diff --git a/checks.sh b/checks.sh index 6fa3e08..eede3bc 100755 --- a/checks.sh +++ b/checks.sh @@ -31,7 +31,9 @@ for plugin in ./src/synack/plugins/*.py; do if [[ $? != 0 ]]; then grep "def ${def}(" ${plugin} -B1 | grep "@property" > /dev/null 2>&1 if [[ $? != 0 ]]; then - echo ${p} missing documentation for: ${def} + if [[ "${def}" != "_"* ]]; then + echo ${p} missing documentation for: ${def} + fi fi fi done @@ -59,6 +61,6 @@ for doc in ./docs/src/usage/plugins/*.md; do fi done -coverage run --source=src --omit=src/synack/db/alembic/env.py,src/synack/db/alembic/versions/*.py -m unittest discover test -coverage report | egrep -v "^[^T].*100%" -coverage html +python3-coverage run --source=src --omit=src/synack/db/alembic/env.py,src/synack/db/alembic/versions/*.py -m unittest discover test +python3-coverage report | egrep -v "^[^T].*100%" +python3-coverage html diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index bff259c..276ab62 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -18,8 +18,8 @@ - [Api](./usage/plugins/api.md) - [Auth](./usage/plugins/auth.md) - [Db](./usage/plugins/db.md) + - [Duo](./usage/plugins/duo.md) - [Debug](./usage/plugins/debug.md) - - [Hydra](./usage/plugins/hydra.md) - [Missions](./usage/plugins/missions.md) - [Notifications](./usage/plugins/notifications.md) - [Scratchspace](./usage/plugins/scratchspace.md) diff --git a/docs/src/usage/index.md b/docs/src/usage/index.md index 1031492..f343803 100644 --- a/docs/src/usage/index.md +++ b/docs/src/usage/index.md @@ -17,12 +17,8 @@ With that in mind, I would highly recommend you become familiar with the [Plugin ## Authentication The first time you try to do anything which requires authentication, you will be automatically prompted for your credentials. -This prompt will expect the `Synack Email` and `Synack Password`, which are fairly self explanitory, but it also asks for the `Synack OTP Secret`. +This prompt will expect the `Synack Email` and `Synack Password`, which are fairly self explanatory. -The `Synack OTP Secret` is NOT the 8 digit code you pull out of Authy. -Instead, it is a string that you must extract from Authy via a method similar to the one found [here](https://gist.github.com/gboudreau/94bb0c11a6209c82418d01a59d958c93). - -Use the above instructions at your own discression. -I TAKE NO RESPONSIBILITY IF SOMETHING BAD HAPPENS AS A RESULT. +For Duo MFA setup options, see the [Duo plugin documentation](./plugins/duo.md). Once you complete these steps, your credentials are stored in a SQLiteDB at `~/.config/synack/synackapi.db`. diff --git a/docs/src/usage/main-components/files.md b/docs/src/usage/main-components/files.md index 5d1a5c2..34eedb0 100644 --- a/docs/src/usage/main-components/files.md +++ b/docs/src/usage/main-components/files.md @@ -34,6 +34,7 @@ It is intended to be used with the following TamperMonkey script in order to do // @version 0.1 // @description Go to the platform automatically // @author You +// @run-at document-start // @match https://*.synack.com/* // @require file:///home//.config/synack/login.js // @grant none diff --git a/docs/src/usage/main-components/state.md b/docs/src/usage/main-components/state.md index 21df449..4f8fc90 100644 --- a/docs/src/usage/main-components/state.md +++ b/docs/src/usage/main-components/state.md @@ -55,12 +55,16 @@ In the event that one of the State variables is set and is **not** constantly at | api_token | str | This is the Synack Access Token used to authenticate requests | config_dir | pathlib.Path | The location of the Database and Login script | debug | bool | Used to show/hide debugging messages +| duo_push_akey | str | Duo device activation key for push auto-approval +| duo_push_host | str | Duo API hostname for push auto-approval +| duo_push_pkey | str | Duo device private key for push auto-approval +| duo_push_rsa_key_path | str | Path to RSA private key for signing Duo API requests | email | str | Your email address used to log into Synack | http_proxy | str | A Web Proxy (Burp, etc.) to intercept requests | https_proxy | str | A Web Proxy (Burp, etc.) to intercept requests | login | bool | Used to enable/disable a check of the api_token upon creation of the Handler | notifications_token | str | Token used for authentication when dealing with Synack Notifications -| otp_secret | str | OTP Secret held by Authy. NOT an OTP. For more information, read the Usage page +| otp_secret | str | OTP Secret held by Duo Mobile. NOT an OTP. For more information, read the Usage page | password | str | Your Synack Password | session | requests.Session | Tracks cookies and headers across various functions | template_dir | pathlib.Path | The location of your Mission Templates diff --git a/docs/src/usage/plugins/auth.md b/docs/src/usage/plugins/auth.md index dfd362f..83cbad0 100644 --- a/docs/src/usage/plugins/auth.md +++ b/docs/src/usage/plugins/auth.md @@ -2,16 +2,6 @@ This plugin deals with authenticating the user to Synack. -## auth.build_otp() - -> Use your stored otp_secret to generate a current OTP code -> ->> Examples ->> ```python3 ->> >>> h.auth.build_otp() ->> '1234567' ->> ``` - ## auth.get_api_token() > Walks through the whole authentication workflow to get a new api_token @@ -22,31 +12,29 @@ This plugin deals with authenticating the user to Synack. >> '489hr98hf...eh59' >> ``` -## auth.get_login_csrf() +## auth.get_authentication_response(csrf) -> Pulls a CSRF Token from the Login page +> Send the username and password to Synack and returns the response +> +> | Arguments | Description +> | --- | --- +> | `csrf` | CSRF token issued by Synack Authentication Workflow > >> Examples >> ```python3 ->> >>> h.auth.get_login_csrf() ->> '45h998h4g5...45wh89g9wh' +>> >>> csrf = h.auth.get_login_csrf() +>> >>> h.auth.get_authentication_response(csrf) +>> {'success': True, ..., 'duo_auth_url': 'https://...'} >> ``` -## auth.get_login_grant_token(csrf, progress_token) +## auth.get_login_csrf() -> Get a Login Grant Token by providing an OTP Code -> -> | Argument | Type | Description -> | --- | --- | --- -> | `csrf` | str | A CSRF Token used while logging in -> | `progress_token` | str | A token returned after submitting a valid username and password +> Pulls a CSRF Token from the Login page > >> Examples >> ```python3 ->> >>> csrf = h.auth.get_login_csrf() ->> >>> lpt = h.auth.get_login_progress_token(csrf) ->> >>> h.auth.get_login_grant_token(csrf, lpt) ->> '58t7i...rh87g58' +>> >>> h.auth.get_login_csrf() +>> '45h998h4g5...45wh89g9wh' >> ``` ## auth.get_login_progress_token(csrf) @@ -74,6 +62,15 @@ This plugin deals with authenticating the user to Synack. >> '958htiu...h98f5ht' >> ``` +## auth.set_api_token_invalid() + +> Invalidates the API Token by logging out +> +>> Examples +>> ```python3 +>> >>> h.auth.set_api_token_invalid() +>> ``` + ## auth.set_login_script() > Writes the current api_token to `~/.config/synack/login.js` JavaScript file to help with staying logged in. diff --git a/docs/src/usage/plugins/db.md b/docs/src/usage/plugins/db.md index 64dbc9d..c77e6df 100644 --- a/docs/src/usage/plugins/db.md +++ b/docs/src/usage/plugins/db.md @@ -82,7 +82,7 @@ Additionally, some properties can be overridden by the State, which allows you t > > | Arguments | Type | Description > | --- | --- | --- -> | `results` | list(dict) | A list of dictionaries containing results from some scan, Hydra, etc. +> | `results` | list(dict) | A list of dictionaries containing results from some scan, etc. > >> Examples >> ```python3 @@ -96,11 +96,8 @@ Additionally, some properties can be overridden by the State, which allows you t >> ... "port": "443", >> ... "protocol": "tcp", >> ... "service": "Super Apache NGINX Deluxe", ->> ... "screenshot_url": "http://127.0.0.1/h3298h23.png", ->> ... "url": "http://bubba.net", >> ... "open": True, >> ... "updated": 1654969137 ->> ... >> ... }, >> ... { >> ... "port": "53", @@ -132,7 +129,7 @@ Additionally, some properties can be overridden by the State, which allows you t > > | Arguments | Type | Description > | --- | --- | --- -> | `results` | list(dict) | A list of dictionaries containing results from some scan, Hydra, etc. +> | `results` | list(dict) | A list of dictionaries containing results from some scan, etc. > >> Examples >> ```python3 @@ -178,7 +175,7 @@ Additionally, some properties can be overridden by the State, which allows you t > | --- | --- | --- > | `port` | int | Port number to search for (443, 80, 25, etc.) > | `protocol` | str | Protocol to search for (tcp, udp, etc.) -> | `source` | str | Source to search for (hydra, nmap, etc.) +> | `source` | str | Source to search for (masscan, nmap, etc.) > | `ip` | str | IP Address to search for > | `kwargs` | kwargs | Any attribute of the Target Database Model (codename, slug, is_active, etc.) > @@ -187,7 +184,7 @@ Additionally, some properties can be overridden by the State, which allows you t >> >>> h.db.find_ports(codename="SLEEPYPUPPY") >> [ >> { ->> 'ip': '1.2.3.4', 'source': 'hydra', 'target': '123hg912', +>> 'ip': '1.2.3.4', 'source': 'masscan', 'target': '123hg912', >> 'ports': [ >> { 'open': True, 'port': '443', 'protocol': 'tcp', 'service': 'https - Wordpress', 'updated': 1654840021 }, >> ... diff --git a/docs/src/usage/plugins/duo.md b/docs/src/usage/plugins/duo.md new file mode 100644 index 0000000..990f4ef --- /dev/null +++ b/docs/src/usage/plugins/duo.md @@ -0,0 +1,86 @@ +# Duo + +## Duo MFA Options + +When prompted during authentication, you can choose from three options: + +**Option 1: Manual Push Approval (Simplest)** +- Press Enter when prompted for OTP Secret +- Approve push notifications on your phone each time the token is expired +- No additional setup required + +**Option 2: Automated OTP (Preferred)** +- Enter your OTP Secret when prompted (accepts both hex and base32 formats) +- Automatically generates OTP codes using a counter (saved in the database) +- Extract the `hotp_secret` from Duo Mobile using [synackDUO](https://github.com/dinosn/synackDUO) (see `response.json`) +- **Note:** This is NOT the 8-digit codes from Duo Mobile, but the HOTP secret key + +**Option 3: Automated Duo Push** +- Uses Duo credentials to auto-approve push requests +- Can also approve push requests using duo.approve_pending_push(timeout) +- Extract credentials using [synackDUO](https://github.com/dinosn/synackDUO) (see `response.json`) (see below) + + +**Disclaimer:** Use the above instructions at your own discretion. I TAKE NO RESPONSIBILITY IF SOMETHING BAD HAPPENS AS A RESULT. + +## Duo Push Auto-Approval Setup + +The Duo plugin supports push notification approval using device credentials. + +### Prerequisites + +You must extract and configure four credentials from Duo Mobile: + +| Credential | Description | Example +| --- | --- | --- +| `duo_push_akey` | Device activation key | `DAXXXXXXXXXXXXXXXXXXXX` +| `duo_push_pkey` | Device private key | `DPXXXXXXXXXXXXXXXXXXXX` +| `duo_push_host` | Duo API hostname | `api-xxxxxxxx.duosecurity.com` +| `duo_push_rsa_key_path` | Path to RSA private key | `~/.config/synack/duo/key.pem` + +### Configuration + +Set credentials in the database: +PP +```python +import synack + +h = synack.Handler(login=False) + +h.db.set_config('duo_push_akey', 'DAXXXXXXXXXXXXXXXXXX') +h.db.set_config('duo_push_pkey', 'DPXXXXXXXXXXXXXXXXXX') +h.db.set_config('duo_push_host', 'api-xxxxxxxx.duosecurity.com') +h.db.set_config('duo_push_rsa_key_path', 'synackDUO/key.pem') +``` + +## duo.get_grant_token(auth_url) + +> Handles Duo Security MFA stages and returns the grant_token used to finish logging into Synack +> +> | Arguments | Description +> | --- | --- +> | `auth_url` | Duo Security Authentication URL generaated by sending credentials to Synack +> +>> Examples +>> ```python3 +>> >>> h.duo.get_grant_token('https:///...duosecurity.com/...') +>> 'Y8....6g' +>> ``` + +## duo.approve_pending_push(timeout) + +> Wait for and approve a single Duo push notification +> +> Polls Duo's device API for pending push notifications and automatically approves the first one found. Useful for automated workflows that need to handle Duo MFA. +> +> | Argument | Type | Default | Description +> | --- | --- | --- | --- +> | `timeout` | int | 30 | Maximum seconds to wait for a push notification +> +> Returns `True` if a push was approved, `False` if timeout or error occurred. +> +>> Examples +>> ```python3 +>> >>> h.duo.approve_pending_push(timeout=60) +>> True +>> ``` diff --git a/docs/src/usage/plugins/hydra.md b/docs/src/usage/plugins/hydra.md deleted file mode 100644 index 8ca26a6..0000000 --- a/docs/src/usage/plugins/hydra.md +++ /dev/null @@ -1,38 +0,0 @@ -# Hydra - -## hydra.build_db_input() - -> Builds a list of ports ready to be ingested by the Database from Hydra output -> ->> Examples ->> ```python3 ->> >>> h.hydra.build_db_input(h.hydra.get_hydra(codename='SLEEPYPUPPY', update_db=False)) ->> [ ->> { ->> 'ip': '1.2.3.4', 'source': 'hydra', 'target': '123hg912', ->> 'ports': [ ->> { 'open': True, 'port': '443', 'protocol': 'tcp', 'service': 'https - Wordpress', 'updated': 1654840021 }, ->> ... ->> ] ->> }, ->> ... ->> ] ->> ``` - -## hydra.get_hydra(page, max_page, update_db, **kwargs) - -> Returns information from Synack Hydra Service -> -> | Arguments | Type | Description -> | --- | --- | --- -> | `page` | int | Page of the Hydra Service to start on (Default: 1) -> | `max_page` | int | Highest page that should be queried (Default: 5) -> | `update_db` | bool | Store the results in the database -> ->> Examples ->> ```python3 ->> >>> h.hydra.get_hydra(codename='SLEEPYPUPPY') ->> [{'host_plugins': {}, 'ip': '1.2.3.4', 'last_changed_dt': '2022-01-01T01:02:03Z', ... }, ... ] ->> >>> h.hydra.get_hydra(codename='SLEEPYPUPPY', page=3, max_page=5, update_db=False) ->> [{'host_plugins': {}, 'ip': '3.4.5.6', 'last_changed_dt': '2022-01-01T01:02:03Z', ... }, ... ] ->> ``` diff --git a/docs/src/usage/plugins/notifications.md b/docs/src/usage/plugins/notifications.md index f34f50e..1a319e9 100644 --- a/docs/src/usage/plugins/notifications.md +++ b/docs/src/usage/plugins/notifications.md @@ -12,10 +12,19 @@ ## notifications.get_unread_count() -> Get the number of unread notifications. +> Get the number of unread notifications > >> Examples >> ```python3 >> >>> h.notifications.get_unread_count() >> 7 >> ``` + +## notifications.set_read() + +> Set all notifications as read +> +>> Examples +>> ```python3 +>> >>> h.notifications.set_read() +>> ``` diff --git a/docs/src/usage/plugins/scratchspace.md b/docs/src/usage/plugins/scratchspace.md index fd4a0a8..5a117b7 100644 --- a/docs/src/usage/plugins/scratchspace.md +++ b/docs/src/usage/plugins/scratchspace.md @@ -18,7 +18,7 @@ ## scratchspace.set_assets_file(content, target=None, codename=None) -> This function will save a `assets.txt` scope file within a `codename` folder in within the `self.db.scratchspace_dir` folder +> This function will save a `assets.txt` scope file within a `codename` folder within the `self.db.scratchspace_dir` folder > If `self.db.use_scratchspace` is `True`, this function is automatically run when you do `targets.get_scope()` or `targets.get_scope_web()` > > | Arguments | Type | Description @@ -36,7 +36,7 @@ ## scratchspace.set_burp_file(content, target=None, codename=None) -> This function will save a `burp.txt` scope file within a `codename` folder in within the `self.db.scratchspace_dir` folder +> This function will save a `burp.txt` scope file within a `codename` folder within the `self.db.scratchspace_dir` folder > If `self.db.use_scratchspace` is `True`, this function is automatically run when you do `targets.get_scope()` or `targets.get_scope_web()` > > | Arguments | Type | Description @@ -82,9 +82,27 @@ >> [PosixPath('/home/user/Scratchspace/SLEEPYTURTLE/file1.txt'), ...] >> ``` +## scratchspace.set_file(content, filename, target=None, codename=None) + +> This function will save a text file with a given name within a `codename` folder within the `sself.db.scratchspace_dir` folder. +> If `self.db.use_scratchspace` is `True`, this function is automatically run when you do `targets.get_scope()` or similar. +> +> | Arguments | Type | Description +> | --- | --- | --- +> | `content` | str,list(str) | Either a preformatted string or (more likely) the return of `targets.get_scope_host()` +> | `filename` | str | Desired filename +> | `target` | db.models.Target | A Target Database Object +> | `codename` | str | Codename of a Target +> +>> Examples +>> ```python3 +>> >>> h.scratchspace.set_file('some unique string', 'mystring.txt', codename='ADAMANTARDVARK') +>> '/tmp/Scratchspace/ADAMANTARDVARK/mystring.txt' +>> ``` + ## scratchspace.set_hosts_file(content, target=None, codename=None) -> This function will save a `hosts.txt` scope file within a `codename` folder in within the `self.db.scratchspace_dir` folder. +> This function will save a `hosts.txt` scope file within a `codename` folder within the `self.db.scratchspace_dir` folder. > If `self.db.use_scratchspace` is `True`, this function is automatically run when you do `targets.get_scope()` or `targets.get_scope_host()` > > | Arguments | Type | Description diff --git a/docs/src/usage/plugins/targets.md b/docs/src/usage/plugins/targets.md index 2135395..96a37d5 100644 --- a/docs/src/usage/plugins/targets.md +++ b/docs/src/usage/plugins/targets.md @@ -210,7 +210,7 @@ >> [{"credentials": [{...},...],...}] >> ``` -## targets.get_query(status='registered', query_changes={}) +## targets.get(status='registered', query_changes={}) > Pulls back a list of targets matching the specified query > @@ -221,7 +221,7 @@ > >> Examples >> ```python3 ->> >>> h.targets.get_query(status='unregistered') +>> >>> h.targets.get(status='unregistered') >> [{"codename": "SLEEPYSLUG", ...}, ...] >> ``` diff --git a/docs/src/usage/plugins/utils.md b/docs/src/usage/plugins/utils.md new file mode 100644 index 0000000..7a6971a --- /dev/null +++ b/docs/src/usage/plugins/utils.md @@ -0,0 +1,17 @@ +# Utils + +## utils.get_html_tag_value(field, text) + +> Looks for an HTML tag in raw HTML and returns its value +> +> | Arguments | Description +> | --- | --- +> | `field` | name of HTML field to find value for +> | `text` | raw HTML content +> +>> Examples +>> ```python3 +>> >>> html = '......' +>> >>> h.utils.get_html_tag_value('tacos', html) +>> 'tasty' +>> ``` diff --git a/setup.py b/setup.py index ce5bf3b..9537d4b 100755 --- a/setup.py +++ b/setup.py @@ -26,14 +26,14 @@ }, package_dir={'': 'src'}, install_requires=[ - "alembic==1.8.1", - "netaddr==0.8.0", - "pathlib2==2.3.6", - "psycopg2-binary==2.9.5", - "pyaml==21.10.1", - "pyotp==2.7.0", - "requests==2.28.1", - "SQLAlchemy==1.4.44", - "urllib3==1.26.13", + "alembic>=1.8.1", + "netaddr>=0.8.0", + "pathlib2>=2.3.6", + "psycopg2-binary>=2.9.5", + "pyaml>=21.10.1", + "pyotp>=2.7.0", + "requests>=2.28.1", + "SQLAlchemy>=1.4.44", + "urllib3>=1.26.13", ] ) diff --git a/src/synack/_handler.py b/src/synack/_handler.py index 914b57a..b29fd9a 100644 --- a/src/synack/_handler.py +++ b/src/synack/_handler.py @@ -11,16 +11,18 @@ class Handler: def __init__(self, state=State(), **kwargs): self.state = state + for name, subclass in Plugin._registry.items(): + instance = subclass(self.state) + setattr(self, name.lower(), instance) + + self.state._db = self.db + for key in kwargs.keys(): if hasattr(self.state, key): setattr(self.state, key, kwargs.get(key)) - for name, subclass in Plugin.registry.items(): - instance = subclass(self.state) - setattr(self, name.lower(), instance) - - self.login() + self._login() - def login(self): + def _login(self): if self.state.login: self.auth.get_api_token() diff --git a/src/synack/_state.py b/src/synack/_state.py index d780505..ff09c87 100644 --- a/src/synack/_state.py +++ b/src/synack/_state.py @@ -11,7 +11,9 @@ class State(object): def __init__(self): + self._api_token = None self._config_dir = None + self._db = None self._debug = None self._email = None self._http_proxy = None @@ -19,14 +21,119 @@ def __init__(self): self._login = None self._notifications_token = None self._otp_secret = None + self._otp_count = None self._password = None self._proxies = None + self._scratchspace_dir = None self._session = None + self._slack_app_token = None + self._slack_channel = None + self._slack_url = None + self._smtp_email_from = None + self._smtp_email_to = None + self._smtp_password = None + self._smtp_port = None + self._smtp_server = None + self._smtp_starttls = None + self._smtp_username = None + self._synack_domain = None self._template_dir = None - self._scratchspace_dir = None self._use_proxies = None self._use_scratchspace = None self._user_id = None + self._duo_push_akey = None + self._duo_push_pkey = None + self._duo_push_host = None + self._duo_push_rsa_key_path = None + self._duo_device = None + + @property + def smtp_email_from(self) -> str: + ret = self._smtp_email_from + if ret is None: + ret = self._db.smtp_email_from + return ret + + @smtp_email_from.setter + def smtp_email_from(self, value: str) -> None: + self._smtp_email_from = value + + @property + def smtp_email_to(self) -> str: + ret = self._smtp_email_to + if ret is None: + ret = self._db.smtp_email_to + return ret + + @smtp_email_to.setter + def smtp_email_to(self, value: str) -> None: + self._smtp_email_to = value + + @property + def smtp_password(self) -> str: + ret = self._smtp_password + if ret is None: + ret = self._db.smtp_password + return ret + + @smtp_password.setter + def smtp_password(self, value: str) -> None: + self._smtp_password = value + + @property + def smtp_port(self) -> str: + ret = self._smtp_port + if ret is None: + ret = self._db.smtp_port + return ret + + @smtp_port.setter + def smtp_port(self, value: str) -> None: + self._smtp_port = value + + @property + def smtp_server(self) -> str: + ret = self._smtp_server + if ret is None: + ret = self._db.smtp_server + return ret + + @smtp_server.setter + def smtp_server(self, value: str) -> None: + self._smtp_server = value + + @property + def smtp_starttls(self) -> str: + ret = self._smtp_starttls + if ret is None: + ret = self._db.smtp_starttls + return ret + + @smtp_starttls.setter + def smtp_starttls(self, value: str) -> None: + self._smtp_starttls = value + + @property + def smtp_username(self) -> str: + ret = self._smtp_username + if ret is None: + ret = self._db.smtp_username + return ret + + @smtp_username.setter + def smtp_username(self, value: str) -> None: + self._smtp_username = value + + @property + def api_token(self) -> str: + ret = self._api_token + if ret is None: + ret = self._db.api_token + return ret + + @api_token.setter + def api_token(self, value: str) -> None: + self._api_token = value @property def config_dir(self) -> pathlib.PosixPath: @@ -42,11 +149,23 @@ def config_dir(self, value: Union[str, pathlib.PosixPath]) -> None: value = pathlib.Path(value).expanduser().resolve() self._config_dir = value + @property + def synack_domain(self): + ret = self._synack_domain + if ret is None: + ret = self._db.synack_domain + return ret + + @synack_domain.setter + def synack_domain(self, value): + self._synack_domain = value + @property def template_dir(self) -> pathlib.PosixPath: ret = self._template_dir - if ret: - ret.mkdir(parents=True, exist_ok=True) + if ret is None: + ret = self._db.template_dir + ret.mkdir(parents=True, exist_ok=True) return ret @template_dir.setter @@ -58,8 +177,9 @@ def template_dir(self, value: Union[str, pathlib.PosixPath]) -> None: @property def scratchspace_dir(self) -> pathlib.PosixPath: ret = self._scratchspace_dir - if ret: - ret.mkdir(parents=True, exist_ok=True) + if ret is None: + ret = self._db.scratchspace_dir + ret.mkdir(parents=True, exist_ok=True) return ret @scratchspace_dir.setter @@ -70,7 +190,10 @@ def scratchspace_dir(self, value: Union[str, pathlib.PosixPath]) -> None: @property def debug(self) -> bool: - return self._debug + ret = self._debug + if ret is None: + ret = self._db.debug + return ret @debug.setter def debug(self, value: bool) -> None: @@ -90,9 +213,23 @@ def login(self) -> bool: def login(self, value: bool) -> None: self._login = value + @property + def notifications_token(self) -> str: + ret = self._notifications_token + if ret is None: + ret = self._db.notifications_token + return ret + + @notifications_token.setter + def notifications_token(self, value: str) -> None: + self._notifications_token = value + @property def use_proxies(self) -> bool: - return self._use_proxies + ret = self._use_proxies + if ret is None: + ret = self._db.use_proxies + return ret @use_proxies.setter def use_proxies(self, value: bool) -> None: @@ -100,7 +237,10 @@ def use_proxies(self, value: bool) -> None: @property def use_scratchspace(self) -> bool: - return self._use_scratchspace + ret = self._use_scratchspace + if ret is None: + ret = self._db.use_scratchspace + return ret @use_scratchspace.setter def use_scratchspace(self, value: bool) -> None: @@ -108,7 +248,10 @@ def use_scratchspace(self, value: bool) -> None: @property def http_proxy(self) -> str: - return self._http_proxy + ret = self._http_proxy + if ret is None: + ret = self._db.http_proxy + return ret @http_proxy.setter def http_proxy(self, value: str) -> None: @@ -116,7 +259,10 @@ def http_proxy(self, value: str) -> None: @property def https_proxy(self) -> str: - return self._https_proxy + ret = self._https_proxy + if ret is None: + ret = self._db.https_proxy + return ret @https_proxy.setter def https_proxy(self, value: str) -> None: @@ -131,23 +277,76 @@ def proxies(self) -> dict(): @property def otp_secret(self) -> str: - return self._otp_secret + ret = self._otp_secret + if ret is None: + ret = self._db.otp_secret + return ret @otp_secret.setter def otp_secret(self, value: str) -> None: self._otp_secret = value + @property + def otp_count(self) -> str: + ret = self._otp_count + if ret is None: + ret = self._db.otp_count + return ret + + @otp_count.setter + def otp_count(self, value: int) -> None: + self._otp_count = value + @property def email(self) -> str: - return self._email + ret = self._email + if ret is None: + ret = self._db.email + return ret @email.setter def email(self, value: str) -> None: self._email = value + @property + def slack_app_token(self) -> str: + ret = self._slack_app_token + if ret is None: + ret = self._db.slack_app_token + return ret + + @slack_app_token.setter + def slack_app_token(self, value: str) -> None: + self._slack_app_token = value + + @property + def slack_channel(self) -> str: + ret = self._slack_channel + if ret is None: + ret = self._db.slack_channel + return ret + + @slack_channel.setter + def slack_channel(self, value: str) -> None: + self._slack_channel = value + + @property + def slack_url(self) -> str: + ret = self._slack_url + if ret is None: + ret = self._db.slack_url + return ret + + @slack_url.setter + def slack_url(self, value: str) -> None: + self._slack_url = value + @property def password(self) -> str: - return self._password + ret = self._password + if ret is None: + ret = self._db.password + return ret @password.setter def password(self, value: str) -> None: @@ -155,8 +354,66 @@ def password(self, value: str) -> None: @property def user_id(self) -> str: - return self._user_id + ret = self._user_id + if ret is None: + ret = self._db.user_id + return ret @user_id.setter def user_id(self, value: str) -> None: self._user_id = value + + @property + def duo_push_akey(self) -> str: + ret = self._duo_push_akey + if ret is None: + ret = self._db.duo_push_akey + return ret + + @duo_push_akey.setter + def duo_push_akey(self, value: str) -> None: + self._duo_push_akey = value + + @property + def duo_push_pkey(self) -> str: + ret = self._duo_push_pkey + if ret is None: + ret = self._db.duo_push_pkey + return ret + + @duo_push_pkey.setter + def duo_push_pkey(self, value: str) -> None: + self._duo_push_pkey = value + + @property + def duo_push_host(self) -> str: + ret = self._duo_push_host + if ret is None: + ret = self._db.duo_push_host + return ret + + @duo_push_host.setter + def duo_push_host(self, value: str) -> None: + self._duo_push_host = value + + @property + def duo_push_rsa_key_path(self) -> str: + ret = self._duo_push_rsa_key_path + if ret is None: + ret = self._db.duo_push_rsa_key_path + return ret + + @duo_push_rsa_key_path.setter + def duo_push_rsa_key_path(self, value: str) -> None: + self._duo_push_rsa_key_path = value + + @property + def duo_device(self) -> str: + ret = self._duo_device + if ret is None: + ret = self._db.duo_device + return ret + + @duo_device.setter + def duo_device(self, value: str) -> None: + self._duo_device = value diff --git a/src/synack/db/alembic/versions/1434aa7ed47c_add_synack_domain.py b/src/synack/db/alembic/versions/1434aa7ed47c_add_synack_domain.py new file mode 100644 index 0000000..454d08e --- /dev/null +++ b/src/synack/db/alembic/versions/1434aa7ed47c_add_synack_domain.py @@ -0,0 +1,25 @@ +"""add synack_domain + +Revision ID: 1434aa7ed47c +Revises: 6814001a4ed4 +Create Date: 2025-02-06 04:19:28.655055 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1434aa7ed47c' +down_revision = '6814001a4ed4' +branch_labels = None +depends_on = None + +def upgrade(): + with op.batch_alter_table('config') as batch_op: + batch_op.add_column(sa.Column('synack_domain', sa.VARCHAR(100), server_default='synack.com')) + + +def downgrade(): + with op.batch_alter_table('config') as batch_op: + batch_op.drop_column('synack_domain') diff --git a/src/synack/db/alembic/versions/6814001a4ed4_add_unique_ips_constraint.py b/src/synack/db/alembic/versions/6814001a4ed4_add_unique_ips_constraint.py new file mode 100644 index 0000000..00058de --- /dev/null +++ b/src/synack/db/alembic/versions/6814001a4ed4_add_unique_ips_constraint.py @@ -0,0 +1,25 @@ +"""add unique ips constraint + +Revision ID: 6814001a4ed4 +Revises: 753c42281f78 +Create Date: 2025-01-26 05:19:35.150476 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = '6814001a4ed4' +down_revision = '753c42281f78' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('ips') as batch_op: + batch_op.create_unique_constraint('uq_ip', ['ip', 'target']) + + +def downgrade(): + with op.batch_alter_table('ips') as batch_op: + batch_op.drop_constraint('uq_ip', type_='unique') diff --git a/src/synack/db/alembic/versions/6f542023f57e_add_duo_push_method.py b/src/synack/db/alembic/versions/6f542023f57e_add_duo_push_method.py new file mode 100644 index 0000000..0fdde9a --- /dev/null +++ b/src/synack/db/alembic/versions/6f542023f57e_add_duo_push_method.py @@ -0,0 +1,38 @@ +"""add duo push method + +Revision ID: 6f542023f57e +Revises: 1434aa7ed47c +Create Date: 2025-11-06 08:23:42.181054 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '6f542023f57e' +down_revision = '1434aa7ed47c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('config') as batch_op: + batch_op.add_column(sa.Column('duo_push_akey', sa.VARCHAR(length=200), nullable=True)) + batch_op.add_column(sa.Column('duo_push_pkey', sa.VARCHAR(length=200), nullable=True)) + batch_op.add_column(sa.Column('duo_push_host', sa.VARCHAR(length=100), nullable=True)) + batch_op.add_column(sa.Column('duo_push_rsa_key_path', sa.VARCHAR(length=250), nullable=True)) + batch_op.add_column(sa.Column('duo_device', sa.VARCHAR(length=50), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('config') as batch_op: + batch_op.drop_column('duo_device') + batch_op.drop_column('duo_push_rsa_key_path') + batch_op.drop_column('duo_push_host') + batch_op.drop_column('duo_push_pkey') + batch_op.drop_column('duo_push_akey') + # ### end Alembic commands ### diff --git a/src/synack/db/alembic/versions/753c42281f78_add_unique_ports_constraint.py b/src/synack/db/alembic/versions/753c42281f78_add_unique_ports_constraint.py new file mode 100644 index 0000000..7401f81 --- /dev/null +++ b/src/synack/db/alembic/versions/753c42281f78_add_unique_ports_constraint.py @@ -0,0 +1,25 @@ +"""add unique ports constraint + +Revision ID: 753c42281f78 +Revises: c2e6de9ffc5e +Create Date: 2025-01-26 05:07:23.252004 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = '753c42281f78' +down_revision = 'c2e6de9ffc5e' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('ports') as batch_op: + batch_op.create_unique_constraint('uq_port', ['port', 'protocol', 'ip', 'source']) + + +def downgrade(): + with op.batch_alter_table('ports') as batch_op: + batch_op.drop_constraint('uq_port', type_='unique') diff --git a/src/synack/db/alembic/versions/8b478a84c1a6_add_organization_name.py b/src/synack/db/alembic/versions/8b478a84c1a6_add_organization_name.py new file mode 100644 index 0000000..1bf270f --- /dev/null +++ b/src/synack/db/alembic/versions/8b478a84c1a6_add_organization_name.py @@ -0,0 +1,26 @@ +"""add organization name + +Revision ID: 8b478a84c1a6 +Revises: 6f542023f57e +Create Date: 2025-02-11 13:05:40.939271 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '8b478a84c1a6' +down_revision = '6f542023f57e' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('organizations') as batch_op: + batch_op.add_column(sa.Column('name', sa.VARCHAR(100))) + + +def downgrade(): + with op.batch_alter_table('organizations') as batch_op: + batch_op.drop_column('name') diff --git a/src/synack/db/alembic/versions/c2e6de9ffc5e_add_otp_count.py b/src/synack/db/alembic/versions/c2e6de9ffc5e_add_otp_count.py new file mode 100644 index 0000000..4f92c4a --- /dev/null +++ b/src/synack/db/alembic/versions/c2e6de9ffc5e_add_otp_count.py @@ -0,0 +1,26 @@ +"""add otp_count + +Revision ID: c2e6de9ffc5e +Revises: f627018b273f +Create Date: 2025-01-11 22:29:05.822904 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'c2e6de9ffc5e' +down_revision = 'f627018b273f' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('config') as batch_op: + batch_op.add_column(sa.Column('otp_count', sa.INTEGER, server_default='0')) + + +def downgrade(): + with op.batch_alter_table('config') as batch_op: + batch_op.drop_column('otp_count') diff --git a/src/synack/db/alembic/versions/f627018b273f_add_slack_app_variables.py b/src/synack/db/alembic/versions/f627018b273f_add_slack_app_variables.py new file mode 100644 index 0000000..712680f --- /dev/null +++ b/src/synack/db/alembic/versions/f627018b273f_add_slack_app_variables.py @@ -0,0 +1,28 @@ +"""add slack app variables + +Revision ID: f627018b273f +Revises: 349c447c0d37 +Create Date: 2025-01-06 20:44:52.383303 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'f627018b273f' +down_revision = '349c447c0d37' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('config') as batch_op: + batch_op.add_column(sa.Column('slack_app_token', sa.VARCHAR(100), server_default='')) + batch_op.add_column(sa.Column('slack_channel', sa.VARCHAR(100), server_default='')) + + +def downgrade(): + with op.batch_alter_table('config') as batch_op: + batch_op.drop_column('slack_app_token') + batch_op.drop_column('slack_channel') diff --git a/src/synack/db/models/category.py b/src/synack/db/models/category.py old mode 100755 new mode 100644 diff --git a/src/synack/db/models/config.py b/src/synack/db/models/config.py old mode 100755 new mode 100644 index e0ac890..216b467 --- a/src/synack/db/models/config.py +++ b/src/synack/db/models/config.py @@ -20,9 +20,12 @@ class Config(Base): login = sa.Column(sa.BOOLEAN, default=True) notifications_token = sa.Column(sa.VARCHAR(1000), default='') otp_secret = sa.Column(sa.VARCHAR(50), default='') + otp_count = sa.Column(sa.INTEGER, default=0) password = sa.Column(sa.VARCHAR(150), default='') scratchspace_dir = sa.Column(sa.VARCHAR(250), default='~/Scratchspace') slack_url = sa.Column(sa.VARCHAR(500), default='') + slack_app_token = sa.Column(sa.VARCHAR(100), default='') + slack_channel = sa.Column(sa.VARCHAR(100), default='') smtp_email_from = sa.Column(sa.VARCHAR(250), default='') smtp_password = sa.Column(sa.VARCHAR(250), default='') smtp_port = sa.Column(sa.INTEGER, default=465) @@ -30,7 +33,15 @@ class Config(Base): smtp_email_to = sa.Column(sa.VARCHAR(250), default='') smtp_username = sa.Column(sa.VARCHAR(250), default='') smtp_starttls = sa.Column(sa.BOOLEAN, default=True) + synack_domain = sa.Column(sa.VARCHAR(100), default='synack.com') template_dir = sa.Column(sa.VARCHAR(250), default='~/Templates') user_id = sa.Column(sa.VARCHAR(20), default='') use_proxies = sa.Column(sa.BOOLEAN, default=False) use_scratchspace = sa.Column(sa.BOOLEAN, default=False) + duo_push_akey = sa.Column(sa.VARCHAR(200), default='') + duo_push_pkey = sa.Column(sa.VARCHAR(200), default='') + duo_push_host = sa.Column(sa.VARCHAR(100), default='') + duo_push_rsa_key_path = sa.Column( + sa.VARCHAR(250), default='' + ) + duo_device = sa.Column(sa.VARCHAR(50), default='') diff --git a/src/synack/db/models/ip.py b/src/synack/db/models/ip.py index a45be8c..83e6c36 100644 --- a/src/synack/db/models/ip.py +++ b/src/synack/db/models/ip.py @@ -15,3 +15,4 @@ class IP(Base): id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) ip = sa.Column(sa.VARCHAR(40)) target = sa.Column(sa.VARCHAR(20), sa.ForeignKey(Target.slug)) + __table_args__ = (sa.UniqueConstraint('ip', 'target', name='uq_ip'),) diff --git a/src/synack/db/models/organization.py b/src/synack/db/models/organization.py old mode 100755 new mode 100644 index c352174..2c67868 --- a/src/synack/db/models/organization.py +++ b/src/synack/db/models/organization.py @@ -12,3 +12,4 @@ class Organization(Base): __tablename__ = 'organizations' slug = sa.Column(sa.VARCHAR(20), primary_key=True) + name = sa.Column(sa.VARCHAR(100)) diff --git a/src/synack/db/models/port.py b/src/synack/db/models/port.py index a592b0e..a7a9044 100644 --- a/src/synack/db/models/port.py +++ b/src/synack/db/models/port.py @@ -20,5 +20,4 @@ class Port(Base): open = sa.Column(sa.BOOLEAN, default=False) service = sa.Column(sa.VARCHAR(200), default="") updated = sa.Column(sa.INTEGER, default=0) - url = sa.Column(sa.VARCHAR(200), default="") - screenshot_url = sa.Column(sa.VARCHAR(1000), default="") + __table_args__ = (sa.UniqueConstraint('port', 'protocol', 'ip', 'source', name='uq_port'),) diff --git a/src/synack/db/models/target.py b/src/synack/db/models/target.py old mode 100755 new mode 100644 diff --git a/src/synack/plugins/__init__.py b/src/synack/plugins/__init__.py index 27450df..e6f9f92 100644 --- a/src/synack/plugins/__init__.py +++ b/src/synack/plugins/__init__.py @@ -5,7 +5,7 @@ from .auth import Auth from .db import Db from .debug import Debug -from .hydra import Hydra +from .duo import Duo from .missions import Missions from .notifications import Notifications from .scratchspace import Scratchspace @@ -13,3 +13,4 @@ from .templates import Templates from .transactions import Transactions from .users import Users +from .utils import Utils diff --git a/src/synack/plugins/alerts.py b/src/synack/plugins/alerts.py index 92086c7..e3ef7b2 100644 --- a/src/synack/plugins/alerts.py +++ b/src/synack/plugins/alerts.py @@ -9,6 +9,7 @@ import re import requests import smtplib +import warnings from .base import Plugin @@ -18,23 +19,23 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for plugin in ['Db']: setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) def email(self, subject='Test Alert', message='This is a test'): message += f'\nTime: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}' msg = email.message.EmailMessage() msg.set_content(message) msg['Subject'] = subject - msg['From'] = self.db.smtp_email_from - msg['To'] = self.db.smtp_email_to + msg['From'] = self._state.smtp_email_from + msg['To'] = self._state.smtp_email_to - if self.db.smtp_starttls: - server = smtplib.SMTP_SSL(self.db.smtp_server, self.db.smtp_port) + if self._state.smtp_starttls: + server = smtplib.SMTP_SSL(self._state.smtp_server, self._state.smtp_port) else: - server = smtplib.SMTP(self.db.smtp_server, self.db.smtp_port) + server = smtplib.SMTP(self._state.smtp_server, self._state.smtp_port) - server.login(self.db.smtp_username, self.db.smtp_password) + server.login(self._state.smtp_username, self._state.smtp_password) server.send_message(msg) def sanitize(self, message): @@ -58,7 +59,17 @@ def sanitize(self, message): r'(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))(?=\s|$)', '[IPv6]', message) return message - def slack(self, message='This is a test'): - requests.post(self.db.slack_url, - data=json.dumps({'text': message}), - headers={'Content-Type': 'application/json'}) + def slack(self, message='This is a test', channel=None): + if channel is None: + channel = self._state.slack_channel + warnings.filterwarnings("ignore") + requests.post('https://slack.com/api/chat.postMessage', + data=json.dumps({ + 'text': message, + 'channel': channel, + }), + headers={ + 'Authorization': f'Bearer {self._state.slack_app_token}', + 'Content-Type': 'application/json' + }, + verify=False) diff --git a/src/synack/plugins/api.py b/src/synack/plugins/api.py index 2d8fd2e..19f4ef3 100644 --- a/src/synack/plugins/api.py +++ b/src/synack/plugins/api.py @@ -3,6 +3,7 @@ Functions to handle interacting with the Synack APIs """ +import time import warnings from .base import Plugin @@ -12,7 +13,7 @@ class Api(Plugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for plugin in ['Debug', 'Db']: - setattr(self, plugin.lower(), self.registry.get(plugin)(self.state)) + setattr(self, '_'+plugin.lower(), self._registry.get(plugin)(self._state)) def login(self, method, path, **kwargs): """Modify API Request for Login @@ -29,7 +30,7 @@ def login(self, method, path, **kwargs): if path.startswith('http'): base = '' else: - base = 'https://login.synack.com/api/' + base = f'https://login.{self._state.synack_domain}/api/' url = f'{base}{path}' res = self.request(method, url, **kwargs) return res @@ -49,20 +50,20 @@ def notifications(self, method, path, **kwargs): if path.startswith('http'): base = '' else: - base = 'https://notifications.synack.com/api/v2/' + base = f'https://notifications.{self._state.synack_domain}/api/v2/' url = f'{base}{path}' if not kwargs.get('headers'): kwargs['headers'] = dict() - auth = "Bearer " + self.db.notifications_token + auth = "Bearer " + self._state.notifications_token kwargs['headers']['Authorization'] = auth res = self.request(method, url, **kwargs) if res.status_code == 422: - self.db.notifications_token = "" + self._db.notifications_token = '' return res - def request(self, method, path, **kwargs): + def request(self, method, path, attempts=0, **kwargs): """Send API Request Arguments: @@ -70,6 +71,7 @@ def request(self, method, path, **kwargs): (GET, POST, etc.) path -- API endpoint path Can be an endpoint on platform.synack.com or a full URL + attempts -- Number of times the request has been attempted headers -- Additional headers to be added for only this request data -- POST body dictionary query -- GET query string dictionary @@ -77,62 +79,98 @@ def request(self, method, path, **kwargs): if path.startswith('http'): base = '' else: - base = 'https://platform.synack.com/api/' + base = f'https://platform.{self._state.synack_domain}/api/' url = f'{base}{path}' - if self.db.use_proxies: - warnings.filterwarnings("ignore") - verify = False - proxies = self.db.proxies - else: - verify = True - proxies = None + verify = False + warnings.filterwarnings('ignore') + + proxies = self._state.proxies if self._state.use_proxies else None - headers = { - 'Authorization': f'Bearer {self.db.api_token}', - 'user_id': self.db.user_id - } + if f'{self._state.synack_domain}/api/' in url: + headers = { + 'Authorization': f'Bearer {self._state.api_token}', + 'user_id': self._state.user_id + } + else: + headers = dict() if kwargs.get('headers'): headers.update(kwargs.get('headers', {})) query = kwargs.get('query') data = kwargs.get('data') if method.upper() == 'GET': - res = self.state.session.get(url, - headers=headers, - proxies=proxies, - params=query, - verify=verify) - elif method.upper() == 'HEAD': - res = self.state.session.head(url, + res = self._state.session.get(url, headers=headers, proxies=proxies, params=query, verify=verify) - elif method.upper() == 'PATCH': - res = self.state.session.patch(url, - json=data, + elif method.upper() == 'HEAD': + res = self._state.session.head(url, headers=headers, proxies=proxies, + params=query, verify=verify) + elif method.upper() == 'PATCH': + res = self._state.session.patch(url, + json=data, + headers=headers, + proxies=proxies, + verify=verify) elif method.upper() == 'POST': - res = self.state.session.post(url, - json=data, + if 'urlencoded' in headers.get('Content-Type', ''): + res = self._state.session.post(url, + data=data, + headers=headers, + proxies=proxies, + verify=verify) + else: + res = self._state.session.post(url, + json=data, + headers=headers, + proxies=proxies, + verify=verify) + elif method.upper() == 'PUT': + res = self._state.session.put(url, headers=headers, proxies=proxies, + params=data, verify=verify) - elif method.upper() == 'PUT': - res = self.state.session.put(url, - headers=headers, - proxies=proxies, - params=data, - verify=verify) - - self.debug.log("Network Request", - f"{res.status_code} -- {method.upper()} -- {url}" + - f"\n\tHeaders: {headers}" + - f"\n\tQuery: {query}" + - f"\n\tData: {data}" + - f"\n\tContent: {res.content}") + + self._debug.log("Network Request", + f"{res.status_code} -- {method.upper()} -- {url}" + + f"\n\tHeaders: {headers}" + + f"\n\tQuery: {query}" + + f"\n\tData: {data}" + + f"\n\tContent: {res.content}") + + reason_failed = None + if res.status_code == 400: + reason_failed = 'Bad request' + elif res.status_code == 401: + reason_failed = 'Unauthorized' + elif res.status_code == 403: + reason_failed = 'Logged out' + elif res.status_code == 412: + reason_failed = 'Mission already claimed' + elif res.status_code == 423: + reason_failed = 'Locked' + elif res.status_code == 429: + self._debug.log('Too many requests', f'({res.status_code} - {res.reason}) {res.url}') + if attempts < 5: + self._debug.log('Pausing', 'Retrying in 30 seconds...') + time.sleep(30) + attempts += 1 + return self.request(method, path, attempts, **kwargs) + elif res.status_code >= 400: + self._debug.log(f'Request failed', f'({res.status_code} - {res.reason}) {res.url}') + if attempts < 5: + self._debug.log('Retrying', f'Attempt #{attempts + 1}') + attempts += 1 + return self.request(method, path, attempts, **kwargs) + + # Log terminal failures (non-retryable errors) + if res.status_code in [400, 401, 403, 412, 423]: + self._debug.log(reason_failed, f'({res.status_code} - {res.reason}) {res.url}') return res diff --git a/src/synack/plugins/auth.py b/src/synack/plugins/auth.py index d9cec6a..1195ce8 100644 --- a/src/synack/plugins/auth.py +++ b/src/synack/plugins/auth.py @@ -3,7 +3,6 @@ Functions related to handling and checking authentication. """ -import pyotp import re from .base import Plugin @@ -12,106 +11,96 @@ class Auth(Plugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - for plugin in ['Api', 'Db', 'Users']: + for plugin in ['Api', 'Db', 'Duo', 'Users']: setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) - - def build_otp(self): - """Generate and return a OTP.""" - totp = pyotp.TOTP(self.db.otp_secret) - totp.digits = 7 - totp.interval = 10 - totp.issuer = 'synack' - return totp.now() + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) def get_api_token(self): """Log in to get a new API token.""" - if self.users.get_profile(): - return self.db.api_token + if self._api.request('HEAD', 'profiles/me').status_code == 200: + return self._state.api_token csrf = self.get_login_csrf() - progress_token = None + duo_auth_url = None grant_token = None if csrf: - progress_token = self.get_login_progress_token(csrf) - if progress_token: - grant_token = self.get_login_grant_token(csrf, progress_token) + auth_response = self.get_authentication_response(csrf) + duo_auth_url = auth_response.get('duo_auth_url', '') + if duo_auth_url: + grant_token = self._duo.get_grant_token(duo_auth_url) if grant_token: - url = 'https://platform.synack.com/' + url = f'https://platform.{self._state.synack_domain}/' headers = { 'X-Requested-With': 'XMLHttpRequest' } query = { "grant_token": grant_token } - res = self.api.request('GET', - url + 'token', - headers=headers, - query=query) + res = self._api.request('GET', + url + 'token', + headers=headers, + query=query) if res.status_code == 200: j = res.json() - self.db.api_token = j.get('access_token') + self._db.api_token = j.get('access_token') self.set_login_script() return j.get('access_token') - def get_login_csrf(self): - """Get the CSRF Token from the login page""" - res = self.api.request('GET', 'https://login.synack.com') - m = re.search(' {" +\ "const loc = window.location;" +\ - "if(loc.href.startsWith('https://login.synack.com/')) {" +\ - "loc.replace('https://platform.synack.com');" +\ + "if(loc.href.startsWith('https://login." + self._state.synack_domain + "/')) {" +\ + "loc.replace('https://platform." + self._state.synack_domain + "');" +\ "}};" +\ "(function() {" +\ - "sessionStorage.setItem('shared-session-com.synack.accessToken'" +\ - ",'" +\ - self.db.api_token +\ - "');" +\ "setTimeout(forceLogin,60000);" +\ "let btn = document.createElement('button');" +\ "btn.addEventListener('click',forceLogin);" +\ @@ -122,7 +111,7 @@ def set_login_script(self): "document.getElementsByClassName('onboarding-form')[0]" +\ ".appendChild(btn)}" +\ ")();" - with open(self.state.config_dir / 'login.js', 'w') as fp: + with open(self._state.config_dir / 'login.js', 'w') as fp: fp.write(script) return script diff --git a/src/synack/plugins/base.py b/src/synack/plugins/base.py index 102434d..1b07e33 100644 --- a/src/synack/plugins/base.py +++ b/src/synack/plugins/base.py @@ -1,9 +1,9 @@ class Plugin: - registry = {} + _registry = {} def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - cls.registry[cls.__name__] = cls + cls._registry[cls.__name__] = cls def __init__(self, state, **kwargs): - self.state = state + self._state = state diff --git a/src/synack/plugins/db.py b/src/synack/plugins/db.py index eb89fba..0e0a9d8 100644 --- a/src/synack/plugins/db.py +++ b/src/synack/plugins/db.py @@ -7,6 +7,8 @@ import alembic.command import sqlalchemy as sa +from sqlalchemy.dialects.sqlite import insert as sqlite_insert + from pathlib import Path from sqlalchemy.orm import sessionmaker from synack.db.models import Target @@ -23,11 +25,14 @@ class Db(Plugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.sqlite_db = self.state.config_dir / 'synackapi.db' + self.sqlite_db = self._state.config_dir / 'synackapi.db' self.set_migration() engine = sa.create_engine(f'sqlite:///{str(self.sqlite_db)}') + metadata = sa.MetaData() + metadata.reflect(bind=engine) + metadata.clear() sa.event.listen(engine, 'connect', self._fk_pragma_on_connect) self.Session = sessionmaker(bind=engine) @@ -44,139 +49,226 @@ def add_categories(self, categories): db_c = Category(id=c['category_id']) session.add(db_c) db_c.name = c['category_name'] - db_c.passed_practical = c['practical_assessment']['passed'] - db_c.passed_written = c['written_assessment']['passed'] + db_c.passed_practical = c['passed'] + db_c.passed_written = c['passed'] session.commit() session.close() def add_ips(self, results, session=None): close = False + if session is None: session = self.Session() close = True - q = session.query(IP) + + ips_data = list() + for result in results: - if result.get('ip'): - filt = sa.and_( - IP.ip.like(result.get('ip')), - IP.target.like(result.get('target')) - ) - db_ip = q.filter(filt).first() - if not db_ip: - db_ip = IP( - ip=result.get('ip'), - target=result.get('target')) - session.add(db_ip) + if result.get('ip') and result.get('target'): + ips_data.append({ + 'ip': result['ip'], + 'target': result['target'] + }) + if len(ips_data) > 15000: + stmt = sqlite_insert(IP).values(ips_data) + stmt = stmt.on_conflict_do_nothing( + index_elements=['ip', 'target'], + ) + session.execute(stmt) + ips_data = list() + + + if ips_data: + stmt = sqlite_insert(IP).values(ips_data) + stmt = stmt.on_conflict_do_nothing( + index_elements=['ip', 'target'], + ) + session.execute(stmt) + if close: session.commit() session.close() def add_organizations(self, targets, session=None): close = False + if session is None: session = self.Session() close = True - q = session.query(Organization) - for t in targets: - if t.get('organization'): - slug = t['organization']['slug'] + + organizations_data = list() + + if isinstance(targets, dict): + targets = [value for key, value in targets.items()] + + for target in targets: + if isinstance(target.get('organization'), str): + slug = target.get('organization') + name = None # No name available in this case else: - slug = t.get('organization_id') - db_o = q.filter_by(slug=slug).first() - if not db_o: - db_o = Organization(slug=slug) - session.add(db_o) + org = target.get('organization', {}) + slug = target.get('organization_id', org.get('slug')) + name = org.get('name') + if slug: + organizations_data.append({ + 'slug': slug, + 'name': name + }) + + if organizations_data: + stmt = sqlite_insert(Organization).values(organizations_data) + stmt = stmt.on_conflict_do_update( + index_elements=['slug'], + set_={ + 'name': stmt.excluded.name + } + ) + session.execute(stmt) + if close: session.commit() session.close() def add_ports(self, results): - self.add_ips(results) session = self.Session() - q = session.query(Port) - ips = session.query(IP) + + self.add_ips(results, session) + ip_map = {ip.ip: ip.id for ip in session.query(IP.ip, IP.id).all()} + + ports_data = list() + for result in results: - ip = ips.filter_by(ip=result.get('ip')) - if ip: - ip = ip.first() + ip_id = ip_map.get(result.get('ip')) + if ip_id: for port in result.get('ports', []): - filt = sa.and_( - Port.port.like(port.get('port')), - Port.protocol.like(port.get('protocol')), - Port.ip.like(ip.id), - Port.source.like(result.get('source'))) - db_port = q.filter(filt) - if not db_port: - db_port = Port( - port=port.get('port'), - protocol=port.get('protocol'), - service=port.get('service'), - ip=ip.id, - source=result.get('source'), - open=port.get('open'), - updated=port.get('updated') + ports_data.append({ + 'port': port.get('port'), + 'protocol': port.get('protocol'), + 'service': port.get('service'), + 'ip': ip_id, + 'source': result.get('source'), + 'open': port.get('open'), + 'updated': port.get('updated') + }) + if len(ports_data) > 15000: + stmt = sqlite_insert(Port).values(ports_data) + stmt = stmt.on_conflict_do_update( + index_elements=['port', 'protocol', 'ip', 'source'], + set_={ + 'service': stmt.excluded.service, + 'open': stmt.excluded.open, + 'updated': stmt.excluded.updated + } ) - else: - db_port = db_port.first() - db_port.service = port.get('service', db_port.service) - db_port.open = port.get('open', db_port.open) - db_port.updated = port.get('updated', db_port.updated) - session.add(db_port) + session.execute(stmt) + ports_data = list() + + + if ports_data: + stmt = sqlite_insert(Port).values(ports_data) + stmt = stmt.on_conflict_do_update( + index_elements=['port', 'protocol', 'ip', 'source'], + set_={ + 'service': stmt.excluded.service, + 'open': stmt.excluded.open, + 'updated': stmt.excluded.updated + } + ) + session.execute(stmt) + session.commit() session.close() def add_targets(self, targets, **kwargs): session = self.Session() + self.add_organizations(targets, session) - q = session.query(Target) - for t in targets: - if t.get('organization'): - org_slug = t['organization']['slug'] + db_orgs = [org[0] for org in session.query(Organization.slug).all()] + + targets_data = list() + + if isinstance(targets, dict): + targets = [value for key, value in targets.items()] + + for target in targets: + if isinstance(target.get('organization'), str): + org_slug = target.get('organization') else: - org_slug = t.get('organization_id') - slug = t.get('slug', t.get('id')) - db_t = q.filter_by(slug=slug).first() - if not db_t: - db_t = Target(slug=slug) - session.add(db_t) - for k in t.keys(): - setattr(db_t, k, t[k]) - db_t.category = t['category']['id'] - db_t.organization = org_slug - db_t.date_updated = t.get('dateUpdated') - db_t.is_active = t.get('isActive') - db_t.is_new = t.get('isNew') - db_t.is_registered = t.get('isRegistered') - db_t.is_updated = t.get('isUpdated') - db_t.last_submitted = t.get('lastSubmitted') - for k in kwargs.keys(): - setattr(db_t, k, kwargs[k]) + org_slug = target.get('organization_id', target.get('organization', {}).get('slug')) + if isinstance(target.get('category'), int): + category = target.get('category') + else: + category = target.get('category', {}).get('id') + if org_slug in db_orgs: + target_data = { + 'slug': target.get('id', target.get('slug')), + 'codename': target.get('codename'), + 'category': category, + 'organization': org_slug, + 'date_updated': target.get('dateUpdated', target.get('date_updated')), + 'is_active': target.get('isActive', target.get('is_active')), + 'is_new': target.get('isNew', target.get('is_new')), + 'is_registered': target.get('isRegistered', target.get('is_registered')), + 'last_submitted': target.get('lastSubmitted', target.get('last_submitted')), + 'average_payout': target.get('averagePayout'), + 'start_date': target.get('start_date'), + 'end_date': target.get('end_date'), + 'is_updated': target.get('isUpdated', False) + } + target_data.update(kwargs) + targets_data.append(target_data) + + if targets_data: + stmt = sqlite_insert(Target).values(targets_data) + stmt = stmt.on_conflict_do_update( + index_elements=['slug'], + set_={ + 'category': stmt.excluded.category, + 'codename': stmt.excluded.codename, + 'organization': stmt.excluded.organization, + 'date_updated': stmt.excluded.date_updated, + 'is_active': stmt.excluded.is_active, + 'is_new': stmt.excluded.is_new, + 'is_registered': stmt.excluded.is_registered, + 'last_submitted': stmt.excluded.last_submitted, + 'average_payout': stmt.excluded.average_payout, + 'start_date': stmt.excluded.start_date, + 'end_date': stmt.excluded.end_date, + 'is_updated': stmt.excluded.is_updated + } + ) + session.execute(stmt) + session.commit() session.close() - def add_urls(self, results, **kwargs): - self.add_ips(results) + def add_urls(self, results): session = self.Session() - q = session.query(Url) - ips = session.query(IP) + + self.add_ips(results, session) + ip_map = {ip.ip: ip.id for ip in session.query(IP.ip, IP.id).all()} + + urls_data = list() + for result in results: - ip = ips.filter_by(ip=result.get('ip')).first() - for url in result.get('urls', []): - if ip: - filt = sa.and_( - Url.url.like(url.get('url')), - Url.ip.like(ip.id)) - else: - filt = sa.and_( - Url.url.like(url.get('url'))) - db_url = q.filter(filt).first() - if not db_url: - db_url = Url() - db_url.url = url.get('url') - db_url.screenshot_url = url.get('screenshot_url') - if ip: - db_url.ip = ip.id - session.add(db_url) + ip_id = ip_map.get(result.get('ip')) + if ip_id: + for url in result.get('urls', []): + urls_data.append({ + 'url': url.get('url'), + 'screenshot_url': url.get('screenshot_url') + }) + + if urls_data: + stmt = sqlite_insert(Url).values(urls_data) + stmt = stmt.on_conflict_do_update( + index_elements=['ip', 'url'], + set_={ + 'screenshot_url': stmt.excluded.screenshot_url + } + ) + session.execute(stmt) + session.commit() session.close() @@ -197,30 +289,22 @@ def categories(self): @property def debug(self): - if self.state.debug is None: - return self.get_config('debug') - else: - return self.state.debug + return self.get_config('debug') @debug.setter def debug(self, value): - self.state.debug = value self.set_config('debug', value) @property def email(self): - if self.state.email is None: - ret = self.get_config('email') - if not ret: - ret = input("Synack Email: ") - self.email = ret - return ret - else: - return self.state.email + ret = self.get_config('email') + if not ret: + ret = input('Synack Email: ') + self.email = ret + return ret @email.setter def email(self, value): - self.state.email = value self.set_config('email', value) def find_ips(self, ip=None, **kwargs): @@ -295,7 +379,25 @@ def find_ports(self, port=None, protocol=None, source=None, ip=None, **kwargs): def find_targets(self, **kwargs): session = self.Session() - targets = session.query(Target).filter_by(**kwargs).all() + query = session.query(Target) + + filters = list() + + for key, value in kwargs.items(): + if hasattr(Target, key): + if kwargs.get('like'): + filters.append(getattr(Target, key).like(f'%{value}%')) + else: + filters.append(getattr(Target, key) == value) + + if filters: + if kwargs.get('or'): + query = query.filter(sa.or_(*filters)) + else: + query = query.filter(sa.and_(*filters)) + + targets = query.all() + session.expunge_all() session.close() return targets @@ -343,8 +445,10 @@ def get_config(self, name=None): if not config: config = Config() session.add(config) + session.commit() + ret = getattr(config, name) if name else config session.close() - return getattr(config, name) if name else config + return ret @property def http_proxy(self): @@ -377,39 +481,117 @@ def notifications_token(self): def notifications_token(self, value): self.set_config('notifications_token', value) + @property + def otp_count(self): + ret = self.get_config('otp_count') + if not ret: + ret = input('Synack OTP Count: ') + self.otp_count = int(ret) + return ret + + @otp_count.setter + def otp_count(self, value): + self.set_config('otp_count', value) + @property def otp_secret(self): - if self.state.otp_secret is None: - ret = self.get_config('otp_secret') - if not ret: - ret = input("Synack OTP Secret: ") + ret = self.get_config('otp_secret') + if not ret: + # Skip prompt if automated push credentials are already configured + if self.duo_push_akey and self.duo_push_pkey and self.duo_push_host: + ret = '' self.otp_secret = ret - self.state.otp_secret = ret - return ret - else: - return self.state.otp_secret + # Skip prompt if user has already selected a device for manual push + elif self.duo_device: + ret = '' + else: + print("\nDuo MFA Authentication Setup:") + print( + "1. Press Enter to use Duo Push notifications " + "(you'll approve on your phone)" + ) + print("2. OR enter your Duo OTP Secret for automated passcode generation") + print(" (Accepts hex (hotp_secret) or base32 (otpauth://) format)") + ret = input('\nDuo OTP Secret (or press Enter for push): ').strip() + self.otp_secret = ret if ret else '' + return ret @otp_secret.setter def otp_secret(self, value): - self.state.otp_secret = value + # Auto-detect and convert hex format to base32 + # Duo's hotp_secret is a hex string, but needs to be treated as UTF-8 + # not as hex bytes (based on duo-hotp reference implementation) + if value and self._is_hex_secret(value): + import base64 + # Encode the hex string as UTF-8 bytes, then base32 + value = base64.b32encode(value.encode('utf-8')).decode('ascii').rstrip('=') self.set_config('otp_secret', value) + def _is_hex_secret(self, value): + """Check if the secret appears to be in hex format (not base32)""" + # Hex: 32 chars using only 0-9, a-f + # Base32: variable length using A-Z, 2-7 + if len(value) != 32: + return False + try: + # If it can be decoded as hex, it's hex + bytes.fromhex(value) + return True + except ValueError: + return False + @property def password(self): - if self.state.password is None: - ret = self.get_config('password') - if not ret: - ret = input("Synack Password: ") - self.password = ret - return ret - else: - return self.state.password + ret = self.get_config('password') + if not ret: + ret = input('Synack Password: ') + self.password = ret + return ret @password.setter def password(self, value): - self.state.password = value self.set_config('password', value) + @property + def duo_push_akey(self): + return self.get_config('duo_push_akey') + + @duo_push_akey.setter + def duo_push_akey(self, value): + self.set_config('duo_push_akey', value) + + @property + def duo_push_pkey(self): + return self.get_config('duo_push_pkey') + + @duo_push_pkey.setter + def duo_push_pkey(self, value): + self.set_config('duo_push_pkey', value) + + @property + def duo_push_host(self): + return self.get_config('duo_push_host') + + @duo_push_host.setter + def duo_push_host(self, value): + self.set_config('duo_push_host', value) + + @property + def duo_push_rsa_key_path(self): + return self.get_config('duo_push_rsa_key_path') + + @duo_push_rsa_key_path.setter + def duo_push_rsa_key_path(self, value): + self.set_config('duo_push_rsa_key_path', value) + + @property + def duo_device(self): + return self.get_config('duo_device') + + @duo_device.setter + def duo_device(self, value): + self.set_config('duo_device', value) + @property def ports(self): session = self.Session() @@ -419,19 +601,9 @@ def ports(self): @property def proxies(self): - if self.state.http_proxy is None: - http_proxy = self.get_config('http_proxy') - else: - http_proxy = self.state.http_proxy - - if self.state.https_proxy is None: - https_proxy = self.get_config('https_proxy') - else: - https_proxy = self.state.https_proxy - return { - 'http': http_proxy, - 'https': https_proxy + 'http': self.get_config('http_proxy'), + 'https': self.get_config('https_proxy') } def remove_targets(self, **kwargs): @@ -442,16 +614,11 @@ def remove_targets(self, **kwargs): @property def scratchspace_dir(self): - if self.state.scratchspace_dir is None: - ret = Path(self.get_config('scratchspace_dir')).expanduser().resolve() - self.state.scratchspace_dir = ret - else: - ret = self.state.scratchspace_dir - return ret + return Path(self.get_config('scratchspace_dir')).expanduser().resolve() @scratchspace_dir.setter def scratchspace_dir(self, value): - self.set_config('scratchspace_dir', value) + self.set_config('scratchspace_dir', str(value)) def set_config(self, name, value): session = self.Session() @@ -474,6 +641,30 @@ def set_migration(self): f'sqlite:///{str(self.sqlite_db)}') alembic.command.upgrade(config, 'head') + @property + def slack_app_token(self): + ret = self.get_config('slack_app_token') + if not ret: + ret = input('Slack App Token: ') + self.slack_app_token = ret + return ret + + @slack_app_token.setter + def slack_app_token(self, value): + self.set_config('slack_app_token', value) + + @property + def slack_channel(self): + ret = self.get_config('slack_channel') + if not ret: + ret = input('Slack Channel: ') + self.slack_channel = ret + return ret + + @slack_channel.setter + def slack_channel(self, value): + self.set_config('slack_channel', value) + @property def slack_url(self): return self.get_config('slack_url') @@ -545,14 +736,17 @@ def targets(self): session.close() return targets + @property + def synack_domain(self): + return self.get_config('synack_domain') + + @synack_domain.setter + def synack_domain(self, value): + self.set_config('synack_domain', value) + @property def template_dir(self): - if self.state.template_dir is None: - ret = Path(self.get_config('template_dir')).expanduser().resolve() - self.state.template_dir = ret - else: - ret = self.state.template_dir - return ret + return Path(self.get_config('template_dir')).expanduser().resolve() @template_dir.setter def template_dir(self, value): @@ -567,14 +761,10 @@ def urls(self): @property def use_proxies(self): - if self.state.use_proxies is None: - return self.get_config('use_proxies') - else: - return self.state.use_proxies + return self.get_config('use_proxies') @use_proxies.setter def use_proxies(self, value): - self.state.use_proxies = value self.set_config('use_proxies', value) @property @@ -587,12 +777,8 @@ def user_id(self, value): @property def use_scratchspace(self): - if self.state.use_scratchspace is None: - return self.get_config('use_scratchspace') - else: - return self.state.use_scratchspace + return self.get_config('use_scratchspace') @use_scratchspace.setter def use_scratchspace(self, value): - self.state.use_scratchspace = value self.set_config('use_scratchspace', value) diff --git a/src/synack/plugins/debug.py b/src/synack/plugins/debug.py index e0e0a29..35d80b0 100644 --- a/src/synack/plugins/debug.py +++ b/src/synack/plugins/debug.py @@ -13,10 +13,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for plugin in ['Db']: setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) def log(self, title, message): - if self.db.debug: + if self._state.debug: t = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S") print(f'{t} -- {title.upper()}\n\t{message}') diff --git a/src/synack/plugins/duo.py b/src/synack/plugins/duo.py new file mode 100644 index 0000000..2ddcb4a --- /dev/null +++ b/src/synack/plugins/duo.py @@ -0,0 +1,542 @@ +"""plugins/duo.py + +Functions related to handling Duo Security Multi-Factor Authentication. +""" + +from .base import Plugin + +import base64 +import json +import pyotp +import re +import requests +import time +from datetime import UTC, datetime +from pathlib import Path +from urllib.parse import urlencode +from Crypto.Hash import SHA512 +from Crypto.PublicKey import RSA +from Crypto.Signature import pkcs1_15 + + +class Duo(Plugin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for plugin in ['Api', 'Db', 'Utils']: + setattr(self, + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) + + self._auth_url = None + self._base_url = None + self._device = None + self._factor = None + self._grant_token = None + self._hotp = None + self._progress_token = None + self._referrer = None + self._session_vars = None + self._status = None + self._sid = None + self._txid = None + self._xsrf = None + self._pubkey = None + + def _build_headers(self, overrides=None): + headers = { + 'Sec-Ch-Ua': '"Chromium";v="131", "Not_A Brand";v="24"', + 'Sec-Ch-Ua-Mobile': '?0', + 'Sec-Ch-Ua-Platform': '"Linux"', + 'Referrer': self._referrer, + 'Sec-Fetch-Site': 'cross-site', + 'Sec-Fetch-Mode': 'navigate', + 'Sec-Fetch-User': '?1', + 'Sec-Fetch-Dest': 'document' + } + headers.update(overrides if overrides else dict()) + return headers + + def get_grant_token(self, auth_url): + """Get Grant Token from Duo Security""" + self._auth_url = auth_url + self._get_session_variables() + self._set_session_variables() + self._set_session_variables() # Yes, this needs to be called twice... + self._get_txid() + if self._txid: + # Priority 1: OTP (if configured) + if self._state.otp_secret: + # OTP passcode already sent in _get_txid(), just poll for status + self._get_status() + # Priority 2: Auto-approval (if configured) - HARD FAIL if broken + elif self.is_configured(): + if not self.load_rsa_key(): + raise RuntimeError( + "Duo Push auto-approval is enabled but RSA key failed to load" + ) + print("Auto-approving Duo push notification...") + if self._state.debug: + print(f"Using device: {self._device}") + print(f"Configured duo_device: {self._state.duo_device}") + if self._device != self._state.duo_device: + print(f"WARNING: Push sent to {self._device} but credentials are for {self._state.duo_device}") + # Wait 2 seconds before polling to give Duo time to register the push + time.sleep(2) + if not self.approve_pending_push(timeout=25): + raise RuntimeError( + "Duo Push auto-approval failed - check credentials or " + "disable auto-approval. Ensure duo_device matches the device " + "with extracted credentials." + ) + self._get_status() + # Priority 3: Manual push (fallback) + else: + print("Waiting for manual Duo push approval on your device...") + self._get_status() + if self._status == 'SUCCESS': + self._get_oidc_exit() + if self._progress_token: + self._get_grant_token() + return self._grant_token + + def _get_grant_token(self): + headers = { + 'X-Csrf-Token': self._xsrf + } + data = { + 'progress_token': self._progress_token + } + res = self._api.login('POST', + 'authenticate', + data=data, + headers=headers) + if res.status_code == 200: + self._grant_token = res.json().get('grant_token') + + def _get_mfa_details(self): + if self._state.otp_secret: + self._device = 'null' + self._hotp = pyotp.HOTP(s=self._state.otp_secret).generate_otp(int(self._state.otp_count)) + self._factor = 'Passcode' + return + + headers = { + 'Referer': f'{self._base_url}/frame/v4/auth/prompt?sid={self._sid}', + 'Sec-Fetch-Site': 'same-origin', + 'Sec-Fetch-Mode': 'cors', + 'Sec-Fetch-Dest': 'empty', + 'Accept': '*/*', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8', + 'X-Xsrftoken': self._xsrf + } + query = { + 'post_auth_action': 'OIDC_EXIT', + 'browser_features': json.dumps({ + 'touch_supported': 'false', + 'platform_authenticator_status': 'unavailable', + 'webauthn_supported': 'true' + }, separators=(',', ':')), + 'sid': self._sid + } + res = self._api.request('GET', f'{self._base_url}/frame/v4/auth/prompt/data', headers=headers, query=query) + + if res.status_code == 200: + response_json = res.json() + response_data = response_json.get('response', {}) + phones = response_data.get('phones', []) + + # If auto-approval credentials are configured, find the matching device + if self.is_configured(): + # Match device by pkey + pkey = self._state.duo_push_pkey + for phone in phones: + if phone.get('key', '') == pkey: + self._device = phone.get('index', '') + self._factor = 'Duo Push' + # Update stored device if it doesn't match + if self._state.duo_device != self._device: + print(f"Auto-correcting duo_device from {self._state.duo_device} to {self._device}") + self._db.duo_device = self._device + return + # If no match found, credentials are for wrong account + print(f"WARNING: duo_push_pkey {pkey} not found in available devices") + print("Falling back to manual device selection") + + # Check if we have a stored device preference + if self._state.duo_device: + # Use the stored device + for phone in phones: + if phone.get('index', '') == self._state.duo_device: + self._device = phone.get('index', '') + self._factor = 'Duo Push' + return + # If stored device not found, fall through to prompt + + # Prompt user to select a device + if phones: + print("\nAvailable Duo devices:") + for i, phone in enumerate(phones, 1): + print(f"{i}. {phone.get('name', 'Unknown')} ({phone.get('index', '')})") + + while True: + try: + choice = input("\nSelect device number (or press Enter for first device): ").strip() + if not choice: + selected_phone = phones[0] + break + choice_num = int(choice) + if 1 <= choice_num <= len(phones): + selected_phone = phones[choice_num - 1] + break + print(f"Please enter a number between 1 and {len(phones)}") + except ValueError: + print("Please enter a valid number") + + self._device = selected_phone.get('index', '') + self._factor = 'Duo Push' + self._db.duo_device = self._device + return + + if not self._device or not self._factor: + raise ValueError( + f'Failed to determine MFA device/factor from Duo API. ' + f'HTTP {res.status_code}, device={self._device}, factor={self._factor}' + ) + + def _get_oidc_exit(self): + headers = { + 'Referer': f'{self._base_url}/frame/v4/auth/prompt?sid={self._sid}', + 'Sec-Fetch-Site': 'same-origin', + 'Sec-Fetch-Mode': 'cors', + 'Sec-Fetch-Dest': 'empty', + 'Accept': '*/*', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8', + 'X-Xsrftoken': self._xsrf + } + data = { + 'sid': self._sid, + 'txid': self._txid, + 'factor': self._factor, + 'device_key': self._device, + '_xsrf': self._xsrf, + 'dampen_choice': 'false' + } + res = self._api.request('POST', f'{self._base_url}/frame/v4/oidc/exit', headers=headers, data=data) + if res.status_code == 200: + try: + self._grant_token = re.search('grant_token=([^&]*)', res.url).group(1) + except AttributeError: + self._progress_token = re.search('token=([^&]*)', res.url).group(1) + self._xsrf = self._utils.get_html_tag_value('csrf-token', res.text) + + def _get_session_variables(self): + self._referrer = f'https://login.{self._state.synack_domain}/' + res = self._api.request('GET', self._auth_url, headers=self._build_headers()) + if res.status_code == 200: + self._sid = re.search('sid=([^&]*)', res.url).group(1) + self._referrer = res.url + self._base_url = re.search('(https.*duo[^.]*.com)/', res.url).group(1) + self._xsrf = self._utils.get_html_tag_value('_xsrf', res.text) + + client_hints = base64.b64encode(json.dumps({ + 'brands': [ + {'brand': 'Chromium', 'version': '131'}, + {'brand': 'Not_A Brand', 'version': '24'} + ], + 'fullVersionList': [], + 'mobile': False, + 'platform': 'Linux', + 'platformVersion': '', + 'uaFullVersion': '' + }).encode()).decode() + + analysis_feature = self._utils.get_html_tag_value('has_session_trust_analysis_feature', res.text) + + self._session_vars = { + 'tx': self._utils.get_html_tag_value('tx', res.text), + 'parent': self._utils.get_html_tag_value('parent', res.text), + '_xsrf': self._xsrf, + 'version': self._utils.get_html_tag_value('version', res.text), + 'akey': self._utils.get_html_tag_value('akey', res.text), + 'has_session_trust_analysis_feature': analysis_feature, + 'session_trust_extension_id': self._utils.get_html_tag_value('session_trust_extension_id', res.text), + 'java_version': self._utils.get_html_tag_value('java_version', res.text), + 'flash_version': self._utils.get_html_tag_value('flash_version', res.text), + 'screen_resolution_width': '3422', + 'screen_resolution_height': '1465', + 'extension_instance_key': '', + 'color_depth': '24', + 'has_touch_capability': 'false', + 'ch_ua_error': '', + 'client_hints': client_hints, + 'is_cef_browser': 'false', + 'is_ipad_os': 'false', + 'is_ie_compatibility_mode': '', + 'is_user_verifying_platform_authenticator_available': 'false', + 'user_verifying_platform_authenticator_available_error': '', + 'acting_ie_version': '', + 'react_support': 'false', + 'react_support_error_message': '' + } + + def _get_status(self): + headers = { + 'Referrer': f'{self._base_url}/frame/v4/auth/prompt?sid={self._sid}', + 'Sec-Fetch-Site': 'same-origin', + 'Sec-Fetch-Mode': 'cors', + 'Sec-Fetch-Dest': 'empty', + 'Accept': '*/*', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8', + 'X-Xsrftoken': self._xsrf + } + data = { + 'txid': self._txid, + 'sid': self._sid + } + # Increase polling attempts from 5 to 12 (1 minute total with 5s intervals) + for i in range(12): + res = self._api.request('POST', f'{self._base_url}/frame/v4/status', headers=headers, data=data) + if res.status_code == 200: + status_enum = res.json().get('response', {}).get('status_enum', -1) + message_enum = res.json().get('message_enum', -1) + self._status = res.json().get('response', {}).get('result', 'UNKNOWN') + if status_enum == 5 or self._status == 'SUCCESS': # Valid Code + break + elif status_enum == 6: # Push Notification Declined (Normal) + break + elif status_enum == 7: # Push Notification Declined (Suspicious Login) + break + elif status_enum == 11: # Bad Code (or Future Code by 20+) + print("Bad OTP Code Sent") + print(res) + print(res.json()) + break + elif status_enum == 13: # Awaiting Push Notification + pass + elif status_enum == 15: # Push sent, waiting for approval + # Continue polling for both auto-approval and manual approval + pass + elif status_enum == 44: # Prior Code + self._db.otp_count += 5 + break + elif message_enum == 57: # Bad Request + print('Your Request was bad!') + break + else: # IDK + print('Something went wrong!') + print(res) + print(res.json()) + break + time.sleep(5) + + def _get_txid(self): + """Sends Push Notification or Submits HOTP""" + headers = { + 'Referrer': f'{self._base_url}/frame/v4/auth/prompt?sid={self._sid}', + 'Sec-Fetch-Site': 'same-origin', + 'Sec-Fetch-Mode': 'cors', + 'Sec-Fetch-Dest': 'empty', + 'Accept': '*/*', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8', + 'X-Xsrftoken': self._xsrf + } + + self._get_mfa_details() + + if self._device and self._factor: + data = { + 'device': self._device, + 'factor': self._factor, + 'postAuthDestination': 'OIDC_EXIT', + 'browser_features': json.dumps({ + 'touch_supported': 'false', + 'platform_authenticator_status': 'unavailable', + 'webauthn_supported': 'true' + }, separators=(',', ':')), + 'sid': self._sid + } + + if self._state.otp_secret: + data['passcode'] = self._hotp + + res = self._api.request('POST', + f'{self._base_url}/frame/v4/prompt', + headers=self._build_headers(headers), + data=data) + if res.status_code == 200: + self._txid = res.json().get('response', {}).get('txid', '') + if self._state.otp_secret: + self._db.otp_count += 1 + + def _set_session_variables(self): + headers = { + 'Sec-Ch-Ua': '"Chromium";v="131", "Not_A Brand";v="24"', + 'Sec-Ch-Ua-Mobile': '?0', + 'Sec-Ch-Ua-Platform': '"Linux"', + 'Referer': self._referrer, + 'Sec-Fetch-Site': 'same-origin', + 'Sec-Fetch-Mode': 'navigate', + 'Sec-Fetch-Dest': 'document', + 'Accept': ';'.join([ + 'text/html,application/xhtml+xml,application/xml', + 'q=0.9,image/avif,image/webp,image/apng,*/*', + 'q=0.8,application/signed-exchange', + 'v=b3;q=0.7' + ]), + 'Content-Type': 'application/x-www-form-urlencoded' + } + res = self._api.request('POST', self._referrer, headers=headers, data=self._session_vars) + if res.status_code == 200: + self._referrer = res.url + + # Duo Push Auto-Approval Methods + + def is_configured(self): + """Check if Duo push auto-approval credentials are configured""" + return ( + self._state.duo_push_akey and + self._state.duo_push_pkey and + self._state.duo_push_host + ) + + def load_rsa_key(self): + """Load RSA key from configured path""" + if not self.is_configured(): + return False + + key_path = Path(self._state.duo_push_rsa_key_path).expanduser() + if not key_path.exists(): + print(f"Duo RSA key not found: {key_path}") + return False + + try: + with open(key_path, 'rb') as f: + self._pubkey = RSA.import_key(f.read()) + return True + except Exception as e: + print(f"Failed to load Duo RSA key: {e}") + return False + + def approve_pending_push(self, timeout=30): + """Wait for and approve a single Duo push notification""" + if not self.is_configured(): + return False + + if not self._pubkey and not self.load_rsa_key(): + print("Cannot approve push: RSA key not available") + return False + + print("Polling for Duo push notification...") + start_time = time.monotonic() + poll_interval = 2 # Poll every 2 seconds + + while time.monotonic() - start_time < timeout: + try: + # Poll for transactions + transactions = self._get_transactions() + if self._state.debug: + print(f"Transactions response: {transactions}") + response_data = transactions.get('response', {}) + pending = response_data.get('transactions', []) + current_time = response_data.get('current_time', 0) + + if self._state.debug: + print(f"Found {len(pending)} pending transactions") + + if pending: + for tx in pending: + tx_id = tx.get('urgid') + expiration = tx.get('expiration', 0) + + if self._state.debug: + print(f"Transaction: {tx}") + + # Skip expired transactions + if expiration and current_time and expiration <= current_time: + if self._state.debug: + print(f"Skipping expired transaction {tx_id}") + continue + + if tx_id: + tx_summary = tx.get('summary', 'N/A') + print(f"Approving Duo push {tx_id[:12]}... ({tx_summary})") + response = self._reply_transaction(tx_id, 'approve') + if response.get('stat') == 'OK': + print("Duo push approved successfully") + return True + else: + print(f"Push approval returned: {response}") + + time.sleep(poll_interval) + + except Exception as e: + print(f"Error during Duo push approval: {e}") + return False + + return False + + def _generate_signature(self, method, path, time_str, data): + """Generate RSA signature for Duo API request""" + encoded_data = urlencode(sorted(data.items())) if data else "" + message_parts = [ + time_str, + method.upper(), + self._state.duo_push_host.lower(), + path, + encoded_data, + ] + message = "\n".join(message_parts).encode('ascii') + h = SHA512.new(message) + signature = pkcs1_15.new(self._pubkey).sign(h) + auth_string = f"{self._state.duo_push_pkey}:{base64.b64encode(signature).decode('ascii')}" + return "Basic " + base64.b64encode(auth_string.encode('ascii')).decode('ascii') + + def _make_request(self, method, path, data): + """Make authenticated request to Duo device API""" + dt = datetime.now(UTC) + # Format as RFC 2822 date for HTTP header (e.g., "Mon, 04 Nov 2025 12:34:56 GMT") + time_str = dt.strftime('%a, %d %b %Y %H:%M:%S GMT') + signature = self._generate_signature(method, path, time_str, data) + + url = f"https://{self._state.duo_push_host}{path}" + headers = { + 'Authorization': signature, + 'x-duo-date': time_str, + 'Host': self._state.duo_push_host, + 'Content-Type': 'application/x-www-form-urlencoded', + } + + try: + if method.upper() == 'GET': + r = requests.get(url, params=data, headers=headers, timeout=10) + else: + r = requests.post(url, data=data, headers=headers, timeout=10) + + r.raise_for_status() + return r.json() + except Exception as e: + print(f"Duo API request failed: {e}") + raise + + def _get_transactions(self): + """Get pending Duo push transactions""" + path = "/push/v2/device/transactions" + params = { + 'akey': self._state.duo_push_akey, + 'fips_status': '1', + 'hsm_status': 'true', + 'pkpush': 'rsa-sha512', + } + return self._make_request('GET', path, params) + + def _reply_transaction(self, transaction_id, answer): + """Reply to a Duo push transaction (approve/deny)""" + path = f"/push/v2/device/transactions/{transaction_id}" + data = { + 'akey': self._state.duo_push_akey, + 'answer': answer, + 'fips_status': '1', + 'hsm_status': 'true', + 'pkpush': 'rsa-sha512', + } + return self._make_request('POST', path, data) diff --git a/src/synack/plugins/hydra.py b/src/synack/plugins/hydra.py deleted file mode 100644 index 399383b..0000000 --- a/src/synack/plugins/hydra.py +++ /dev/null @@ -1,81 +0,0 @@ -"""plugins/hydra.py - -Functions dealing with hydra -""" - -import json -import time - -from .base import Plugin -from datetime import datetime - - -class Hydra(Plugin): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - for plugin in ['Api', 'Db']: - setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) - - def build_db_input(self, results): - """Format the Hydra output so that it can be ingested into the DB""" - db_input = list() - for result in results: - ports = list() - for port in result.get('ports').keys(): - for protocol in result['ports'][port].keys(): - for hydra_src in result['ports'][port][protocol].keys(): - h_src = result['ports'][port][protocol][hydra_src] - service = h_src.get('verified_service', {'parsed': 'unknown'})['parsed'] + \ - ' - ' + \ - h_src.get('product', {'parsed': 'unknown'})['parsed'] - service = service.strip(' - ') - port_open = result['ports'][port][protocol][hydra_src]['open']['parsed'] - epoch = datetime(1970, 1, 1) - try: - last_changed_dt = datetime.strptime(result['last_changed_dt'], "%Y-%m-%dT%H:%M:%SZ") - except ValueError: - last_changed_dt = datetime.strptime(result['last_changed_dt'], "%Y-%m-%dT%H:%M:%S.%fZ") - updated = int((last_changed_dt - epoch).total_seconds()) - - ports.append({ - "port": port, - "protocol": protocol, - "service": service, - "open": port_open, - "updated": updated - }) - db_input.append({ - "ip": result["ip"], - "target": result["listing_uid"], - "source": "hydra", - "ports": ports - }) - return db_input - - def get_hydra(self, page=1, max_page=5, update_db=True, **kwargs): - """Get Hydra results for target identified using kwargs (codename='x', slug='x', etc.)""" - max_page = 1000 if max_page == 0 else max_page - results = list() - targets = self.db.find_targets(**kwargs) - if targets: - target = targets[0] - if target: - query = { - 'page': page, - 'listing_uids': target.slug, - 'q': '+port_is_open:true' - } - time.sleep(page*0.01) - res = self.api.request('GET', - 'hydra_search/search', - query=query) - if res.status_code == 200: - curr_results = json.loads(res.content) - results.extend(curr_results) - if len(curr_results) == 10 and page < max_page: - results.extend(self.get_hydra(page=page+1, max_page=max_page, **kwargs)) - if update_db: - self.db.add_ports(self.build_db_input(results)) - return results diff --git a/src/synack/plugins/missions.py b/src/synack/plugins/missions.py index 8ff562e..d4f45e3 100644 --- a/src/synack/plugins/missions.py +++ b/src/synack/plugins/missions.py @@ -16,8 +16,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for plugin in ['Api', 'Db', 'Targets', 'Templates']: setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) def build_order(self, missions, sort="payout-high"): """Sort a list of missions by what's desired first @@ -49,36 +49,42 @@ def build_summary(self, missions): missions -- List of missions from one of the get_missions functions """ ret = { - "count": 0, - "value": 0, - "time": 0 + 'total': { 'count': 0, 'value': 0, 'time': 0 } } - for m in missions: - if m.get("status") == "CLAIMED": + for mission in missions: + codename = mission.get('listingCodename', 'UNKNOWN') + ret[codename] = ret.get(codename, {'count': 0, 'value': 0, 'time': 0, 'titles': list()}) + ret[codename]['count'] += 1 + ret[codename]['value'] += mission['payout']['amount'] + ret[codename]['titles'].append(mission['title']) + ret['total']['count'] += 1 + ret['total']['value'] += mission['payout']['amount'] + + if mission.get('status') == 'CLAIMED': utc = datetime.utcnow() try: - claimed_on = datetime.strptime(m['claimedOn'], - "%Y-%m-%dT%H:%M:%S.%fZ") + claimed_on = datetime.strptime(mission['claimedOn'], + '%Y-%m-%dT%H:%M:%S.%fZ') except ValueError: - claimed_on = datetime.strptime(m['claimedOn'], - "%Y-%m-%dT%H:%M:%SZ") + claimed_on = datetime.strptime(mission['claimedOn'], + '%Y-%m-%dT%H:%M:%SZ') try: - modified_on = datetime.strptime(m['modifiedOn'], - "%Y-%m-%dT%H:%M:%S.%fZ") + modified_on = datetime.strptime(mission['modifiedOn'], + '%Y-%m-%dT%H:%M:%S.%fZ') except ValueError: - modified_on = datetime.strptime(m['modifiedOn'], - "%Y-%m-%dT%H:%M:%SZ") + modified_on = datetime.strptime(mission['modifiedOn'], + '%Y-%m-%dT%H:%M:%SZ') report_time = claimed_on if claimed_on > modified_on else modified_on elapsed = int((utc - report_time).total_seconds()) - time = m['maxCompletionTimeInSecs'] - elapsed - if time < ret['time'] or ret['time'] == 0: - ret['time'] = time - ret['count'] = ret['count'] + 1 - ret['value'] = ret['value'] + m['payout']['amount'] + time = mission['maxCompletionTimeInSecs'] - elapsed + if time < ret['total']['time'] or ret['total']['time'] == 0: + ret['total']['time'] = time + if time < ret[codename]['time'] or ret[codename]['time'] == 0: + ret[codename]['time'] = time + return ret - def get(self, status="PUBLISHED", - max_pages=1, page=1, per_page=20, listing_uids=None): + def get(self, status='PUBLISHED', max_pages=1, page=1, per_page=20, listing_uids=None, **kwargs): """Get a list of missions given a status Arguments: @@ -99,30 +105,36 @@ def get(self, status="PUBLISHED", } if listing_uids: query["listingUids"] = listing_uids - res = self.api.request('GET', - 'tasks/v2/tasks', - query=query) + res = self._api.request('GET', + 'tasks/v2/tasks', + query=query) if res.status_code == 200: ret = res.json() if len(ret) == per_page and page < max_pages: - new = self.get(status, - max_pages, - page+1, - per_page) + new = self.get(status=status, + max_pages=max_pages, + page=page+1, + per_page=per_page) ret.extend(new) return ret + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() + return [] - def get_approved(self): + def get_approved(self, **kwargs): """Get a list of missions currently approved""" - return self.get("APPROVED") + kwargs['status'] = 'APPROVED' + return self.get(**kwargs) - def get_available(self): + def get_available(self, **kwargs): """Get a list of missions currently available""" - return self.get("PUBLISHED") + kwargs['status'] = 'PUBLISHED' + return self.get(**kwargs) - def get_claimed(self): + def get_claimed(self, **kwargs): """Get a list of all missions you currently have""" - return self.get("CLAIMED") + kwargs['status'] = 'CLAIMED' + return self.get(**kwargs) def get_count(self, status="PUBLISHED", listing_uids=None): """Get the number of missions currently available @@ -137,11 +149,14 @@ def get_count(self, status="PUBLISHED", listing_uids=None): } if listing_uids: query["listingUid"] = listing_uids - res = self.api.request('HEAD', - 'tasks/v1/tasks', - query=query) + res = self._api.request('HEAD', + 'tasks/v1/tasks', + query=query) if res.status_code == 204: return int(res.headers.get('x-count', 0)) + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() + return 0 def get_evidences(self, mission): """Download the evidences for a single mission @@ -149,36 +164,43 @@ def get_evidences(self, mission): Arguments: mission -- A single mission """ - evidences = self.api.request('GET', - 'tasks/v2/tasks/' + - mission['id'] + - '/evidences') - if evidences.status_code == 200: - ret = evidences.json() + res = self._api.request('GET', + 'tasks/v2/tasks/' + + mission['id'] + + '/evidences') + if res.status_code == 200: + ret = res.json() ret["title"] = mission["title"] ret["asset"] = mission["assetTypes"][0] ret["taskType"] = mission["taskType"] ret["structuredResponse"] = mission["validResponses"][1]["value"] return ret + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() - def get_in_review(self): + def get_in_review(self, **kwargs): """Get a list of missions currently in review""" - return self.get("FOR_REVIEW") + kwargs['status'] = 'FOR_REVIEW' + return self.get(**kwargs) def get_wallet_claimed(self): """Get Current Claimed Amount for Mission Wallet""" - res = self.api.request('GET', - 'tasks/v2/researcher/claimed_amount') + res = self._api.request('GET', + 'tasks/v2/researcher/claimed_amount') if res.status_code == 200: return int(res.json().get('claimedAmount', '0')) + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def get_wallet_limit(self): """Get Current Mission Wallet Limit""" - res = self.api.request('GET', - 'profiles/me') + res = self._api.request('GET', + 'profiles/me') if res.status_code == 200: return int(res.json().get('claim_limit', '0')) + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def set_claimed(self, mission): """Try to claim a single mission @@ -196,34 +218,36 @@ def set_disclaimed(self, mission): """ return self.set_status(mission, "DISCLAIM") - def set_evidences(self, mission, template=None): + def set_evidences(self, mission, template=None, force=False): """Upload a template to a mission Arguments: mission -- A single mission """ if template is None: - template = self.templates.get_file(mission) + template = self._templates.get_file(mission) if template: curr = self.get_evidences(mission) safe = True if curr: for f in ['introduction', 'testing_methodology', 'conclusion']: - if len(curr.get(f)) >= 20: + if len(curr.get(f)) >= 20 and force == False: safe = False break if safe: - res = self.api.request('PATCH', - 'tasks/v2/tasks/' + - mission['id'] + - '/evidences', - data=template) + res = self._api.request('PATCH', + 'tasks/v2/tasks/' + + mission['id'] + + '/evidences', + data=template) if res.status_code == 200: ret = res.json() ret["title"] = mission["title"] ret["codename"] = mission["listingCodename"] return ret + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def set_status(self, mission, status): """Interact with single mission @@ -234,21 +258,21 @@ def set_status(self, mission, status): data = { "type": status } - orgId = mission["organizationUid"] - listingId = mission["listingUid"] - campaignId = mission["campaignUid"] - taskId = mission["id"] - payout = str(mission["payout"]["amount"]) - title = mission["title"] - - res = self.api.request('POST', - 'tasks/v1' + - '/organizations/' + orgId + - '/listings/' + listingId + - '/campaigns/' + campaignId + - '/tasks/' + taskId + - '/transitions', - data=data) + orgId = mission.get('organizationUid', 'unk') + listingId = mission.get('listingUid', 'unk') + campaignId = mission.get('campaignUid', 'unk') + taskId = mission.get('id') + payout = str(mission.get('payout', {}).get('amount', 'unk')) + title = mission.get('title', 'unk') + + res = self._api.request('POST', + 'tasks/v1' + + '/organizations/' + orgId + + '/listings/' + listingId + + '/campaigns/' + campaignId + + '/tasks/' + taskId + + '/transitions', + data=data) return { "target": listingId, "title": title, diff --git a/src/synack/plugins/notifications.py b/src/synack/plugins/notifications.py index ce828f1..a492eeb 100644 --- a/src/synack/plugins/notifications.py +++ b/src/synack/plugins/notifications.py @@ -11,24 +11,34 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for plugin in ['Api', 'Db']: setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) def get(self): """Get a list of recent notifications""" - res = self.api.notifications('GET', - 'notifications?meta=1') + res = self._api.notifications('GET', + 'notifications?meta=1') if res.status_code == 200: return res.json() def get_unread_count(self): """Get the number of unread notifications""" - token = self.db.notifications_token query = { - "authorization_token": token + "authorization_token": self._state.notifications_token } - res = self.api.notifications('GET', - 'notifications/unread_count', - query=query) + res = self._api.notifications('GET', + 'notifications/unread_count', + query=query) + if res.status_code == 200: + return res.json() + + def set_read(self): + """Set all notifications to read""" + query = { + "authorization_token": self._state.notifications_token + } + res = self._api.notifications('GET', + 'read_all', + query=query) if res.status_code == 200: return res.json() diff --git a/src/synack/plugins/scratchspace.py b/src/synack/plugins/scratchspace.py index 436e496..4b4a277 100644 --- a/src/synack/plugins/scratchspace.py +++ b/src/synack/plugins/scratchspace.py @@ -13,37 +13,25 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for plugin in ['Api', 'Db']: setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) def build_filepath(self, filename, target=None, codename=None): if target: codename = target.codename if codename: - f = self.db.scratchspace_dir + f = self._state.scratchspace_dir f = f / codename f.mkdir(parents=True, exist_ok=True) f = f / filename return f def set_assets_file(self, content, target=None, codename=None): - if target or codename: - if type(content) in [list, set]: - content = '\n'.join(content) - dest_file = self.build_filepath('assets.txt', target=target, codename=codename) - with open(dest_file, 'w') as fp: - fp.write(content) - return dest_file + return self.set_file(content=content, filename='assets.txt', target=target, codename=codename) def set_burp_file(self, content, target=None, codename=None): - if target or codename: - if type(content) == dict: - content = json.dumps(content) - dest_file = self.build_filepath('burp.txt', target=target, codename=codename) - with open(dest_file, 'w') as fp: - fp.write(content) - return dest_file + return self.set_file(content=content, filename='burp.txt', target=target, codename=codename) def set_download_attachments(self, attachments, target=None, codename=None, prompt_overwrite=True, overwrite=True): downloads = list() @@ -55,18 +43,25 @@ def set_download_attachments(self, attachments, target=None, codename=None, prom ans = input(f'{attachment.get("filename")} exists. Overwrite? [y/N]: ') overwrite_current = ans.lower().startswith('y') if overwrite_current or not dest_file.exists(): - res = self.api.request('GET', attachment.get('url')) + res = self._api.request('GET', attachment.get('url')) if res.status_code == 200: with open(dest_file, 'wb') as fp: fp.write(res.content) downloads.append(dest_file) + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() return downloads - def set_hosts_file(self, content, target=None, codename=None): + def set_file(self, content, filename, target=None, codename=None): if target or codename: if type(content) in [list, set]: content = '\n'.join(content) - dest_file = self.build_filepath('hosts.txt', target=target, codename=codename) + elif type(content) in [dict]: + content = json.dumps(content) + dest_file = self.build_filepath(filename, target=target, codename=codename) with open(dest_file, 'w') as fp: fp.write(content) return dest_file + + def set_hosts_file(self, content, target=None, codename=None): + return self.set_file(content=content, filename='hosts.txt', target=target, codename=codename) diff --git a/src/synack/plugins/targets.py b/src/synack/plugins/targets.py index 9cdde40..4561934 100644 --- a/src/synack/plugins/targets.py +++ b/src/synack/plugins/targets.py @@ -15,8 +15,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for plugin in ['Api', 'Db', 'Scratchspace']: setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) def build_codename_from_slug(self, slug): """Return a codename for a target given its slug @@ -25,10 +25,10 @@ def build_codename_from_slug(self, slug): slug -- Slug of desired target """ codename = 'NONE' - targets = self.db.find_targets(slug=slug) + targets = self._db.find_targets(slug=slug) if not targets: self.get_registered_summary() - targets = self.db.find_targets(slug=slug) + targets = self._db.find_targets(slug=slug) if targets: codename = targets[0].codename return codename @@ -86,20 +86,22 @@ def build_scope_web_db(self, scope): def build_slug_from_codename(self, codename): """Return a slug for a target given its codename""" slug = None - targets = self.db.find_targets(codename=codename) + targets = self._db.find_targets(codename=codename) if not targets: self.get_registered_summary() - targets = self.db.find_targets(codename=codename) + targets = self._db.find_targets(codename=codename) if targets: slug = targets[0].slug return slug def get_assessments(self): """Check which assessments have been completed""" - res = self.api.request('GET', 'assessments') + res = self._api.request('GET', 'assessments') if res.status_code == 200: - self.db.add_categories(res.json()) - return self.db.categories + self._db.add_categories(res.json()) + return self._db.categories + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def get_assets(self, target=None, asset_type=None, host_type=None, active='true', scope=['in', 'discovered'], sort='location', sort_dir='asc', @@ -107,14 +109,19 @@ def get_assets(self, target=None, asset_type=None, host_type=None, active='true' """Get the assets (scope) of a target""" if target is None: if len(kwargs) > 0: - target = self.db.find_targets(**kwargs) + target = self._db.find_targets(**kwargs) else: curr = self.get_connected() - target = self.db.find_targets(slug=curr.get('slug')) + target = self._db.find_targets(slug=curr.get('slug')) if type(scope) == str: scope = [scope] + if type(host_type) == str: + host_type = [host_type] + elif host_type is None: + host_type = list() + if target: if type(target) is list and len(target) > 0: target = target[0] @@ -125,8 +132,8 @@ def get_assets(self, target=None, asset_type=None, host_type=None, active='true' queries.append(f'organizationUid%5B%5D={organization_uid}') if asset_type is not None: queries.append(f'assetType%5B%5D={asset_type}') - if host_type is not None: - queries.append(f'hostType%5B%5D={host_type}') + for item in host_type: + queries.append(f'hostType%5B%5D={item}') for item in scope: queries.append(f'scope%5B%5D={item}') if sort is not None: @@ -140,27 +147,31 @@ def get_assets(self, target=None, asset_type=None, host_type=None, active='true' if perPage is not None: queries.append(f'perPage={perPage}') - res = self.api.request('GET', f'asset/v2/assets?{"&".join(queries)}') + res = self._api.request('GET', f'asset/v2/assets?{"&".join(queries)}') if res.status_code == 200: - if self.db.use_scratchspace: - self.scratchspace.set_assets_file(res.text, target=target) + if self._state.use_scratchspace: + self._scratchspace.set_assets_file(res.text, target=target) return res.json() + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def get_attachments(self, target=None, **kwargs): """Get the attachments of a target.""" if target is None: if len(kwargs) == 0: kwargs = {'codename': self.get_connected().get('codename')} - target = self.db.find_targets(**kwargs) + target = self._db.find_targets(**kwargs) if target: target = target[0] - res = self.api.request('GET', f'targets/{target.slug}/resources') + res = self._api.request('GET', f'targets/{target.slug}/resources') if res.status_code == 200: return res.json() + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def get_connected(self): """Return information about the currenly selected target""" - res = self.api.request('GET', 'launchpoint') + res = self._api.request('GET', 'launchpoint') if res.status_code == 200: j = res.json() slug = j.get('slug') @@ -176,37 +187,43 @@ def get_connected(self): "status": status } return ret + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def get_connections(self, target=None, **kwargs): """Get the connection details of a target.""" if target is None: if len(kwargs) == 0: kwargs = {'codename': self.get_connected().get('codename')} - target = self.db.find_targets(**kwargs) + target = self._db.find_targets(**kwargs) if target: target = target[0] - res = self.api.request('GET', "listing_analytics/connections", query={"listing_id": target.slug}) + res = self._api.request('GET', "listing_analytics/connections", query={"listing_id": target.slug}) if res.status_code == 200: return res.json()["value"] + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def get_credentials(self, **kwargs): """Get Credentials for a target""" - target = self.db.find_targets(**kwargs)[0] + target = self._db.find_targets(**kwargs)[0] if target: - res = self.api.request('POST', - f'asset/v1/organizations/{target.organization}' + - f'/owners/listings/{target.slug}' + - f'/users/{self.db.user_id}' + - '/credentials') + res = self._api.request('POST', + f'asset/v1/organizations/{target.organization}' + + f'/owners/listings/{target.slug}' + + f'/users/{self._state.user_id}' + + '/credentials') if res.status_code == 200: return res.json() + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() - def get_query(self, status='registered', query_changes={}): + def get(self, status='registered', query_changes={}): """Get information about targets returned from a query""" - if not self.db.categories: + if not self._db.categories: self.get_assessments() categories = [] - for category in self.db.categories: + for category in self._db.categories: if category.passed_practical and category.passed_written: categories.append(category.id) query = { @@ -216,55 +233,60 @@ def get_query(self, status='registered', query_changes={}): 'filter[category][]': categories } query.update(query_changes) - res = self.api.request('GET', 'targets', query=query) + res = self._api.request('GET', 'targets', query=query) if res.status_code == 200: - self.db.add_targets(res.json(), is_registered=True) + self._db.add_targets(res.json(), is_registered=True) return res.json() + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def get_registered_summary(self): """Get information on your registered targets""" - res = self.api.request('GET', 'targets/registered_summary') + res = self._api.request('GET', 'targets/registered_summary') ret = [] if res.status_code == 200: - self.db.add_targets(res.json()) + self._db.add_targets(res.json()) ret = dict() for t in res.json(): ret[t['id']] = t + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() return ret - def get_scope(self, add_to_db=False, **kwargs): + def get_scope(self, **kwargs): """Get the scope of a target""" if len(kwargs) > 0: - target = self.db.find_targets(**kwargs) + target = self._db.find_targets(**kwargs) else: curr = self.get_connected() - target = self.db.find_targets(slug=curr.get('slug')) + target = self._db.find_targets(slug=curr.get('slug')) if target: target = target[0] categories = dict() - for category in self.db.categories: + for category in self._db.categories: categories[category.id] = category.name if categories[target.category].lower() == 'host': - return self.get_scope_host(target, add_to_db=add_to_db) + return self.get_scope_host(target) elif categories[target.category].lower() in ['web application', 'mobile']: - return self.get_scope_web(target, add_to_db=add_to_db) + return self.get_scope_web(target) - def get_scope_host(self, target=None, add_to_db=False, **kwargs): + def get_scope_host(self, target=None, **kwargs): """Get the scope of a Host target""" + if target is None: if len(kwargs) > 0: - targets = self.db.find_targets(**kwargs) + targets = self._db.find_targets(**kwargs) else: curr = self.get_connected() - targets = self.db.find_targets(slug=curr.get('slug')) + targets = self._db.find_targets(slug=curr.get('slug')) if targets: target = next(iter(targets), None) scope = set() if target: - assets = self.get_assets(target=target, active='true', asset_type='host', host_type='cidr') + assets = self.get_assets(target=target, active='true', asset_type='host', host_type=['cidr', 'ip']) for asset in assets: if asset.get('active'): try: @@ -275,23 +297,22 @@ def get_scope_host(self, target=None, add_to_db=False, **kwargs): pass scope.discard(None) + scope = list(scope) if len(scope) > 0: - if add_to_db: - self.db.add_ips(self.build_scope_host_db(target.slug, scope)) - if self.db.use_scratchspace: - self.scratchspace.set_hosts_file(scope, target=target) + if self._state.use_scratchspace: + self._scratchspace.set_hosts_file(scope, target=target) return scope - def get_scope_web(self, target=None, add_to_db=False, **kwargs): + def get_scope_web(self, target=None, **kwargs): """Get the scope of a Web target""" if target is None: if len(kwargs) > 0: - targets = self.db.find_targets(**kwargs) + targets = self._db.find_targets(**kwargs) else: curr = self.get_connected() - targets = self.db.find_targets(slug=curr.get('slug')) + targets = self._db.find_targets(slug=curr.get('slug')) if targets: target = next(iter(targets), None) @@ -314,10 +335,8 @@ def get_scope_web(self, target=None, add_to_db=False, **kwargs): }) if len(scope) > 0: - if add_to_db: - self.db.add_urls(self.build_scope_web_db(scope)) - if self.db.use_scratchspace: - self.scratchspace.set_burp_file(self.build_scope_web_burp(scope), target=target) + if self._state.use_scratchspace: + self._scratchspace.set_burp_file(self.build_scope_web_burp(scope), target=target) return scope @@ -328,32 +347,36 @@ def get_submissions(self, target=None, status="accepted", **kwargs): if target is None: if len(kwargs) == 0: kwargs = {'codename': self.get_connected().get('codename')} - target = self.db.find_targets(**kwargs) + target = self._db.find_targets(**kwargs) if target: target = target[0] query = {"listing_id": target.slug, "status": status} - res = self.api.request('GET', "listing_analytics/categories", query=query) + res = self._api.request('GET', "listing_analytics/categories", query=query) if res.status_code == 200: return res.json()["value"] + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def get_submissions_summary(self, target=None, hours_ago=None, **kwargs): """Get a summary of the submission analytics of a target.""" if target is None: if len(kwargs) == 0: kwargs = {'codename': self.get_connected().get('codename')} - target = self.db.find_targets(**kwargs) + target = self._db.find_targets(**kwargs) if target: target = target[0] query = {"listing_id": target.slug} if hours_ago: query["period"] = f"{hours_ago}h" - res = self.api.request('GET', "listing_analytics/submissions", query=query) + res = self._api.request('GET', "listing_analytics/submissions", query=query) if res.status_code == 200: return res.json()["value"] + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def get_unregistered(self): """Get slugs of all unregistered targets""" - return self.get_query(status='unregistered') + return self.get(status='unregistered') def get_upcoming(self): """Get slugs and upcoming start dates of all upcoming targets""" @@ -361,7 +384,7 @@ def get_upcoming(self): 'sorting[field]': 'upcomingStartDate', 'sorting[direction]': 'asc' } - return self.get_query(status='upcoming', query_changes=query_changes) + return self.get(status='upcoming', query_changes=query_changes) def set_connected(self, target=None, **kwargs): """Connect to a target""" @@ -371,14 +394,16 @@ def set_connected(self, target=None, **kwargs): elif len(kwargs) == 0: slug = '' else: - target = self.db.find_targets(**kwargs) + target = self._db.find_targets(**kwargs) if target: slug = target[0].slug if slug is not None: - res = self.api.request('PUT', 'launchpoint', data={'listing_id': slug}) + res = self._api.request('PUT', 'launchpoint', data={'listing_id': slug}) if res.status_code == 200: return self.get_connected() + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() def set_registered(self, targets=None): """Register all unregistered targets""" @@ -387,11 +412,13 @@ def set_registered(self, targets=None): data = '{"ResearcherListing":{"terms":1}}' ret = [] for t in targets: - res = self.api.request('POST', - f'targets/{t["slug"]}/signup', - data=data) + res = self._api.request('POST', + f'targets/{t["slug"]}/signup', + data=data) if res.status_code == 200: ret.append(t) + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() if len(targets) >= 15: ret.extend(self.set_registered()) return ret diff --git a/src/synack/plugins/templates.py b/src/synack/plugins/templates.py index 57e4305..a1bd30c 100644 --- a/src/synack/plugins/templates.py +++ b/src/synack/plugins/templates.py @@ -14,11 +14,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for plugin in ['Alerts', 'Db', 'Targets']: setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) def build_filepath(self, mission, generic_ok=False): - f = self.db.template_dir + f = self._state.template_dir f = f / self.build_safe_name(mission['taskType']) if mission.get('asset'): f = f / self.build_safe_name(mission['asset']) @@ -34,7 +34,7 @@ def build_filepath(self, mission, generic_ok=False): def build_replace_variables(self, text, target=None, **kwargs): """Replaces known variables within text""" if target is None: - target = self.db.find_targets(**kwargs) + target = self._db.find_targets(**kwargs) if target: target = target[0] @@ -44,7 +44,7 @@ def build_replace_variables(self, text, target=None, **kwargs): def build_safe_name(self, name): """Simplify a name to use for a file path""" - name = self.alerts.sanitize(name) + name = self._alerts.sanitize(name) name = name.lower() name = re.sub('[^a-z0-9]', '_', name) return re.sub('_+', '_', name) diff --git a/src/synack/plugins/transactions.py b/src/synack/plugins/transactions.py index c973a9d..e22a130 100644 --- a/src/synack/plugins/transactions.py +++ b/src/synack/plugins/transactions.py @@ -13,11 +13,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for plugin in ['Api']: setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) def get_balance(self): """Get your current account balance and requested payout values""" - res = self.api.request('HEAD', 'transactions') + res = self._api.request('HEAD', 'transactions') if res.status_code == 200: return json.loads(res.headers.get('x-balance')) + elif res.status_code == 403 and self._state.login: + self._auth.get_api_token() diff --git a/src/synack/plugins/users.py b/src/synack/plugins/users.py index 3388826..ea461df 100644 --- a/src/synack/plugins/users.py +++ b/src/synack/plugins/users.py @@ -11,12 +11,12 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for plugin in ['Api', 'Db']: setattr(self, - plugin.lower(), - self.registry.get(plugin)(self.state)) + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) def get_profile(self, user_id="me"): """Get a user's profile""" - res = self.api.request('GET', f'profiles/{user_id}') + res = self._api.request('GET', f'profiles/{user_id}') if res.status_code == 200: - self.db.user_id = res.json().get('user_id') + self._db.user_id = res.json().get('user_id') return res.json() diff --git a/src/synack/plugins/utils.py b/src/synack/plugins/utils.py new file mode 100644 index 0000000..0bed3fc --- /dev/null +++ b/src/synack/plugins/utils.py @@ -0,0 +1,24 @@ +"""plugins/utils.py + +Defines utility methods used in other plugins +""" + +from .base import Plugin + +import re + + +class Utils(Plugin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for plugin in []: + setattr(self, + '_'+plugin.lower(), + self._registry.get(plugin)(self._state)) + + @staticmethod + def get_html_tag_value(field, text): + match = re.search(f'<[^>]*name=.{field}.[^>]*(?:content|value)=.([^"\']*)', text) + if match.group is None: + match = re.search(f'<[^>]*(?:content|value)=.([^"\']*)[^>]*name=.{field}', text) + return match.group(1) if match else '' diff --git a/test/test_alerts.py b/test/test_alerts.py index 17086b2..47ccb14 100644 --- a/test/test_alerts.py +++ b/test/test_alerts.py @@ -17,27 +17,28 @@ class AlertsTestCase(unittest.TestCase): def setUp(self): self.state = synack._state.State() + self.state._db = MagicMock() self.alerts = synack.plugins.Alerts(self.state) - self.alerts.db = MagicMock() + self.alerts._db = MagicMock() def test_email_no_tls(self): """Should send a non-TLS encrypted email""" - self.alerts.db.smtp_starttls = False - self.alerts.db.smtp_server = 'smtp.email.com' - self.alerts.db.smtp_port = 587 - self.alerts.db.smtp_username = 'user5' - self.alerts.db.smtp_password = 'password123' + self.alerts._state.smtp_starttls = False + self.alerts._state.smtp_server = 'smtp.email.com' + self.alerts._state.smtp_port = 587 + self.alerts._state.smtp_username = 'user5' + self.alerts._state.smtp_password = 'password123' with patch('smtplib.SMTP') as mock_smtp: self.alerts.email('subject', 'body') mock_smtp.assert_called_with('smtp.email.com', 587) def test_email_tls(self): """Should send a TLS encrypted email""" - self.alerts.db.smtp_starttls = True - self.alerts.db.smtp_server = 'smtp.email.com' - self.alerts.db.smtp_port = 465 - self.alerts.db.smtp_username = 'user5' - self.alerts.db.smtp_password = 'password123' + self.alerts._state.smtp_starttls = True + self.alerts._state.smtp_server = 'smtp.email.com' + self.alerts._state.smtp_port = 465 + self.alerts._state.smtp_username = 'user5' + self.alerts._state.smtp_password = 'password123' with patch('smtplib.SMTP_SSL') as mock_smtp: with patch('email.message.EmailMessage') as mock_msg: with patch('datetime.datetime') as mock_dt: @@ -85,8 +86,13 @@ def test_sanitize_urls(self): def test_slack(self): """Should POST a message to slack""" with patch('requests.post') as mock_post: - self.alerts.db.slack_url = 'https://slack.com' + self.alerts._state.slack_channel = 'myslackchannel' + self.alerts._state.slack_app_token = '1234' self.alerts.slack('this is a test') - mock_post.assert_called_with('https://slack.com', - data='{"text": "this is a test"}', - headers={'Content-Type': 'application/json'}) + mock_post.assert_called_with('https://slack.com/api/chat.postMessage', + data='{"text": "this is a test", "channel": "myslackchannel"}', + headers={ + 'Authorization': 'Bearer 1234', + 'Content-Type': 'application/json' + }, + verify=False) diff --git a/test/test_api.py b/test/test_api.py index 1c4c1b8..62db387 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -17,9 +17,10 @@ class ApiTestCase(unittest.TestCase): def setUp(self): self.state = synack._state.State() + self.state._db = MagicMock() self.api = synack.plugins.Api(self.state) - self.api.debug = MagicMock() - self.api.db = MagicMock() + self.api._debug = MagicMock() + self.api._db = MagicMock() def test_login_full_path(self): """Login Base URL should prepend and request should be made""" @@ -40,7 +41,7 @@ def test_notification_bad_token(self): """Notifications token should be obtained if it doesn't exist""" self.api.request = MagicMock() self.api.request.return_value.status_code = 422 - self.api.db.notifications_token = "bad_token" + self.api._state.notifications_token = "bad_token" url = 'https://notifications.synack.com/api/v2/test' headers = {"Authorization": "Bearer bad_token"} self.api.notifications('GET', 'test') @@ -51,7 +52,7 @@ def test_notification_bad_token(self): def test_notification_full_path(self): """Notifications Base URL should prepend and request should be made""" self.api.request = MagicMock() - self.api.db.notifications_token = "something" + self.api._state.notifications_token = "something" headers = {"Authorization": "Bearer something"} url = 'http://www.google.com/api/test' self.api.notifications('GET', url) @@ -62,13 +63,13 @@ def test_notification_full_path(self): def test_notification_no_token(self): """Notifications token should be obtained if it doesn't exist""" self.api.request = MagicMock() - self.api.db.notifications_token = "" + self.api._state.notifications_token = "" self.api.notifications('GET', 'test') def test_notification_path(self): """Notifications Base URL should prepend and request should be made""" self.api.request = MagicMock() - self.api.db.notifications_token = "something" + self.api._state.notifications_token = "something" headers = {"Authorization": "Bearer something"} url = 'https://notifications.synack.com/api/v2/test' self.api.notifications('GET', 'test') @@ -78,64 +79,64 @@ def test_notification_path(self): def test_request_full_url(self): """Base URL should not be added if a full url is passed""" - self.api.state.session.get = MagicMock() - self.api.db.use_proxies = False - self.api.db.user_id = "paco" - self.api.db.api_token = "12345" + self.api._state.session.get = MagicMock() + self.api._state.use_proxies = False + self.api._state.user_id = "paco" + self.api._state.api_token = "12345" headers = { 'Authorization': 'Bearer 12345', 'user_id': 'paco' } - url = 'http://www.google.com/api/test' + url = 'http://www.synack.com/api/test' self.api.request('GET', url) - self.api.state.session.get.assert_called_with(url, - headers=headers, - proxies=None, - params=None, - verify=True) + self.api._state.session.get.assert_called_with(url, + headers=headers, + proxies=None, + params=None, + verify=True) def test_request_get(self): """GET requests should work""" - self.api.state.session.get = MagicMock() - self.api.db.use_proxies = False - self.api.db.user_id = "paco" - self.api.db.api_token = "12345" + self.api._state.session.get = MagicMock() + self.api._state.use_proxies = False + self.api._state.user_id = "paco" + self.api._state.api_token = "12345" headers = { 'Authorization': 'Bearer 12345', 'user_id': 'paco' } url = 'https://platform.synack.com/api/test' self.api.request('GET', 'test') - self.api.state.session.get.assert_called_with(url, - headers=headers, - proxies=None, - params=None, - verify=True) + self.api._state.session.get.assert_called_with(url, + headers=headers, + proxies=None, + params=None, + verify=True) def test_request_head(self): """HEAD requests should work""" - self.api.state.session.head = MagicMock() - self.api.db.use_proxies = False - self.api.db.user_id = "paco" - self.api.db.api_token = "12345" + self.api._state.session.head = MagicMock() + self.api._state.use_proxies = False + self.api._state.user_id = "paco" + self.api._state.api_token = "12345" headers = { 'Authorization': 'Bearer 12345', 'user_id': 'paco' } url = 'https://platform.synack.com/api/test' self.api.request('HEAD', 'test') - self.api.state.session.head.assert_called_with(url, - headers=headers, - proxies=None, - params=None, - verify=True) + self.api._state.session.head.assert_called_with(url, + headers=headers, + proxies=None, + params=None, + verify=True) def test_request_header_kwargs(self): """requests should merge in kwargs headers""" - self.api.state.session.get = MagicMock() - self.api.db.use_proxies = False - self.api.db.user_id = "paco" - self.api.db.api_token = "12345" + self.api._state.session.get = MagicMock() + self.api._state.use_proxies = False + self.api._state.user_id = "paco" + self.api._state.api_token = "12345" headers = { 'Authorization': 'Bearer 12345', 'user_id': 'paco', @@ -143,20 +144,20 @@ def test_request_header_kwargs(self): } url = 'https://platform.synack.com/api/test' self.api.request('GET', 'test', headers={'test': 'test'}) - self.api.state.session.get.assert_called_with(url, - headers=headers, - proxies=None, - params=None, - verify=True) + self.api._state.session.get.assert_called_with(url, + headers=headers, + proxies=None, + params=None, + verify=True) def test_request_logged(self): """All requests should call the logger""" - self.api.state.session.get = MagicMock() - self.api.state.session.get.return_value.status_code = 200 - self.api.state.session.get.return_value.content = "Returned Content" - self.api.db.use_proxies = False - self.api.db.user_id = "paco" - self.api.db.api_token = "12345" + self.api._state.session.get = MagicMock() + self.api._state.session.get.return_value.status_code = 200 + self.api._state.session.get.return_value.content = "Returned Content" + self.api._state.use_proxies = False + self.api._state.user_id = "paco" + self.api._state.api_token = "12345" headers = { 'Authorization': 'Bearer 12345', 'user_id': 'paco' @@ -167,45 +168,45 @@ def test_request_logged(self): "\n\tQuery: None" + \ "\n\tData: None" + \ "\n\tContent: Returned Content" - self.api.debug.log.assert_called_with("Network Request", message) + self.api._debug.log.assert_called_with("Network Request", message) def test_request_patch(self): """PATCH requests should work""" - self.api.state.session.patch = MagicMock() + self.api._state.session.patch = MagicMock() data = {'test': 'test'} - self.api.db.use_proxies = False - self.api.db.user_id = "paco" - self.api.db.api_token = "12345" + self.api._state.use_proxies = False + self.api._state.user_id = "paco" + self.api._state.api_token = "12345" url = 'https://platform.synack.com/api/test' headers = { 'Authorization': 'Bearer 12345', 'user_id': 'paco' } self.api.request('PATCH', 'test', data=data) - self.api.state.session.patch.assert_called_with(url, - json=data, - headers=headers, - proxies=None, - verify=True) + self.api._state.session.patch.assert_called_with(url, + json=data, + headers=headers, + proxies=None, + verify=True) def test_request_post(self): """POST requests should work""" - self.api.state.session.post = MagicMock() + self.api._state.session.post = MagicMock() data = {'test': 'test'} - self.api.db.use_proxies = False - self.api.db.user_id = "paco" - self.api.db.api_token = "12345" + self.api._state.use_proxies = False + self.api._state.user_id = "paco" + self.api._state.api_token = "12345" url = 'https://platform.synack.com/api/test' headers = { 'Authorization': 'Bearer 12345', 'user_id': 'paco' } self.api.request('POST', 'test', data=data) - self.api.state.session.post.assert_called_with(url, - json=data, - headers=headers, - proxies=None, - verify=True) + self.api._state.session.post.assert_called_with(url, + json=data, + headers=headers, + proxies=None, + verify=True) def test_request_proxies(self): """Proxies should be used if set""" @@ -213,38 +214,39 @@ def test_request_proxies(self): 'http': 'http://127.0.0.1:8080', 'https': 'http://127.0.0.1:8080', } - self.api.db.user_id = "paco" - self.api.db.api_token = "12345" + self.api._state.user_id = "paco" + self.api._state.api_token = "12345" headers = { 'Authorization': 'Bearer 12345', 'user_id': 'paco' } url = 'https://platform.synack.com/api/test' - self.api.state.session.get = MagicMock() - self.api.db.use_proxies = True - self.api.db.proxies = proxies + self.api._state.session.get = MagicMock() + self.api._state.use_proxies = True + self.api._state.http_proxy = proxies.get('http') + self.api._state.https_proxy = proxies.get('https') self.api.request('GET', 'test') - self.api.state.session.get.assert_called_with(url, - headers=headers, - proxies=proxies, - params=None, - verify=False) + self.api._state.session.get.assert_called_with(url, + headers=headers, + proxies=proxies, + params=None, + verify=False) def test_request_put(self): """PUT requests should work""" - self.api.state.session.put = MagicMock() + self.api._state.session.put = MagicMock() data = {'test': 'test'} - self.api.db.use_proxies = False - self.api.db.user_id = "paco" - self.api.db.api_token = "12345" + self.api._state.use_proxies = False + self.api._state.user_id = "paco" + self.api._state.api_token = "12345" url = 'https://platform.synack.com/api/test' headers = { 'Authorization': 'Bearer 12345', 'user_id': 'paco' } self.api.request('PUT', 'test', data=data) - self.api.state.session.put.assert_called_with(url, - params=data, - headers=headers, - proxies=None, - verify=True) + self.api._state.session.put.assert_called_with(url, + headers=headers, + proxies=None, + params=data, + verify=True) diff --git a/test/test_auth.py b/test/test_auth.py index 6b029bf..cdc3d98 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -5,7 +5,6 @@ import os import pathlib -import pyotp import sys import unittest @@ -19,122 +18,68 @@ class AuthTestCase(unittest.TestCase): def setUp(self): self.state = synack._state.State() + self.state._db = MagicMock() self.auth = synack.plugins.Auth(self.state) - self.auth.api = MagicMock() - self.auth.db = MagicMock() - self.auth.users = MagicMock() - - def test_build_otp(self): - """Should generate a OTP""" - pyotp.TOTP = MagicMock() - self.auth.db.otp_secret = "123" - self.auth.build_otp() - self.assertEqual(7, pyotp.TOTP.return_value.digits) - self.assertEqual(10, pyotp.TOTP.return_value.interval) - self.assertEqual('synack', pyotp.TOTP.return_value.issuer) - pyotp.TOTP.assert_called_with('123') - pyotp.TOTP.return_value.now.assert_called_with() + self.auth._api = MagicMock() + self.auth._db = MagicMock() + self.auth._users = MagicMock() + self.auth._duo = MagicMock() def test_get_api_token(self): """Should complete the login workflow when check fails""" - self.auth.db.api_token = "" + self.auth._state.api_token = "" self.auth.set_login_script = MagicMock() - self.auth.users.get_profile = MagicMock() - self.auth.users.get_profile.return_value = None + self.auth.get_authentication_response = MagicMock() + self.auth.get_authentication_response.return_value = { + 'duo_auth_url': 'https://duoauth.local' + } + self.auth._users.get_profile = MagicMock() + self.auth._users.get_profile.return_value = None self.auth.get_login_csrf = MagicMock(return_value="csrf_fwlnm") - self.auth.get_login_progress_token = MagicMock() - self.auth.get_login_progress_token.return_value = "pt_rsaemnt" - self.auth.get_login_grant_token = MagicMock(return_value="gt_fwlnm") - self.auth.api.request.return_value.status_code = 200 + self.auth._api.request.return_value.status_code = 200 ret_json = {"access_token": "api_lwfaume"} - self.auth.api.request.return_value.json.return_value = ret_json + self.auth._api.request.return_value.json.return_value = ret_json self.assertEqual("api_lwfaume", self.auth.get_api_token()) self.auth.get_login_csrf.assert_called_with() self.auth.set_login_script.assert_called_with() - self.auth.get_login_progress_token.assert_called_with("csrf_fwlnm") - self.auth.get_login_grant_token.assert_called_with("csrf_fwlnm", - "pt_rsaemnt") + self.auth.get_authentication_response.assert_called_with('csrf_fwlnm') def test_get_api_token_login_success(self): """Should return the database token when check succeeds""" - self.auth.db.api_token = "qweqweqwe" + self.auth._state.api_token = "qweqweqwe" self.auth.set_login_script = MagicMock() - self.auth.users.get_profile = MagicMock() - self.auth.users.get_profile.return_value = {"user_id": "john"} + self.auth._users.get_profile = MagicMock() + self.auth._users.get_profile.return_value = {"user_id": "john"} self.assertEqual("qweqweqwe", self.auth.get_api_token()) - def test_get_login_grant_token(self): - """Should get the grant token from valid authy TOTP""" - self.auth.build_otp = MagicMock(return_value="12345") - self.auth.api.login.return_value.status_code = 200 - self.auth.api.login.return_value.json.return_value = { - "grant_token": "qwfars" - } - headers = { - "X-Csrf-Token": "abcde" - } - data = { - "authy_token": "12345", - "progress_token": "789456123" - } - - returned_gt = self.auth.get_login_grant_token('abcde', '789456123') - self.assertEqual("qwfars", returned_gt) - self.auth.api.login.assert_called_with("POST", - "authenticate", - headers=headers, - data=data) - - def test_get_login_progress_token(self): - """Should get the progress token from valid creds""" - self.auth.api.login.return_value.status_code = 200 - self.auth.api.login.return_value.json.return_value = { - "progress_token": "qwfars" - } - data = { - "email": "bob@bob.com", - "password": "123456" - } - headers = { - "X-CSRF-Token": "abcde" - } - self.auth.db.email = "bob@bob.com" - self.auth.db.password = "123456" - returned_pt = self.auth.get_login_progress_token('abcde') - self.assertEqual("qwfars", returned_pt) - self.auth.api.login.assert_called_with("POST", - "authenticate", - headers=headers, - data=data) - def test_get_notifications_token(self): """Should get the notifications token""" - self.auth.db.notifications_token = "" - self.auth.api.request.return_value.status_code = 200 + self.auth._db.notifications_token = "" + self.auth._api.request.return_value.status_code = 200 ret_value = {"token": "12345"} - self.auth.api.request.return_value.json.return_value = ret_value + self.auth._api.request.return_value.json.return_value = ret_value self.assertEqual("12345", self.auth.get_notifications_token()) - self.assertEqual("12345", self.auth.db.notifications_token) - self.auth.api.request.assert_called_with("GET", - "users/notifications_token") - self.auth.api.request.return_value.json.assert_called_with() + self.assertEqual("12345", self.auth._db.notifications_token) + self.auth._api.request.assert_called_with("GET", + "users/notifications_token") + self.auth._api.request.return_value.json.assert_called_with() def test_login_csrf(self): """Should get the login csrf token""" ret_text = '= 20 characters""" @@ -397,15 +398,15 @@ def test_set_evidences_unsafe(self): "validResponses": [{}, {"value": "uieth8rgyub"}], "listingCodename": "SLAPPYMONKEY" } - self.missions.templates.get_template = MagicMock() - self.missions.templates.get_template.return_value = template + self.missions._templates.get_template = MagicMock() + self.missions._templates.get_template.return_value = template self.missions.get_evidences = MagicMock() self.missions.get_evidences.return_value = curr - self.missions.api.request = MagicMock() - self.missions.api.request.return_value.status_code = 200 - self.missions.api.request.return_value.json.return_value = {} + self.missions._api.request = MagicMock() + self.missions._api.request.return_value.status_code = 200 + self.missions._api.request.return_value.json.return_value = {} self.missions.set_evidences(mission) - self.missions.api.request.assert_not_called() + self.missions._api.request.assert_not_called() def test_set_status(self): """Should interact with a mission""" @@ -424,7 +425,7 @@ def test_set_status(self): "status": "CLAIM", "success": True } - self.missions.api.request.return_value.status_code = 201 + self.missions._api.request.return_value.status_code = 201 self.assertEqual(ret, self.missions.set_status(m, "CLAIM")) data = {"type": "CLAIM"} calls = [ @@ -445,4 +446,4 @@ def test_set_status(self): '/transitions', data=data) ] - self.missions.api.request.has_calls(calls) + self.missions._api.request.has_calls(calls) diff --git a/test/test_notifications.py b/test/test_notifications.py index 24d9be8..ad9e6f6 100644 --- a/test/test_notifications.py +++ b/test/test_notifications.py @@ -17,27 +17,28 @@ class NotificationsTestCase(unittest.TestCase): def setUp(self): self.state = synack._state.State() + self.state._db = MagicMock() self.notifications = synack.plugins.Notifications(self.state) - self.notifications.api = MagicMock() - self.notifications.db = MagicMock() + self.notifications._api = MagicMock() + self.notifications._db = MagicMock() def test_get(self): """Should get a list of notifications""" - self.notifications.api.notifications.return_value.status_code = 200 - self.notifications.api.notifications.return_value.json.return_value = {"one": "1"} + self.notifications._api.notifications.return_value.status_code = 200 + self.notifications._api.notifications.return_value.json.return_value = {"one": "1"} path = "notifications?meta=1" self.assertEqual({"one": "1"}, self.notifications.get()) - self.notifications.api.notifications.assert_called_with("GET", path) + self.notifications._api.notifications.assert_called_with("GET", path) def test_get_unread_count(self): """Should get the number of unread notifications""" - self.notifications.api.notifications.return_value.status_code = 200 - self.notifications.api.notifications.return_value.json.return_value = {"one": "1"} - self.notifications.db.notifications_token = "good_token" + self.notifications._api.notifications.return_value.status_code = 200 + self.notifications._api.notifications.return_value.json.return_value = {"one": "1"} + self.notifications._state.notifications_token = "good_token" query = { "authorization_token": "good_token" } path = "notifications/unread_count" self.assertEqual({"one": "1"}, self.notifications.get_unread_count()) - self.notifications.api.notifications.assert_called_with("GET", path, - query=query) + self.notifications._api.notifications.assert_called_with("GET", path, + query=query) diff --git a/test/test_scratchspace.py b/test/test_scratchspace.py index c167f89..484e228 100644 --- a/test/test_scratchspace.py +++ b/test/test_scratchspace.py @@ -1,6 +1,6 @@ """test_scratchspace.py -Tests for the plugins/scratchspace.py Db class +Tests for the plugins/scratchspace.py Scratchspace class """ import os @@ -18,17 +18,18 @@ class ScratchspaceTestCase(unittest.TestCase): def setUp(self): self.state = synack._state.State() + self.state._db = MagicMock() self.scratchspace = synack.plugins.Scratchspace(self.state) def test_build_filepath_codename(self): """Should build the appropriate scratchspace filepath given a codename""" - self.scratchspace.db.scratchspace_dir = pathlib.Path('/tmp') + self.scratchspace._state.scratchspace_dir = pathlib.Path('/tmp') ret = self.scratchspace.build_filepath('test.txt', codename='TIREDTURKEY') self.assertEqual(pathlib.Path('/tmp/TIREDTURKEY/test.txt'), ret) def test_build_filepath_target(self): """Should build the appropriate scratchspace filepath given a filepath""" - self.scratchspace.db.scratchspace_dir = pathlib.Path('/tmp') + self.scratchspace._state.scratchspace_dir = pathlib.Path('/tmp') target = synack.db.models.Target(codename='TIREDTURKEY') ret = self.scratchspace.build_filepath('test.txt', target=target) self.assertEqual(pathlib.Path('/tmp/TIREDTURKEY/test.txt'), ret) @@ -82,9 +83,9 @@ def test_set_download_attachments_codename(self): dest_path = pathlib.Path('/tmp/TIREDTURKEY/burp.txt') self.scratchspace.build_filepath = MagicMock() self.scratchspace.build_filepath.return_value = dest_path - self.scratchspace.api.request = MagicMock() - self.scratchspace.api.request.return_value.status_code = 200 - self.scratchspace.api.request.return_value.content = b'file_content' + self.scratchspace._api.request = MagicMock() + self.scratchspace._api.request.return_value.status_code = 200 + self.scratchspace._api.request.return_value.content = b'file_content' m = mock_open() attachments = [ {'slug': '43i7h', 'filename': 'file1.txt', 'url': 'https://downloads.com/xyzf'} @@ -101,9 +102,9 @@ def test_set_download_attachments_prompt_overwrite(self, input_mock): dest_path = pathlib.Path('/tmp/TIREDTURKEY/burp.txt') self.scratchspace.build_filepath = MagicMock() self.scratchspace.build_filepath.return_value = dest_path - self.scratchspace.api.request = MagicMock() - self.scratchspace.api.request.return_value.status_code = 200 - self.scratchspace.api.request.return_value.content = b'file_content' + self.scratchspace._api.request = MagicMock() + self.scratchspace._api.request.return_value.status_code = 200 + self.scratchspace._api.request.return_value.content = b'file_content' attachments = [ {'slug': '43i7h', 'filename': 'file1.txt', 'url': 'https://downloads.com/xyzf'} ] diff --git a/test/test_state.py b/test/test_state.py index 70460c1..84c64e6 100644 --- a/test/test_state.py +++ b/test/test_state.py @@ -9,6 +9,8 @@ import pathlib import requests +from unittest.mock import MagicMock + sys.path.insert(0, os.path.abspath(os.path.join(__file__, '../../src'))) import synack # noqa: E402 @@ -17,6 +19,7 @@ class StateTestCase(unittest.TestCase): def setUp(self): self.state = synack._state.State() + self.state._db = MagicMock() def test_config_dir(self): self.assertEqual(pathlib.PosixPath, type(self.state.config_dir)) @@ -31,32 +34,32 @@ def test_config_dir(self): self.state._config_dir) def test_debug(self): - self.assertEqual(None, self.state.debug) + self.assertEqual(self.state._db.debug, self.state.debug) self.assertEqual(None, self.state._debug) - self.state.debug = True - self.assertEqual(True, self.state.debug) + self.state._debug = True + self.assertEqual(True, self.state._debug) self.assertEqual(True, self.state._debug) def test_email(self): - self.assertEqual(None, self.state.email) + self.assertEqual(self.state._db.email, self.state.email) self.assertEqual(None, self.state._email) self.state.email = '1@2.com' self.assertEqual('1@2.com', self.state.email) self.assertEqual('1@2.com', self.state._email) def test_http_proxy(self): - self.assertEqual(None, self.state.http_proxy) + self.assertEqual(self.state._db.http_proxy, self.state.http_proxy) self.assertEqual(None, self.state._http_proxy) self.state.http_proxy = 'http://1.1.1.1:1234' self.assertEqual('http://1.1.1.1:1234', self.state._http_proxy) self.assertEqual('http://1.1.1.1:1234', self.state.http_proxy) self.assertEqual(self.state.proxies, { 'http': 'http://1.1.1.1:1234', - 'https': None + 'https': self.state._db.https_proxy }) def test_https_proxy(self): - self.assertEqual(None, self.state.https_proxy) + self.assertEqual(self.state._db.https_proxy, self.state.https_proxy) self.assertEqual(None, self.state._https_proxy) self.state.https_proxy = 'http://1.1.1.1:1234' self.assertEqual('http://1.1.1.1:1234', self.state.https_proxy) @@ -70,14 +73,14 @@ def test_login(self): self.assertEqual(False, self.state._login) def test_otp_secret(self): - self.assertEqual(None, self.state.otp_secret) + self.assertEqual(self.state._db.otp_secret, self.state.otp_secret) self.assertEqual(None, self.state._otp_secret) self.state.otp_secret = '12345' self.assertEqual('12345', self.state.otp_secret) self.assertEqual('12345', self.state._otp_secret) def test_password(self): - self.assertEqual(None, self.state.password) + self.assertEqual(self.state._db.password, self.state.password) self.assertEqual(None, self.state._password) self.state.password = 'password1234' self.assertEqual('password1234', self.state.password) @@ -85,13 +88,13 @@ def test_password(self): def test_proxies(self): self.assertEqual(self.state.proxies, { - 'http': None, - 'https': None + 'http': self.state._db.http_proxy, + 'https': self.state._db.https_proxy }) self.state.http_proxy = 'http://2.2.2.2:1234' self.assertEqual(self.state.proxies, { 'http': 'http://2.2.2.2:1234', - 'https': None + 'https': self.state._db.https_proxy }) self.state.https_proxy = 'http://1.1.1.1:1234' self.assertEqual(self.state.proxies, { @@ -100,7 +103,7 @@ def test_proxies(self): }) def test_scratchspace_dir(self): - self.assertEqual(None, self.state.scratchspace_dir) + self.assertEqual(self.state._db.scratchspace_dir, self.state.scratchspace_dir) self.assertEqual(None, self.state._scratchspace_dir) self.state.scratchspace_dir = "/tmp" self.assertEqual(pathlib.PosixPath, type(self.state.scratchspace_dir)) @@ -114,7 +117,7 @@ def test_session(self): self.assertEqual(requests.sessions.Session, type(self.state._session)) def test_template_dir(self): - self.assertEqual(None, self.state.template_dir) + self.assertEqual(self.state._db.template_dir, self.state.template_dir) self.assertEqual(None, self.state._template_dir) self.state.template_dir = "/tmp" self.assertEqual(pathlib.PosixPath, type(self.state.template_dir)) @@ -124,21 +127,21 @@ def test_template_dir(self): self.state._template_dir) def test_use_proxies(self): - self.assertEqual(None, self.state.use_proxies) + self.assertEqual(self.state._db.use_proxies, self.state.use_proxies) self.assertEqual(None, self.state._use_proxies) self.state.use_proxies = True self.assertEqual(True, self.state.use_proxies) self.assertEqual(True, self.state._use_proxies) def test_user_id(self): - self.assertEqual(None, self.state.user_id) + self.assertEqual(self.state._db.user_id, self.state.user_id) self.assertEqual(None, self.state._user_id) self.state.user_id = '12345' self.assertEqual('12345', self.state.user_id) self.assertEqual('12345', self.state._user_id) def test_use_scratchspace(self): - self.assertEqual(None, self.state.use_scratchspace) + self.assertEqual(self.state._db.use_scratchspace, self.state.use_scratchspace) self.assertEqual(None, self.state._use_scratchspace) self.state.use_scratchspace = True self.assertEqual(True, self.state.use_scratchspace) diff --git a/test/test_targets.py b/test/test_targets.py index 9af76f6..d4e75c0 100644 --- a/test/test_targets.py +++ b/test/test_targets.py @@ -18,30 +18,31 @@ class TargetsTestCase(unittest.TestCase): def setUp(self): self.state = synack._state.State() + self.state._db = MagicMock() self.targets = synack.plugins.Targets(self.state) - self.targets.api = MagicMock() - self.targets.db = MagicMock() + self.targets._api = MagicMock() + self.targets._db = MagicMock() self.targets.scratchspace = MagicMock() self.maxDiff = None def test_build_codename_from_slug(self): """Should return a codename for a given slug""" ret_targets = [Target(codename="SLOPPYSLUG")] - self.targets.db.find_targets.return_value = ret_targets + self.targets._db.find_targets.return_value = ret_targets self.assertEqual("SLOPPYSLUG", self.targets.build_codename_from_slug("qwfars")) - self.targets.db.find_targets.assert_called_with(slug="qwfars") + self.targets._db.find_targets.assert_called_with(slug="qwfars") def test_build_codename_from_slug_invalid(self): """Should return NONE if non-real slug""" - self.targets.db.find_targets.return_value = [] + self.targets._db.find_targets.return_value = [] self.assertEqual("NONE", self.targets.build_codename_from_slug("qwfars")) - self.targets.db.find_targets.assert_called_with(slug="qwfars") + self.targets._db.find_targets.assert_called_with(slug="qwfars") def test_build_codename_from_slug_no_targets(self): """Should update the targets if empty""" - self.targets.db.find_targets.side_effect = [ + self.targets._db.find_targets.side_effect = [ [], [Target(codename="SLOPPYSLUG")] ] @@ -52,7 +53,7 @@ def test_build_codename_from_slug_no_targets(self): self.targets.get_registered_summary = MagicMock() self.assertEqual("SLOPPYSLUG", self.targets.build_codename_from_slug("qwfars")) - self.targets.db.find_targets.assert_has_calls(calls) + self.targets._db.find_targets.assert_has_calls(calls) self.targets.get_registered_summary.assert_called_with() def test_build_scope_host_db(self): @@ -163,14 +164,14 @@ def test_build_scope_web_db(self): def test_build_slug_from_codename(self): """Should return a slug for a given codename""" ret_targets = [Target(slug="qwerty")] - self.targets.db.find_targets.return_value = ret_targets + self.targets._db.find_targets.return_value = ret_targets self.assertEqual("qwerty", self.targets.build_slug_from_codename("qwerty")) - self.targets.db.find_targets.assert_called_with(codename="qwerty") + self.targets._db.find_targets.assert_called_with(codename="qwerty") def test_build_slug_from_codename_no_targets(self): """Should update the targets if empty""" - self.targets.db.find_targets.side_effect = [ + self.targets._db.find_targets.side_effect = [ [], [Target(slug="qwerty")] ] @@ -182,7 +183,7 @@ def test_build_slug_from_codename_no_targets(self): slug = self.targets.build_slug_from_codename("CHONKEYMONKEY") self.assertEqual("qwerty", slug) - self.targets.db.find_targets.assert_has_calls(calls) + self.targets._db.find_targets.assert_has_calls(calls) self.targets.get_registered_summary.assert_called_with() def test_get_assessments_all_passed(self): @@ -210,32 +211,32 @@ def test_get_assessments_all_passed(self): } ] cat1 = synack.db.models.Category() - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = assessments - self.targets.db.categories = [cat1] + self.targets. _api.request.return_value.status_code = 200 + self.targets. _api.request.return_value.json.return_value = assessments + self.targets._db.categories = [cat1] self.assertEqual([cat1], self.targets.get_assessments()) - self.targets.db.add_categories.assert_called_with(assessments) + self.targets._db.add_categories.assert_called_with(assessments) def test_get_assets(self): """Should return a list of assets for a currently connected target""" self.targets.get_connected = MagicMock() self.targets.get_connected.return_value = {'codename': 'TURBULENTTORTOISE', 'slug': '327h8iw'} - self.targets.db.find_targets.return_value = [Target(slug='327h8iw')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.text = 'rettext' - self.targets.api.request.return_value.json.return_value = 'retjson' + self.targets._db.find_targets.return_value = [Target(slug='327h8iw')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.text = 'rettext' + self.targets._api.request.return_value.json.return_value = 'retjson' self.assertEqual('retjson', self.targets.get_assets()) - self.targets.api.request.assert_called_with('GET', - 'asset/v2/assets?listingUid%5B%5D=327h8iw&scope%5B%5D=in' + - '&scope%5B%5D=discovered&sort%5B%5D=location&active=true' + - '&sortDir=asc&page=1&perPage=5000') + self.targets._api.request.assert_called_with('GET', + 'asset/v2/assets?listingUid%5B%5D=327h8iw&scope%5B%5D=in' + + '&scope%5B%5D=discovered&sort%5B%5D=location&active=true' + + '&sortDir=asc&page=1&perPage=5000') def test_get_assets_non_defaults(self): """Should return a list of assets given information to query""" - self.targets.db.find_targets.return_value = [Target(codename='TURBULENTTORTOISE', slug='327h8iw')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.text = 'rettext' - self.targets.api.request.return_value.json.return_value = 'retjson' + self.targets._db.find_targets.return_value = [Target(codename='TURBULENTTORTOISE', slug='327h8iw')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.text = 'rettext' + self.targets._api.request.return_value.json.return_value = 'retjson' self.assertEqual('retjson', self.targets.get_assets(codename='TURBULENTTORTOISE', asset_type='blah', host_type='cidr', @@ -246,11 +247,11 @@ def test_get_assets_non_defaults(self): page=3, perPage=50, organization_uid='uiehqw')) - self.targets.api.request.assert_called_with('GET', - 'asset/v2/assets?listingUid%5B%5D=327h8iw' + - '&organizationUid%5B%5D=uiehqw&assetType%5B%5D=blah' + - '&hostType%5B%5D=cidr&scope%5B%5D=secret' + - '&sort%5B%5D=loc&active=false&sortDir=desc&page=3&perPage=50') + self.targets._api.request.assert_called_with('GET', + 'asset/v2/assets?listingUid%5B%5D=327h8iw' + + '&organizationUid%5B%5D=uiehqw&assetType%5B%5D=blah' + + '&hostType%5B%5D=cidr&scope%5B%5D=secret' + + '&sort%5B%5D=loc&active=false&sortDir=desc&page=3&perPage=50') def test_get_attachments_current(self): """Should return a list of attachments based on currently selected target""" @@ -265,12 +266,12 @@ def test_get_attachments_current(self): ] self.targets.get_connected = MagicMock() self.targets.get_connected.return_value = {'codename': 'TASTYTACO', 'slug': 'u2ire'} - self.targets.db.find_targets = MagicMock() - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = attachments - self.assertEquals(self.targets.get_attachments(), attachments) - self.targets.api.request.assert_called_with('GET', 'targets/u2ire/resources') + self.targets._db.find_targets = MagicMock() + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = attachments + self.assertEqual(self.targets.get_attachments(), attachments) + self.targets._api.request.assert_called_with('GET', 'targets/u2ire/resources') def test_get_attachments_slug(self): """Should return a list of attachments given a slug""" @@ -283,12 +284,12 @@ def test_get_attachments_slug(self): "updated_at": 1667849178, } ] - self.targets.db.find_targets = MagicMock() - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = attachments - self.assertEquals(self.targets.get_attachments(slug='u2ire'), attachments) - self.targets.api.request.assert_called_with('GET', 'targets/u2ire/resources') + self.targets._db.find_targets = MagicMock() + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = attachments + self.assertEqual(self.targets.get_attachments(slug='u2ire'), attachments) + self.targets._api.request.assert_called_with('GET', 'targets/u2ire/resources') def test_get_attachments_target(self): """Should return a list of attachments given a Target""" @@ -301,15 +302,15 @@ def test_get_attachments_target(self): "updated_at": 1667849178, } ] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = attachments - self.assertEquals(self.targets.get_attachments(target=Target(slug='u2ire')), attachments) - self.targets.api.request.assert_called_with('GET', 'targets/u2ire/resources') + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = attachments + self.assertEqual(self.targets.get_attachments(target=Target(slug='u2ire')), attachments) + self.targets._api.request.assert_called_with('GET', 'targets/u2ire/resources') def test_get_connected(self): """Should make a request to get the currently selected target""" - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = { + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = { "slug": "qwfars", "status": "connected" } @@ -324,8 +325,8 @@ def test_get_connected(self): def test_get_connected_disconnected(self): """Should report Not Connected when not connected to a target""" - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = { + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = { "slug": "", "status": "connected" } @@ -352,13 +353,13 @@ def test_get_connections(self): "current_connections": 5 } } - self.targets.db.find_targets = MagicMock() - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = return_data - self.assertEquals(self.targets.get_connections(slug='u2ire'), connections) - self.targets.api.request.assert_called_with('GET', 'listing_analytics/connections', - query={"listing_id": "u2ire"}) + self.targets._db.find_targets = MagicMock() + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = return_data + self.assertEqual(self.targets.get_connections(slug='u2ire'), connections) + self.targets._api.request.assert_called_with('GET', 'listing_analytics/connections', + query={"listing_id": "u2ire"}) def test_get_connections_no_args(self): """Should return a summary of the lifetime and current connections if no args provided""" @@ -374,36 +375,37 @@ def test_get_connections_no_args(self): "current_connections": 5 } } - self.targets.db.find_targets = MagicMock() + self.targets._db.find_targets = MagicMock() self.targets.get_connected = MagicMock() self.targets.get_connected.return_value = {'codename': 'TIREDTIGER', 'slug': 'u2ire'} - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = return_data - self.assertEquals(self.targets.get_connections(), connections) + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = return_data + self.assertEqual(self.targets.get_connections(), connections) self.targets.get_connected.assert_called_with() - self.targets.api.request.assert_called_with('GET', 'listing_analytics/connections', - query={"listing_id": "u2ire"}) + self.targets._api.request.assert_called_with('GET', 'listing_analytics/connections', + query={"listing_id": "u2ire"}) def test_get_credentials(self): """Should get credentials for a given target""" target = Target(organization="qwewqe", slug="asdasd") - self.targets.db.find_targets = MagicMock() - self.targets.api = MagicMock() - self.targets.db.find_targets.return_value = [target] - self.targets.db.user_id = 'bobby' - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = "json_return" + self.targets._db.find_targets = MagicMock() + self.targets._api = MagicMock() + self.targets._db.find_targets.return_value = [target] + self.targets._db.user_id = 'bobby' + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = "json_return" + self.targets._state.user_id = 'bobby' url = 'asset/v1/organizations/qwewqe/owners/listings/asdasd/users/bobby/credentials' self.assertEqual("json_return", self.targets.get_credentials(codename='SLEEPYSLUG')) - self.targets.api.request.assert_called_with('POST', url) + self.targets._api.request.assert_called_with('POST', url) - def test_get_query(self): + def test_get(self): """Should get a list of targets""" - self.targets.db.categories = [ + self.targets._db.categories = [ Category(id=1, passed_practical=True, passed_written=True), Category(id=2, passed_practical=True, passed_written=True), Category(id=3, passed_practical=False, passed_written=False), @@ -414,37 +416,37 @@ def test_get_query(self): 'filter[industry]': 'all', 'filter[category][]': [1, 2] } - self.targets.api.request.return_value.status_code = 200 + self.targets._api.request.return_value.status_code = 200 results = [ { "codename": "SLEEPYSLUG", "slug": "1o2h8o" } ] - self.targets.api.request.return_value.json.return_value = results + self.targets._api.request.return_value.json.return_value = results self.assertEqual(results, self.targets.get_unregistered()) - self.targets.api.request.assert_called_with("GET", - "targets", - query=query) + self.targets._api.request.assert_called_with("GET", + "targets", + query=query) - def test_get_query_assessments_empty(self): + def test_get_assessments_empty(self): """Should get a list of unregistered targets""" self.targets.get_assessments = MagicMock() - self.targets.db.categories = [] + self.targets._db.categories = [] query = { 'filter[primary]': 'unregistered', 'filter[secondary]': 'all', 'filter[industry]': 'all', 'filter[category][]': [] } - self.targets.api.request.return_value.status_code = 200 + self.targets._api.request.return_value.status_code = 200 results = [] - self.targets.api.request.return_value.json.return_value = results + self.targets._api.request.return_value.json.return_value = results self.assertEqual(results, self.targets.get_unregistered()) self.targets.get_assessments.assert_called_with() - self.targets.api.request.assert_called_with("GET", - "targets", - query=query) + self.targets._api.request.assert_called_with("GET", + "targets", + query=query) def test_get_registered_summary(self): """Should make a request to get basic info about registered targets""" @@ -461,62 +463,38 @@ def test_get_registered_summary(self): "outage_windows": [], "vulnerability_discovery": True } - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = [t1] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = [t1] out = { "qwfars": t1 } path = 'targets/registered_summary' self.assertEqual(out, self.targets.get_registered_summary()) - self.targets.api.request.assert_called_with('GET', path) + self.targets._api.request.assert_called_with('GET', path) def test_get_scope_for_host(self): """Should get the scope for a Host when given Host information""" self.targets.get_scope_host = MagicMock() self.targets.get_scope_host.return_value = 'HostScope' tgt = Target(category=1) - self.targets.db.find_targets.return_value = [tgt] - self.targets.db.categories = [Category(id=1, name='Host')] + self.targets._db.find_targets.return_value = [tgt] + self.targets._db.categories = [Category(id=1, name='Host')] out = self.targets.get_scope(slug='1392g78yr') - self.targets.db.find_targets.assert_called_with(slug='1392g78yr') - self.targets.get_scope_host.assert_called_with(tgt, add_to_db=False) - self.assertEquals(out, 'HostScope') - - def test_get_scope_for_host_add_to_db(self): - """Should get the scope for a Host when given Host information""" - self.targets.get_scope_host = MagicMock() - self.targets.get_scope_host.return_value = 'HostScope' - tgt = Target(category=1) - self.targets.db.find_targets.return_value = [tgt] - self.targets.db.categories = [Category(id=1, name='Host')] - out = self.targets.get_scope(slug='1392g78yr', add_to_db=True) - self.targets.db.find_targets.assert_called_with(slug='1392g78yr') - self.targets.get_scope_host.assert_called_with(tgt, add_to_db=True) - self.assertEquals(out, 'HostScope') + self.targets._db.find_targets.assert_called_with(slug='1392g78yr') + self.targets.get_scope_host.assert_called_with(tgt) + self.assertEqual(out, 'HostScope') def test_get_scope_for_web(self): """Should get the scope for a Host when given Web information""" self.targets.get_scope_web = MagicMock() self.targets.get_scope_web.return_value = 'WebScope' tgt = Target(category=2) - self.targets.db.find_targets.return_value = [tgt] - self.targets.db.categories = [Category(id=2, name='Web Application')] + self.targets._db.find_targets.return_value = [tgt] + self.targets._db.categories = [Category(id=2, name='Web Application')] out = self.targets.get_scope(slug='1392g78yr') - self.targets.db.find_targets.assert_called_with(slug='1392g78yr') - self.targets.get_scope_web.assert_called_with(tgt, add_to_db=False) - self.assertEquals(out, 'WebScope') - - def test_get_scope_for_web_add_to_db(self): - """Should get the scope for a Host when given Web information""" - self.targets.get_scope_web = MagicMock() - self.targets.get_scope_web.return_value = 'WebScope' - tgt = Target(category=2) - self.targets.db.find_targets.return_value = [tgt] - self.targets.db.categories = [Category(id=2, name='Web Application')] - out = self.targets.get_scope(slug='1392g78yr', add_to_db=True) - self.targets.db.find_targets.assert_called_with(slug='1392g78yr') - self.targets.get_scope_web.assert_called_with(tgt, add_to_db=True) - self.assertEquals(out, 'WebScope') + self.targets._db.find_targets.assert_called_with(slug='1392g78yr') + self.targets.get_scope_web.assert_called_with(tgt) + self.assertEqual(out, 'WebScope') def test_get_scope_host(self): """Should get the scope for a Host""" @@ -532,33 +510,10 @@ def test_get_scope_host(self): 'location': '2.2.2.2/32' } ] - self.targets.db.find_targets.return_value = [Target(slug='213h89h3', codename='SASSYSQUIRREL')] + self.targets._db.find_targets.return_value = [Target(slug='213h89h3', codename='SASSYSQUIRREL')] out = self.targets.get_scope_host(codename='SASSYSQUIRREL') self.assertEqual(ips, out) - self.targets.db.find_targets.assert_called_with(codename='SASSYSQUIRREL') - - def test_get_scope_host_add_to_db(self): - """Should get the scope for a Host""" - ips = {'1.1.1.1/32', '2.2.2.2/32'} - self.targets.get_assets = MagicMock() - self.targets.get_assets.return_value = [ - { - 'active': True, - 'location': '1.1.1.1/32' - }, - { - 'active': True, - 'location': '2.2.2.2/32' - } - ] - self.targets.build_scope_host_db = MagicMock() - self.targets.build_scope_host_db.return_value = 'host_db_return_value' - self.targets.db.find_targets.return_value = [Target(slug='213h89h3', codename='SASSYSQUIRREL')] - out = self.targets.get_scope_host(codename='SASSYSQUIRREL', add_to_db=True) - self.assertEqual(ips, out) - self.targets.db.find_targets.assert_called_with(codename='SASSYSQUIRREL') - self.targets.build_scope_host_db.assert_called_with('213h89h3', ips) - self.targets.db.add_ips.assert_called_with('host_db_return_value') + self.targets._db.find_targets.assert_called_with(codename='SASSYSQUIRREL') def test_get_scope_host_current(self): """Should get the scope for the currenly connected Host if not specified""" @@ -576,11 +531,11 @@ def test_get_scope_host_current(self): 'location': '2.2.2.2/32' } ] - self.targets.db.find_targets.return_value = [Target(slug='213h89h3', codename='SASSYSQUIRREL')] + self.targets._db.find_targets.return_value = [Target(slug='213h89h3', codename='SASSYSQUIRREL')] out = self.targets.get_scope_host() self.assertEqual(ips, out) self.targets.get_connected.assert_called_with() - self.targets.db.find_targets.assert_called_with(slug='213h89h3') + self.targets._db.find_targets.assert_called_with(slug='213h89h3') def test_get_scope_host_not_ip(self): """Should get the scope for a Host""" @@ -596,23 +551,24 @@ def test_get_scope_host_not_ip(self): 'location': '8675309' } ] - self.targets.db.find_targets.return_value = [Target(slug='213h89h3', codename='SASSYSQUIRREL')] + self.targets._db.find_targets.return_value = [Target(slug='213h89h3', codename='SASSYSQUIRREL')] out = self.targets.get_scope_host(codename='SASSYSQUIRREL') self.assertEqual(ips, out) - self.targets.db.find_targets.assert_called_with(codename='SASSYSQUIRREL') + self.targets._db.find_targets.assert_called_with(codename='SASSYSQUIRREL') def test_get_scope_no_provided(self): """Should get the scope for the currently connected target if none is specified""" self.targets.get_connected = MagicMock() self.targets.get_connected.return_value = {'slug': 'test'} - self.targets.db.find_targets.return_value = None + self.targets._db.find_targets.return_value = None self.targets.get_scope() self.targets.get_connected.assert_called_with() - self.targets.db.find_targets.assert_called_with(slug='test') + self.targets._db.find_targets.assert_called_with(slug='test') def test_get_scope_web(self): """Should get the scope for a Web Application""" self.targets.build_scope_web_burp = MagicMock() + self.targets.build_scope_web_burp.return_value = 'burp_web_scope' scope = [{ 'listing': 'uewqhuiewq', 'location': 'https://good.things.com', @@ -631,46 +587,18 @@ def test_get_scope_web(self): } ] tgt = Target(slug='213h89h3', organization='93g8eh8', codename='SASSYSQUIRREL') - self.targets.db.find_targets.return_value = [tgt] + self.targets._db.find_targets.return_value = [tgt] + self.targets._state.use_scratchspace = True out = self.targets.get_scope_web(codename='SASSYSQUIRREL') self.assertEqual(scope, out) self.targets.build_scope_web_burp.assert_called_with(scope) - self.targets.db.find_targets.assert_called_with(codename='SASSYSQUIRREL') + self.targets._db.find_targets.assert_called_with(codename='SASSYSQUIRREL') self.targets.get_assets.assert_called_with(target=tgt, active='true', asset_type='webapp') - def test_get_scope_web_add_to_db(self): - """Should get the scope for a Web Application and add it to the database""" - self.targets.build_scope_web_burp = MagicMock() - self.targets.build_scope_web_db = MagicMock() - self.targets.get_assets = MagicMock() - scope = [{ - 'listing': 'uewqhuiewq', - 'location': 'https://good.things.com', - 'rule': '*.good.things.com/*', - 'status': 'in' - }] - self.targets.get_assets = MagicMock() - self.targets.get_assets.return_value = [ - { - 'active': True, - 'listings': [{'listingUid': 'uewqhuiewq', 'scope': 'in'}], - 'location': 'https://good.things.com (https://good.things.com)', - 'scopeRules': [ - {'rule': '*.good.things.com/*'} - ] - } - ] - tgt = Target(slug='213h89h3', organization='93g8eh8', codename='SASSYSQUIRREL') - self.targets.db.find_targets.return_value = [tgt] - out = self.targets.get_scope_web(codename='SASSYSQUIRREL', add_to_db=True) - self.assertEqual(scope, out) - self.targets.build_scope_web_burp.assert_called_with(scope) - self.targets.db.find_targets.assert_called_with(codename='SASSYSQUIRREL') - self.targets.db.add_urls.assert_called_with(self.targets.build_scope_web_db.return_value) - def test_get_scope_web_current(self): """Should get the scope for the currently connected Web Application if not specified""" self.targets.build_scope_web_burp = MagicMock() + self.targets.build_scope_web_burp.return_value = 'burp_formatted_scope' scope = [{ 'listing': 'uewqhuiewq', 'location': 'https://good.things.com', @@ -691,12 +619,13 @@ def test_get_scope_web_current(self): } ] tgt = Target(slug='213h89h3', organization='93g8eh8', codename='SASSYSQUIRREL') - self.targets.db.find_targets.return_value = [tgt] + self.targets._db.find_targets.return_value = [tgt] + self.targets._state.use_scratchspace = True out = self.targets.get_scope_web() self.assertEqual(scope, out) self.targets.build_scope_web_burp.assert_called_with(scope) self.targets.get_connected.assert_called_with() - self.targets.db.find_targets.assert_called_with(slug='93g8eg8') + self.targets._db.find_targets.assert_called_with(slug='93g8eg8') self.targets.get_assets.assert_called_with(target=tgt, active='true', asset_type='webapp') def test_get_submissions(self): @@ -715,13 +644,13 @@ def test_get_submissions(self): ] }] } - self.targets.db.find_targets = MagicMock() - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = return_data - self.assertEquals(self.targets.get_submissions(slug='u2ire'), return_data["value"]) - self.targets.api.request.assert_called_with('GET', 'listing_analytics/categories', - query={"listing_id": "u2ire", "status": "accepted"}) + self.targets._db.find_targets = MagicMock() + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = return_data + self.assertEqual(self.targets.get_submissions(slug='u2ire'), return_data["value"]) + self.targets._api.request.assert_called_with('GET', 'listing_analytics/categories', + query={"listing_id": "u2ire", "status": "accepted"}) def test_get_submissions_invalid_status(self): """Should return an empty dictionary if status is invalid""" @@ -739,11 +668,11 @@ def test_get_submissions_invalid_status(self): ] }] } - self.targets.db.find_targets = MagicMock() - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = return_data - self.assertEquals(self.targets.get_submissions(slug='u2ire', status="bad_status"), []) + self.targets._db.find_targets = MagicMock() + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = return_data + self.assertEqual(self.targets.get_submissions(slug='u2ire', status="bad_status"), []) def test_get_submissions_no_slug(self): """Should return info on currently connected target if slug not provided""" @@ -761,15 +690,15 @@ def test_get_submissions_no_slug(self): ] }] } - self.targets.db.find_targets = MagicMock() - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = return_data + self.targets._db.find_targets = MagicMock() + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = return_data self.targets.get_connected = MagicMock() self.targets.get_connected.return_value = {"slug": "u2ire"} - self.assertEquals(self.targets.get_submissions(), return_data["value"]) - self.targets.api.request.assert_called_with('GET', 'listing_analytics/categories', - query={"listing_id": "u2ire", "status": "accepted"}) + self.assertEqual(self.targets.get_submissions(), return_data["value"]) + self.targets._api.request.assert_called_with('GET', 'listing_analytics/categories', + query={"listing_id": "u2ire", "status": "accepted"}) def test_get_submissions_rejected(self): """Should return the accepted vulnerabilities for a target given a slug""" @@ -787,13 +716,13 @@ def test_get_submissions_rejected(self): ] }] } - self.targets.db.find_targets = MagicMock() - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = return_data - self.assertEquals(self.targets.get_submissions(status="rejected", slug='u2ire'), return_data["value"]) - self.targets.api.request.assert_called_with('GET', 'listing_analytics/categories', - query={"listing_id": "u2ire", "status": "rejected"}) + self.targets._db.find_targets = MagicMock() + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = return_data + self.assertEqual(self.targets.get_submissions(status="rejected", slug='u2ire'), return_data["value"]) + self.targets._api.request.assert_called_with('GET', 'listing_analytics/categories', + query={"listing_id": "u2ire", "status": "rejected"}) def test_get_submissions_summary(self): """Should return the amount of lifetime submissions given a slug""" @@ -802,13 +731,13 @@ def test_get_submissions_summary(self): "type": "submissions", "value": 35 } - self.targets.db.find_targets = MagicMock() - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = return_data - self.assertEquals(self.targets.get_submissions_summary(slug='u2ire'), 35) - self.targets.api.request.assert_called_with('GET', 'listing_analytics/submissions', - query={"listing_id": "u2ire"}) + self.targets._db.find_targets = MagicMock() + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = return_data + self.assertEqual(self.targets.get_submissions_summary(slug='u2ire'), 35) + self.targets._api.request.assert_called_with('GET', 'listing_analytics/submissions', + query={"listing_id": "u2ire"}) def test_get_submissions_summary_hours(self): """Should return the amount of submissions in the last x hours given a slug""" @@ -817,13 +746,13 @@ def test_get_submissions_summary_hours(self): "type": "submissions", "value": 5 } - self.targets.db.find_targets = MagicMock() - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = return_data - self.assertEquals(self.targets.get_submissions_summary(hours_ago=48, slug='u2ire'), 5) - self.targets.api.request.assert_called_with('GET', 'listing_analytics/submissions', - query={"listing_id": "u2ire", "period": "48h"}) + self.targets._db.find_targets = MagicMock() + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = return_data + self.assertEqual(self.targets.get_submissions_summary(hours_ago=48, slug='u2ire'), 5) + self.targets._api.request.assert_called_with('GET', 'listing_analytics/submissions', + query={"listing_id": "u2ire", "period": "48h"}) def test_get_submissions_summary_no_slug(self): """Should return the amount of lifetime submissions for current connected when no slug""" @@ -832,25 +761,25 @@ def test_get_submissions_summary_no_slug(self): "type": "submissions", "value": 35 } - self.targets.db.find_targets = MagicMock() + self.targets._db.find_targets = MagicMock() self.targets.get_connected = MagicMock() self.targets.get_connected.return_value = {'slug': 'u2ire'} - self.targets.db.find_targets.return_value = [Target(slug='u2ire')] - self.targets.api.request.return_value.status_code = 200 - self.targets.api.request.return_value.json.return_value = return_data - self.assertEquals(self.targets.get_submissions_summary(), 35) - self.targets.api.request.assert_called_with('GET', 'listing_analytics/submissions', - query={"listing_id": "u2ire"}) + self.targets._db.find_targets.return_value = [Target(slug='u2ire')] + self.targets._api.request.return_value.status_code = 200 + self.targets._api.request.return_value.json.return_value = return_data + self.assertEqual(self.targets.get_submissions_summary(), 35) + self.targets._api.request.assert_called_with('GET', 'listing_analytics/submissions', + query={"listing_id": "u2ire"}) def test_get_unregistered(self): """Should query for unregistered targets""" results = [ {'codename': 'SLEEPYSLUG', 'slug': '1283hi'} ] - self.targets.get_query = MagicMock() - self.targets.get_query.return_value = results - self.assertEquals(results, self.targets.get_unregistered()) - self.targets.get_query.assert_called_with(status='unregistered') + self.targets.get = MagicMock() + self.targets.get.return_value = results + self.assertEqual(results, self.targets.get_unregistered()) + self.targets.get.assert_called_with(status='unregistered') def test_get_upcoming(self): """Should query for upcoming targets""" @@ -861,41 +790,41 @@ def test_get_upcoming(self): 'sorting[field]': 'upcomingStartDate', 'sorting[direction]': 'asc' } - self.targets.get_query = MagicMock() - self.targets.get_query.return_value = results - self.assertEquals(results, self.targets.get_upcoming()) - self.targets.get_query.assert_called_with(status='upcoming', query_changes=query_changes) + self.targets.get = MagicMock() + self.targets.get.return_value = results + self.assertEqual(results, self.targets.get_upcoming()) + self.targets.get.assert_called_with(status='upcoming', query_changes=query_changes) def test_set_connected(self): """Should connect to a given target provided kwargs""" - self.targets.db.find_targets.return_value = [Target(slug='28h93iw')] - self.targets.api.request.return_value.status_code = 200 + self.targets._db.find_targets.return_value = [Target(slug='28h93iw')] + self.targets._api.request.return_value.status_code = 200 self.targets.get_connected = MagicMock() self.targets.set_connected(slug='28h93iw') - self.targets.api.request.assert_called_with('PUT', - 'launchpoint', - data={'listing_id': '28h93iw'}) + self.targets._api.request.assert_called_with('PUT', + 'launchpoint', + data={'listing_id': '28h93iw'}) self.targets.get_connected.assert_called_with() def test_set_connected_disconnect(self): """Should disconnect from target if none specified""" - self.targets.api.request.return_value.status_code = 200 + self.targets._api.request.return_value.status_code = 200 self.targets.get_connected = MagicMock() self.targets.set_connected() - self.targets.api.request.assert_called_with('PUT', - 'launchpoint', - data={'listing_id': ''}) + self.targets._api.request.assert_called_with('PUT', + 'launchpoint', + data={'listing_id': ''}) self.targets.get_connected.assert_called_with() def test_set_connected_target(self): """Should connect to a given target provided a target""" target = Target(slug='28h93iw') - self.targets.api.request.return_value.status_code = 200 + self.targets._api.request.return_value.status_code = 200 self.targets.get_connected = MagicMock() self.targets.set_connected(target) - self.targets.api.request.assert_called_with('PUT', - 'launchpoint', - data={'listing_id': '28h93iw'}) + self.targets._api.request.assert_called_with('PUT', + 'launchpoint', + data={'listing_id': '28h93iw'}) self.targets.get_connected.assert_called_with() def test_set_registered(self): @@ -920,9 +849,9 @@ def test_set_registered(self): data='{"ResearcherListing":{"terms":1}}') ] self.targets.get_unregistered.return_value = unreg - self.targets.api.request.return_value.status_code = 200 + self.targets._api.request.return_value.status_code = 200 self.assertEqual(unreg, self.targets.set_registered()) - self.targets.api.request.assert_has_calls(calls) + self.targets._api.request.assert_has_calls(calls) def test_set_registered_many(self): """Should call itself again if it has determined the page was full""" @@ -935,5 +864,5 @@ def test_set_registered_many(self): for i in range(0, 15): unreg.append(t) self.targets.get_unregistered.side_effect = [unreg, [t, t]] - self.targets.api.request.return_value.status_code = 200 + self.targets._api.request.return_value.status_code = 200 self.assertEqual(17, len(self.targets.set_registered())) diff --git a/test/test_templates.py b/test/test_templates.py index 15586bb..91cce69 100644 --- a/test/test_templates.py +++ b/test/test_templates.py @@ -19,8 +19,9 @@ class TemplatesTestCase(unittest.TestCase): def setUp(self): self.state = synack._state.State() + self.state._db = MagicMock() self.templates = synack.plugins.Templates(self.state) - self.templates.db = MagicMock() + self.templates._db = MagicMock() def test_build_filepath_from_evidences(self): """Should return path from evidences json""" @@ -35,7 +36,7 @@ def test_build_filepath_from_evidences(self): 'asset': 'web', 'title': 'Mission' } - self.templates.db.template_dir = pathlib.Path('/tmp') + self.templates._state.template_dir = pathlib.Path('/tmp') self.assertEqual('/tmp/mission/web/mission.txt', self.templates.build_filepath(mission)) @@ -54,7 +55,7 @@ def test_build_filepath_from_mission(self): ], 'title': 'Mission' } - self.templates.db.template_dir = pathlib.Path('/tmp') + self.templates._state.template_dir = pathlib.Path('/tmp') self.assertEqual('/tmp/mission/web/mission.txt', self.templates.build_filepath(mission)) @@ -75,16 +76,16 @@ def test_build_filepath_non_exist_and_generic_ok(self): } with patch('pathlib.Path.exists') as mock_exists: mock_exists.side_effect = [False, True] - self.templates.db.template_dir = pathlib.Path('/tmp') + self.templates._state.template_dir = pathlib.Path('/tmp') self.assertEqual('/tmp/mission/web/generic.txt', self.templates.build_filepath(mission, generic_ok=True)) def test_build_safe_name(self): """Should convert complex missions names to something simpler""" - self.templates.alerts = MagicMock() - self.templates.alerts.sanitize.return_value = "S!oME_RaNdOm___MISSION!" + self.templates._alerts = MagicMock() + self.templates._alerts.sanitize.return_value = "S!oME_RaNdOm___MISSION!" one = self.templates.build_safe_name("S!oME_RaNdOm___MISSION!") - self.templates.alerts.sanitize.assert_called_with("S!oME_RaNdOm___MISSION!") + self.templates._alerts.sanitize.assert_called_with("S!oME_RaNdOm___MISSION!") one_out = "s_ome_random_mission_" self.assertEqual(one_out, one) @@ -111,22 +112,22 @@ def test_build_sections(self): def test_build_text_replaced_variables(self): """Should replace variables in text given text and Target info""" - self.templates.db.find_targets = MagicMock() + self.templates._db.find_targets = MagicMock() tgts = [Target(codename='SNEAKYSASQUATCH', slug='38h24iu')] - self.templates.db.find_targets.return_value = tgts + self.templates._db.find_targets.return_value = tgts input_text = "The target is {{ TARGET_CODENAME }}" expected_output = "The target is SNEAKYSASQUATCH" - self.assertEquals(self.templates.build_replace_variables(input_text, target=tgts[0]), expected_output) + self.assertEqual(self.templates.build_replace_variables(input_text, target=tgts[0]), expected_output) def test_build_text_replaced_variables_codename(self): """Should replace variables in text given text and codename""" - self.templates.db.find_targets = MagicMock() + self.templates._db.find_targets = MagicMock() tgts = [Target(codename='SNEAKYSASQUATCH', slug='38h24iu')] - self.templates.db.find_targets.return_value = tgts + self.templates._db.find_targets.return_value = tgts input_text = "The target is {{ TARGET_CODENAME }}" expected_output = "The target is SNEAKYSASQUATCH" actual_output = self.templates.build_replace_variables(input_text, codename='SLEEPYSASQUATCH') - self.assertEquals(actual_output, expected_output) + self.assertEqual(actual_output, expected_output) def test_get_file(self): self.templates.build_filepath = MagicMock() diff --git a/test/test_transactions.py b/test/test_transactions.py index c3d7a7c..134d76e 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -19,8 +19,9 @@ class TransactionsTestCase(unittest.TestCase): def setUp(self): self.state = synack._state.State() + self.state._db = MagicMock() self.transactions = synack.plugins.Transactions(self.state) - self.transactions.api = MagicMock() + self.transactions._api = MagicMock() def test_get_balance(self): """Should get the balance of your synack account""" @@ -28,9 +29,9 @@ def test_get_balance(self): "total_balance": "10.0", "pending_payout": "0.0" }''' - self.transactions.api.request.return_value.headers = {'x-balance': bal} - self.transactions.api.request.return_value.status_code = 200 + self.transactions._api.request.return_value.headers = {'x-balance': bal} + self.transactions._api.request.return_value.status_code = 200 ret = self.transactions.get_balance() self.assertEqual(ret, json.loads(bal)) - self.transactions.api.request.assert_called_with('HEAD', - 'transactions') + self.transactions._api.request.assert_called_with('HEAD', + 'transactions') diff --git a/test/test_users.py b/test/test_users.py index 0d1d55b..7242a3a 100644 --- a/test/test_users.py +++ b/test/test_users.py @@ -17,22 +17,23 @@ class UsersTestCase(unittest.TestCase): def setUp(self): self.state = synack._state.State() + self.state._db = MagicMock() self.users = synack.plugins.Users(self.state) - self.users.api = MagicMock() - self.users.db = MagicMock() + self.users._api = MagicMock() + self.users._db = MagicMock() def test_get_profile(self): """Should get info about me""" - self.users.api.request.return_value.status_code = 200 - self.users.api.request.return_value.json.return_value = {"one": "1"} + self.users._api.request.return_value.status_code = 200 + self.users._api.request.return_value.json.return_value = {"one": "1"} self.assertEqual({"one": "1"}, self.users.get_profile()) - self.users.api.request.assert_called_with("GET", - "profiles/me") + self.users._api.request.assert_called_with("GET", + "profiles/me") def test_get_profile_other(self): """Should get info about someone else""" - self.users.api.request.return_value.status_code = 200 - self.users.api.request.return_value.json.return_value = {"one": "1"} + self.users._api.request.return_value.status_code = 200 + self.users._api.request.return_value.json.return_value = {"one": "1"} self.assertEqual({"one": "1"}, self.users.get_profile("lngvmkpj")) - self.users.api.request.assert_called_with("GET", - "profiles/lngvmkpj") + self.users._api.request.assert_called_with("GET", + "profiles/lngvmkpj")