From 53f86bf8034a545c7fc509dc4e4335a5aa244d65 Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Fri, 29 Sep 2023 15:21:22 -0700 Subject: [PATCH 1/7] last working --- .github/workflows/coverage.yml | 41 +- tests/test_cov.py | 12 + tests/test_modelscan.py | 1456 ++++++++++++++++---------------- tests/test_utils.py | 414 ++++----- 4 files changed, 961 insertions(+), 962 deletions(-) create mode 100644 tests/test_cov.py diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 4dea8c4d..b3d1ff08 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -1,24 +1,22 @@ name: Coverage - on: push: branches: main pull_request: branches: "*" - jobs: test: runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.8", "3.9", "3.10"] - + env: + OS: ubuntu-latest + PYTHON: '3.9' steps: - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} + - name: Set up Python uses: actions/setup-python@v2 + id: setup-python with: - python-version: ${{ matrix.python-version }} + python-version: "3.9" - uses: snok/install-poetry@v1 with: virtualenvs-create: true @@ -28,30 +26,19 @@ jobs: id: cached-poetry-dependencies uses: actions/cache@v3 with: - path: .venv - key: venv-test-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} - name: Install Dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' run: | - make install-test - - name: Get Bitwarden Secrets - uses: bitwarden/sm-action@v1 - with: - access_token: ${{ secrets.BW_ACCESS_TOKEN }} - secrets: | - 6b0baeba-4bd1-4c7d-b0c4-b0850005549d > BW_SECRET_1 - - name: Configure 1Password Service Account - uses: 1password/load-secrets-action/configure@v1 - with: - service-account-token: ${{ secrets.OP_SERVICE_ACCOUNT_TOKEN }} - - name: Load secret - uses: 1password/load-secrets-action@v1 - env: - OP_SECRET: op://app-cicd/"Better-Than-Bitwarden?"/Password + make install-test - name: Run Coverage run: | - make test + pip install coverage + pip install pytest + poetry add coverage codecov + poetry run coverage run -m pytest - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 with: token: ${{ secrets.CODECOV_TOKEN }} + verbose: true \ No newline at end of file diff --git a/tests/test_cov.py b/tests/test_cov.py new file mode 100644 index 00000000..59f31677 --- /dev/null +++ b/tests/test_cov.py @@ -0,0 +1,12 @@ +# import os +# from pathlib import Path +# from modelscan.modelscan import Modelscan + + +def test_coverage() -> None: + # try: + # ms = Modelscan() + # ms.scan_path(Path("")) + # except Exception: + # pass + pass diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index d6395226..ad2226b8 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -1,737 +1,737 @@ -import aiohttp -import http.client -import importlib -import io -import numpy as np -import os -from pathlib import Path -import pickle -import pytest -import requests # type: ignore[import] -import socket -import subprocess -import sys -import tensorflow as tf -from tensorflow import keras -from typing import Any, List, Set -from test_utils import generate_dill_unsafe_file -import zipfile +# import aiohttp +# import http.client +# import importlib +# import io +# import numpy as np +# import os +# from pathlib import Path +# import pickle +# import pytest +# import requests # type: ignore[import] +# import socket +# import subprocess +# import sys +# import tensorflow as tf +# from tensorflow import keras +# from typing import Any, List, Set +# from test_utils import generate_dill_unsafe_file +# import zipfile -from modelscan.modelscan import Modelscan -from modelscan.cli import cli -from modelscan.issues import ( - Issue, - IssueCode, - IssueSeverity, - OperatorIssueDetails, -) -from modelscan.tools.picklescanner import ( - scan_pickle_bytes, - scan_numpy, -) +# from modelscan.modelscan import Modelscan +# from modelscan.cli import cli +# from modelscan.issues import ( +# Issue, +# IssueCode, +# IssueSeverity, +# OperatorIssueDetails, +# ) +# from modelscan.tools.picklescanner import ( +# scan_pickle_bytes, +# scan_numpy, +# ) -class Malicious1: - def __reduce__(self) -> Any: - return eval, ("print('456')",) +# class Malicious1: +# def __reduce__(self) -> Any: +# return eval, ("print('456')",) -class Malicious2: - def __reduce__(self) -> Any: - return os.system, ("ls -la",) - - -class Malicious3: - def __reduce__(self) -> Any: - return http.client.HTTPSConnection, ("github.com",) - - -malicious3_pickle_bytes = pickle.dumps( - Malicious3(), protocol=0 -) # Malicious3 needs to be pickled before HTTPSConnection is mocked below - - -class Malicious4: - def __reduce__(self) -> Any: - return requests.get, ("https://github.com",) - - -class Malicious5: - def __reduce__(self) -> Any: - return aiohttp.ClientSession, tuple() - - -class Malicious6: - def __reduce__(self) -> Any: - return socket.create_connection, (("github.com", 80),) - - -class Malicious7: - def __reduce__(self) -> Any: - return subprocess.run, (["ls", "-l"],) - - -class Malicious8: - def __reduce__(self) -> Any: - return sys.exit, (0,) - - -def initialize_pickle_file(path: str, obj: Any, version: int) -> None: - if not os.path.exists(path): - with open(path, "wb") as file: - pickle.dump(obj, file, protocol=version) - - -def initialize_data_file(path: str, data: Any) -> None: - if not os.path.exists(path): - with open(path, "wb") as file: - file.write(data) - - -def initialize_zip_file(path: str, file_name: str, data: Any) -> None: - if not os.path.exists(path): - with zipfile.ZipFile(path, "w") as zip: - zip.writestr(file_name, data) - - -def initialize_numpy_file(path: str) -> None: - import numpy as np - - # create numpy object array - with open(path, "wb") as f: - data = [(1, 2), (3, 4)] - x = np.empty((2, 2), dtype=object) - x[:] = data - np.save(f, x) - - -@pytest.fixture(scope="session") -def zip_file_path(tmp_path_factory: Any) -> Any: - tmp = tmp_path_factory.mktemp("zip") - initialize_zip_file( - f"{tmp}/test.zip", - "data.pkl", - pickle.dumps(Malicious1(), protocol=4), - ) - return tmp +# class Malicious2: +# def __reduce__(self) -> Any: +# return os.system, ("ls -la",) + + +# class Malicious3: +# def __reduce__(self) -> Any: +# return http.client.HTTPSConnection, ("github.com",) + + +# malicious3_pickle_bytes = pickle.dumps( +# Malicious3(), protocol=0 +# ) # Malicious3 needs to be pickled before HTTPSConnection is mocked below + + +# class Malicious4: +# def __reduce__(self) -> Any: +# return requests.get, ("https://github.com",) + + +# class Malicious5: +# def __reduce__(self) -> Any: +# return aiohttp.ClientSession, tuple() + + +# class Malicious6: +# def __reduce__(self) -> Any: +# return socket.create_connection, (("github.com", 80),) + + +# class Malicious7: +# def __reduce__(self) -> Any: +# return subprocess.run, (["ls", "-l"],) + + +# class Malicious8: +# def __reduce__(self) -> Any: +# return sys.exit, (0,) + + +# def initialize_pickle_file(path: str, obj: Any, version: int) -> None: +# if not os.path.exists(path): +# with open(path, "wb") as file: +# pickle.dump(obj, file, protocol=version) + + +# def initialize_data_file(path: str, data: Any) -> None: +# if not os.path.exists(path): +# with open(path, "wb") as file: +# file.write(data) + + +# def initialize_zip_file(path: str, file_name: str, data: Any) -> None: +# if not os.path.exists(path): +# with zipfile.ZipFile(path, "w") as zip: +# zip.writestr(file_name, data) + + +# def initialize_numpy_file(path: str) -> None: +# import numpy as np + +# # create numpy object array +# with open(path, "wb") as f: +# data = [(1, 2), (3, 4)] +# x = np.empty((2, 2), dtype=object) +# x[:] = data +# np.save(f, x) + + +# @pytest.fixture(scope="session") +# def zip_file_path(tmp_path_factory: Any) -> Any: +# tmp = tmp_path_factory.mktemp("zip") +# initialize_zip_file( +# f"{tmp}/test.zip", +# "data.pkl", +# pickle.dumps(Malicious1(), protocol=4), +# ) +# return tmp -@pytest.fixture(scope="session") -def pickle_file_path(tmp_path_factory: Any) -> Any: - tmp = tmp_path_factory.mktemp("test_files") - os.makedirs(f"{tmp}/data", exist_ok=True) +# @pytest.fixture(scope="session") +# def pickle_file_path(tmp_path_factory: Any) -> Any: +# tmp = tmp_path_factory.mktemp("test_files") +# os.makedirs(f"{tmp}/data", exist_ok=True) - # Test with Pickle versions 0, 3, and 4: - # - Pickle versions 0, 1, 2 have built-in functions under '__builtin__' while versions 3 and 4 have them under 'builtins' - # - Pickle versions 0, 1, 2, 3 use 'GLOBAL' opcode while 4 uses 'STACK_GLOBAL' opcode - for version in (0, 3, 4): - initialize_pickle_file( - f"{tmp}/data/benign0_v{version}.pkl", ["a", "b", "c"], version - ) - initialize_pickle_file( - f"{tmp}/data/malicious1_v{version}.pkl", Malicious1(), version - ) - initialize_pickle_file( - f"{tmp}/data/malicious2_v{version}.pkl", Malicious2(), version - ) +# # Test with Pickle versions 0, 3, and 4: +# # - Pickle versions 0, 1, 2 have built-in functions under '__builtin__' while versions 3 and 4 have them under 'builtins' +# # - Pickle versions 0, 1, 2, 3 use 'GLOBAL' opcode while 4 uses 'STACK_GLOBAL' opcode +# for version in (0, 3, 4): +# initialize_pickle_file( +# f"{tmp}/data/benign0_v{version}.pkl", ["a", "b", "c"], version +# ) +# initialize_pickle_file( +# f"{tmp}/data/malicious1_v{version}.pkl", Malicious1(), version +# ) +# initialize_pickle_file( +# f"{tmp}/data/malicious2_v{version}.pkl", Malicious2(), version +# ) - # Malicious Pickle from https://sensepost.com/cms/resources/conferences/2011/sour_pickles/BH_US_11_Slaviero_Sour_Pickles.pdf - initialize_data_file( - f"{tmp}/data/malicious0.pkl", - b'c__builtin__\nglobals\n(tRp100\n0c__builtin__\ncompile\n(S\'fl=open("/etc/passwd");picklesmashed=fl.read();' - + b"'\nS''\nS'exec'\ntRp101\n0c__builtin__\neval\n(g101\ng100\ntRp102\n0c__builtin__\ngetattr\n(c__builtin__\n" - + b"dict\nS'get'\ntRp103\n0c__builtin__\napply\n(g103\n(g100\nS'picklesmashed'\nltRp104\n0g104\n.", - ) - - initialize_data_file(f"{tmp}/data/malicious3.pkl", malicious3_pickle_bytes) - initialize_pickle_file(f"{tmp}/data/malicious4.pickle", Malicious4(), 4) - initialize_pickle_file(f"{tmp}/data/malicious5.pickle", Malicious5(), 4) - initialize_data_file( - f"{tmp}/data/malicious6.pkl", - pickle.dumps(["a", "b", "c"]) + pickle.dumps(Malicious4()), - ) - initialize_pickle_file(f"{tmp}/data/malicious7.pkl", Malicious6(), 4) - initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4) - initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4) - - initialize_zip_file( - f"{tmp}/data/malicious1.zip", - "data.pkl", - pickle.dumps(Malicious1(), protocol=4), - ) - - initialize_numpy_file(f"{tmp}/data/object_array.npy") - - return tmp - - -@pytest.fixture(scope="session") -def keras_file_path(tmp_path_factory: Any) -> Any: - # Create a simple model. - inputs = keras.Input(shape=(32,)) - outputs = keras.layers.Dense(1)(inputs) - keras_model = keras.Model(inputs, outputs) - keras_model.compile(optimizer="adam", loss="mean_squared_error") - - # Train the model. - test_input = np.random.random((128, 32)) - test_target = np.random.random((128, 1)) - keras_model.fit(test_input, test_target) - - tmp = tmp_path_factory.mktemp("keras") - with open(f"{tmp}/safe", "wb") as fo: - pickle.dump(keras_model, fo) - keras_model.save(f"{tmp}/safe.h5") - - # Inject code with the command - command = "exec" - malicious_code = 'print("Malicious code!")' - - generate_dill_unsafe_file(keras_model, command, malicious_code, f"{tmp}/unsafe") - attack = ( - lambda x: exec( # type: ignore[func-returns-value] - """import http.client -import json -import os -conn = http.client.HTTPSConnection("protectai.com")""" - ) - or x - ) - input_to_new_layer = keras.layers.Dense(1)(keras_model.layers[-1].output) - new_layer = keras.layers.Lambda(attack)(input_to_new_layer) - - malicious_model = tf.keras.Model(inputs=keras_model.inputs, outputs=[new_layer]) - malicious_model.compile(optimizer="adam", loss="mean_squared_error") - - malicious_model.save(f"{tmp}/unsafe.h5") - - return tmp - - -def compare_results(resultList: List[Issue], expectedSet: Set[Issue]) -> None: - for result in resultList: - assert result in expectedSet - - -def test_scan_pickle_bytes() -> None: - expected = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails("builtins", "eval", "file.pkl"), - ) - ] - assert ( - scan_pickle_bytes(io.BytesIO(pickle.dumps(Malicious1())), "file.pkl")[0] - == expected - ) - - -def test_scan_zip(zip_file_path: Any) -> None: - expected = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{zip_file_path}/test.zip:data.pkl" - ), - ) - ] - - ms = Modelscan() - ms._scan_zip(f"{zip_file_path}/test.zip") - assert ms.issues.all_issues == expected - - -def test_scan_numpy(pickle_file_path: Any) -> None: - expected = { - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "numpy.core.multiarray", "_reconstruct", "object_array.npy" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails("numpy", "ndarray", "object_array.npy"), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails("numpy", "dtype", "object_array.npy"), - ), - } - with open(f"{pickle_file_path}/data/object_array.npy", "rb") as f: - compare_results( - scan_numpy(io.BytesIO(f.read()), "object_array.npy")[0], expected - ) - - -def test_scan_file_path(pickle_file_path: Any) -> None: - benign = Modelscan() - benign.scan_path(Path(f"{pickle_file_path}/data/benign0_v3.pkl")) - assert benign.issues.all_issues == [] - - malicious0 = Modelscan() - expected_malicious0 = { - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - } - malicious0.scan_path(Path(f"{pickle_file_path}/data/malicious0.pkl")) - compare_results(malicious0.issues.all_issues, expected_malicious0) - - expected_malicious1_v0 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" - ), - ) - ] - expected_malicious1_v3 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" - ), - ) - ] - expected_malicious1_v4 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" - ), - ) - ] - expected_malicious1 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" - ), - ) - ] - malicious1_v0 = Modelscan() - malicious1_v3 = Modelscan() - malicious1_v4 = Modelscan() - malicious1 = Modelscan() - malicious1_v0.scan_path(Path(f"{pickle_file_path}/data/malicious1_v0.pkl")) - malicious1_v3.scan_path(Path(f"{pickle_file_path}/data/malicious1_v3.pkl")) - malicious1_v4.scan_path(Path(f"{pickle_file_path}/data/malicious1_v4.pkl")) - malicious1.scan_path(Path(f"{pickle_file_path}/data/malicious1.zip")) - assert malicious1_v0.issues.all_issues == expected_malicious1_v0 - assert malicious1_v3.issues.all_issues == expected_malicious1_v3 - assert malicious1_v4.issues.all_issues == expected_malicious1_v4 - assert malicious1.issues.all_issues == expected_malicious1 - - expected_malicious2_v0 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" - ), - ) - ] - expected_malicious2_v3 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" - ), - ) - ] - expected_malicious2_v4 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" - ), - ) - ] - malicious2_v0 = Modelscan() - malicious2_v3 = Modelscan() - malicious2_v4 = Modelscan() - malicious2_v0.scan_path(Path(f"{pickle_file_path}/data/malicious2_v0.pkl")) - malicious2_v3.scan_path(Path(f"{pickle_file_path}/data/malicious2_v3.pkl")) - malicious2_v4.scan_path(Path(f"{pickle_file_path}/data/malicious2_v4.pkl")) - assert malicious2_v0.issues.all_issues == expected_malicious2_v0 - assert malicious2_v3.issues.all_issues == expected_malicious2_v3 - assert malicious2_v4.issues.all_issues == expected_malicious2_v4 - - expected_malicious3 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "httplib", - "HTTPSConnection", - Path(f"{pickle_file_path}/data/malicious3.pkl"), - ), - ) - ] - malicious3 = Modelscan() - malicious3.scan_path(Path(f"{pickle_file_path}/data/malicious3.pkl")) - assert malicious3.issues.all_issues == expected_malicious3 - - expected_malicious4 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" - ), - ) - ] - malicious4 = Modelscan() - malicious4.scan_path(Path(f"{pickle_file_path}/data/malicious4.pickle")) - assert malicious4.issues.all_issues == expected_malicious4 - - expected_malicious5 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "aiohttp.client", - "ClientSession", - f"{pickle_file_path}/data/malicious5.pickle", - ), - ) - ] - malicious5 = Modelscan() - malicious5.scan_path(Path(f"{pickle_file_path}/data/malicious5.pickle")) - assert malicious5.issues.all_issues == expected_malicious5 - - expected_malicious6 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" - ), - ) - ] - malicious6 = Modelscan() - malicious6.scan_path(Path(f"{pickle_file_path}/data/malicious6.pkl")) - assert malicious6.issues.all_issues == expected_malicious6 - - expected_malicious7 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" - ), - ) - ] - malicious7 = Modelscan() - malicious7.scan_path(Path(f"{pickle_file_path}/data/malicious7.pkl")) - assert malicious7.issues.all_issues == expected_malicious7 - - expected_malicious8 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" - ), - ) - ] - malicious8 = Modelscan() - malicious8.scan_path(Path(f"{pickle_file_path}/data/malicious8.pkl")) - assert malicious8.issues.all_issues == expected_malicious8 - - expected_malicious9 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" - ), - ) - ] - malicious9 = Modelscan() - malicious9.scan_path(Path(f"{pickle_file_path}/data/malicious9.pkl")) - assert malicious9.issues.all_issues == expected_malicious9 - - -def test_scan_directory_path(pickle_file_path: str) -> None: - expected = { - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "numpy", "ndarray", f"{pickle_file_path}/data/object_array.npy" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "numpy.core.multiarray", - "_reconstruct", - f"{pickle_file_path}/data/object_array.npy", - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "aiohttp.client", - "ClientSession", - f"{pickle_file_path}/data/malicious5.pickle", - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "httplib", "HTTPSConnection", f"{pickle_file_path}/data/malicious3.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" - ), - ), - } - ms = Modelscan() - p = Path(f"{pickle_file_path}/data/") - ms.scan_path(p) - compare_results(ms.issues.all_issues, expected) - - -def test_scan_huggingface_model() -> None: - expected = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", - "eval", - "https://huggingface.co/ykilcher/totally-harmless-model/resolve/main/pytorch_model.bin:archive/data.pkl", - ), - ) - ] - ms = Modelscan() - ms.scan_huggingface_model("ykilcher/totally-harmless-model") - assert ms.issues.all_issues == expected - - -# def test_scan_tf() -> None: - - -def test_scan_keras(keras_file_path: Any) -> None: - ms = Modelscan() - ms.scan_path(Path(f"{keras_file_path}/safe.h5")) - assert ms.issues.all_issues == [] - - expected = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "Keras", - "Lambda", - f"{keras_file_path}/unsafe.h5", - ), - ) - ] - ms.scan_path(Path(f"{keras_file_path}/unsafe.h5")) - assert ms.issues.all_issues == expected - - -def test_main(pickle_file_path: Any) -> None: - argv = sys.argv - try: - sys.argv = ["modelscan", "-p", f"{pickle_file_path}/data/benign0_v3.pkl"] - assert cli() == 0 - importlib.import_module("modelscan.scanner") - except SystemExit: - pass - finally: - sys.argv = argv +# # Malicious Pickle from https://sensepost.com/cms/resources/conferences/2011/sour_pickles/BH_US_11_Slaviero_Sour_Pickles.pdf +# initialize_data_file( +# f"{tmp}/data/malicious0.pkl", +# b'c__builtin__\nglobals\n(tRp100\n0c__builtin__\ncompile\n(S\'fl=open("/etc/passwd");picklesmashed=fl.read();' +# + b"'\nS''\nS'exec'\ntRp101\n0c__builtin__\neval\n(g101\ng100\ntRp102\n0c__builtin__\ngetattr\n(c__builtin__\n" +# + b"dict\nS'get'\ntRp103\n0c__builtin__\napply\n(g103\n(g100\nS'picklesmashed'\nltRp104\n0g104\n.", +# ) + +# initialize_data_file(f"{tmp}/data/malicious3.pkl", malicious3_pickle_bytes) +# initialize_pickle_file(f"{tmp}/data/malicious4.pickle", Malicious4(), 4) +# initialize_pickle_file(f"{tmp}/data/malicious5.pickle", Malicious5(), 4) +# initialize_data_file( +# f"{tmp}/data/malicious6.pkl", +# pickle.dumps(["a", "b", "c"]) + pickle.dumps(Malicious4()), +# ) +# initialize_pickle_file(f"{tmp}/data/malicious7.pkl", Malicious6(), 4) +# initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4) +# initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4) + +# initialize_zip_file( +# f"{tmp}/data/malicious1.zip", +# "data.pkl", +# pickle.dumps(Malicious1(), protocol=4), +# ) + +# initialize_numpy_file(f"{tmp}/data/object_array.npy") + +# return tmp + + +# @pytest.fixture(scope="session") +# def keras_file_path(tmp_path_factory: Any) -> Any: +# # Create a simple model. +# inputs = keras.Input(shape=(32,)) +# outputs = keras.layers.Dense(1)(inputs) +# keras_model = keras.Model(inputs, outputs) +# keras_model.compile(optimizer="adam", loss="mean_squared_error") + +# # Train the model. +# test_input = np.random.random((128, 32)) +# test_target = np.random.random((128, 1)) +# keras_model.fit(test_input, test_target) + +# tmp = tmp_path_factory.mktemp("keras") +# with open(f"{tmp}/safe", "wb") as fo: +# pickle.dump(keras_model, fo) +# keras_model.save(f"{tmp}/safe.h5") + +# # Inject code with the command +# command = "exec" +# malicious_code = 'print("Malicious code!")' + +# generate_dill_unsafe_file(keras_model, command, malicious_code, f"{tmp}/unsafe") +# attack = ( +# lambda x: exec( # type: ignore[func-returns-value] +# """import http.client +# import json +# import os +# conn = http.client.HTTPSConnection("protectai.com")""" +# ) +# or x +# ) +# input_to_new_layer = keras.layers.Dense(1)(keras_model.layers[-1].output) +# new_layer = keras.layers.Lambda(attack)(input_to_new_layer) + +# malicious_model = tf.keras.Model(inputs=keras_model.inputs, outputs=[new_layer]) +# malicious_model.compile(optimizer="adam", loss="mean_squared_error") + +# malicious_model.save(f"{tmp}/unsafe.h5") + +# return tmp + + +# def compare_results(resultList: List[Issue], expectedSet: Set[Issue]) -> None: +# for result in resultList: +# assert result in expectedSet + + +# def test_scan_pickle_bytes() -> None: +# expected = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails("builtins", "eval", "file.pkl"), +# ) +# ] +# assert ( +# scan_pickle_bytes(io.BytesIO(pickle.dumps(Malicious1())), "file.pkl")[0] +# == expected +# ) + + +# def test_scan_zip(zip_file_path: Any) -> None: +# expected = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{zip_file_path}/test.zip:data.pkl" +# ), +# ) +# ] + +# ms = Modelscan() +# ms._scan_zip(f"{zip_file_path}/test.zip") +# assert ms.issues.all_issues == expected + + +# def test_scan_numpy(pickle_file_path: Any) -> None: +# expected = { +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "numpy.core.multiarray", "_reconstruct", "object_array.npy" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails("numpy", "ndarray", "object_array.npy"), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails("numpy", "dtype", "object_array.npy"), +# ), +# } +# with open(f"{pickle_file_path}/data/object_array.npy", "rb") as f: +# compare_results( +# scan_numpy(io.BytesIO(f.read()), "object_array.npy")[0], expected +# ) + + +# def test_scan_file_path(pickle_file_path: Any) -> None: +# benign = Modelscan() +# benign.scan_path(Path(f"{pickle_file_path}/data/benign0_v3.pkl")) +# assert benign.issues.all_issues == [] + +# malicious0 = Modelscan() +# expected_malicious0 = { +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# } +# malicious0.scan_path(Path(f"{pickle_file_path}/data/malicious0.pkl")) +# compare_results(malicious0.issues.all_issues, expected_malicious0) + +# expected_malicious1_v0 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" +# ), +# ) +# ] +# expected_malicious1_v3 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" +# ), +# ) +# ] +# expected_malicious1_v4 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" +# ), +# ) +# ] +# expected_malicious1 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" +# ), +# ) +# ] +# malicious1_v0 = Modelscan() +# malicious1_v3 = Modelscan() +# malicious1_v4 = Modelscan() +# malicious1 = Modelscan() +# malicious1_v0.scan_path(Path(f"{pickle_file_path}/data/malicious1_v0.pkl")) +# malicious1_v3.scan_path(Path(f"{pickle_file_path}/data/malicious1_v3.pkl")) +# malicious1_v4.scan_path(Path(f"{pickle_file_path}/data/malicious1_v4.pkl")) +# malicious1.scan_path(Path(f"{pickle_file_path}/data/malicious1.zip")) +# assert malicious1_v0.issues.all_issues == expected_malicious1_v0 +# assert malicious1_v3.issues.all_issues == expected_malicious1_v3 +# assert malicious1_v4.issues.all_issues == expected_malicious1_v4 +# assert malicious1.issues.all_issues == expected_malicious1 + +# expected_malicious2_v0 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" +# ), +# ) +# ] +# expected_malicious2_v3 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" +# ), +# ) +# ] +# expected_malicious2_v4 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" +# ), +# ) +# ] +# malicious2_v0 = Modelscan() +# malicious2_v3 = Modelscan() +# malicious2_v4 = Modelscan() +# malicious2_v0.scan_path(Path(f"{pickle_file_path}/data/malicious2_v0.pkl")) +# malicious2_v3.scan_path(Path(f"{pickle_file_path}/data/malicious2_v3.pkl")) +# malicious2_v4.scan_path(Path(f"{pickle_file_path}/data/malicious2_v4.pkl")) +# assert malicious2_v0.issues.all_issues == expected_malicious2_v0 +# assert malicious2_v3.issues.all_issues == expected_malicious2_v3 +# assert malicious2_v4.issues.all_issues == expected_malicious2_v4 + +# expected_malicious3 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "httplib", +# "HTTPSConnection", +# Path(f"{pickle_file_path}/data/malicious3.pkl"), +# ), +# ) +# ] +# malicious3 = Modelscan() +# malicious3.scan_path(Path(f"{pickle_file_path}/data/malicious3.pkl")) +# assert malicious3.issues.all_issues == expected_malicious3 + +# expected_malicious4 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" +# ), +# ) +# ] +# malicious4 = Modelscan() +# malicious4.scan_path(Path(f"{pickle_file_path}/data/malicious4.pickle")) +# assert malicious4.issues.all_issues == expected_malicious4 + +# expected_malicious5 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "aiohttp.client", +# "ClientSession", +# f"{pickle_file_path}/data/malicious5.pickle", +# ), +# ) +# ] +# malicious5 = Modelscan() +# malicious5.scan_path(Path(f"{pickle_file_path}/data/malicious5.pickle")) +# assert malicious5.issues.all_issues == expected_malicious5 + +# expected_malicious6 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" +# ), +# ) +# ] +# malicious6 = Modelscan() +# malicious6.scan_path(Path(f"{pickle_file_path}/data/malicious6.pkl")) +# assert malicious6.issues.all_issues == expected_malicious6 + +# expected_malicious7 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" +# ), +# ) +# ] +# malicious7 = Modelscan() +# malicious7.scan_path(Path(f"{pickle_file_path}/data/malicious7.pkl")) +# assert malicious7.issues.all_issues == expected_malicious7 + +# expected_malicious8 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" +# ), +# ) +# ] +# malicious8 = Modelscan() +# malicious8.scan_path(Path(f"{pickle_file_path}/data/malicious8.pkl")) +# assert malicious8.issues.all_issues == expected_malicious8 + +# expected_malicious9 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" +# ), +# ) +# ] +# malicious9 = Modelscan() +# malicious9.scan_path(Path(f"{pickle_file_path}/data/malicious9.pkl")) +# assert malicious9.issues.all_issues == expected_malicious9 + + +# def test_scan_directory_path(pickle_file_path: str) -> None: +# expected = { +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "numpy", "ndarray", f"{pickle_file_path}/data/object_array.npy" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "numpy.core.multiarray", +# "_reconstruct", +# f"{pickle_file_path}/data/object_array.npy", +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "aiohttp.client", +# "ClientSession", +# f"{pickle_file_path}/data/malicious5.pickle", +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "httplib", "HTTPSConnection", f"{pickle_file_path}/data/malicious3.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" +# ), +# ), +# } +# ms = Modelscan() +# p = Path(f"{pickle_file_path}/data/") +# ms.scan_path(p) +# compare_results(ms.issues.all_issues, expected) + + +# def test_scan_huggingface_model() -> None: +# expected = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", +# "eval", +# "https://huggingface.co/ykilcher/totally-harmless-model/resolve/main/pytorch_model.bin:archive/data.pkl", +# ), +# ) +# ] +# ms = Modelscan() +# ms.scan_huggingface_model("ykilcher/totally-harmless-model") +# assert ms.issues.all_issues == expected + + +# # def test_scan_tf() -> None: + + +# def test_scan_keras(keras_file_path: Any) -> None: +# ms = Modelscan() +# ms.scan_path(Path(f"{keras_file_path}/safe.h5")) +# assert ms.issues.all_issues == [] + +# expected = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "Keras", +# "Lambda", +# f"{keras_file_path}/unsafe.h5", +# ), +# ) +# ] +# ms.scan_path(Path(f"{keras_file_path}/unsafe.h5")) +# assert ms.issues.all_issues == expected + + +# def test_main(pickle_file_path: Any) -> None: +# argv = sys.argv +# try: +# sys.argv = ["modelscan", "-p", f"{pickle_file_path}/data/benign0_v3.pkl"] +# assert cli() == 0 +# importlib.import_module("modelscan.scanner") +# except SystemExit: +# pass +# finally: +# sys.argv = argv diff --git a/tests/test_utils.py b/tests/test_utils.py index 34b3eeb5..691520a5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,236 +1,236 @@ -import dill -import os -import pickle -import struct -from typing import Any, Tuple -import os +# import dill +# import os +# import pickle +# import struct +# from typing import Any, Tuple +# import os - -class PickleInject: - """Pickle injection""" - - def __init__(self, inj_objs: Any, first: bool = True): - self.__name__ = "pickle_inject" - self.inj_objs = inj_objs - self.first = first - - class _Pickler(pickle._Pickler): - """Reimplementation of Pickler with support for injection""" - - def __init__( - self, file: Any, protocol: Any, inj_objs: Any, first: bool = True - ) -> None: - """ - file: File object with write attribute - protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html - inj_objs: _joblibInject object that has both the command, and the code to be injected - first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file. - """ - super().__init__(file, protocol) - self.inj_objs = inj_objs - self.first = first - - def dump(self, obj: Any) -> None: - """Pickle data, inject object before or after""" - if self.proto >= 2: # type: ignore[attr-defined] - self.write(pickle.PROTO + struct.pack("= 4: # type: ignore[attr-defined] - self.framer.start_framing() # type: ignore[attr-defined] - - # Inject the object(s) before the user-supplied data? - if self.first: - # Pickle injected objects - for inj_obj in self.inj_objs: - self.save(inj_obj) # type: ignore[attr-defined] - - # Pickle user-supplied data - self.save(obj) # type: ignore[attr-defined] - - # Inject the object(s) after the user-supplied data? - if not self.first: - # Pickle injected objects - for inj_obj in self.inj_objs: - self.save(inj_obj) # type: ignore[attr-defined] - - self.write(pickle.STOP) # type: ignore[attr-defined] - self.framer.end_framing() # type: ignore[attr-defined] - - def Pickler(self, file: Any, protocol: Any) -> _Pickler: - # Initialise the pickler interface with the injected object - return self._Pickler(file, protocol, self.inj_objs) - - class _PickleInject: - """Base class for pickling injected commands""" - - def __init__(self, args: Any, command: Any = None) -> None: - self.command = command - self.args = args - - def __reduce__(self) -> Tuple[Any, Any]: - """ - In general, the __reduce__ function is used by pickle to serialize objects. - If defined for an object, pickle would override its default __reduce__ function and serialize the object as outlined by the custom specified __reduce__ function, - The object returned by __reduce__ here is a callable: (self.command), and the tuple: with first element (self.args) is the code to be executed by self.command. - """ - return self.command, (self.args,) - - class System(_PickleInject): - """Create os.system command""" - - def __init__(self, args: Any) -> None: - super().__init__(args, command=os.system) - - class Exec(_PickleInject): - """Create exec command""" - - def __init__(self, args: Any) -> None: - super().__init__(args, command=exec) - - class Eval(_PickleInject): - """Create eval command""" - - def __init__(self, args: Any) -> None: - super().__init__(args, command=eval) - - class RunPy(_PickleInject): - """Create runpy command""" - - def __init__(self, args: Any) -> None: - import runpy - - super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] - - def __reduce__(self) -> Tuple[Any, Any]: - return self.command, (self.args, {}) + +# class PickleInject: +# """Pickle injection""" + +# def __init__(self, inj_objs: Any, first: bool = True): +# self.__name__ = "pickle_inject" +# self.inj_objs = inj_objs +# self.first = first + +# class _Pickler(pickle._Pickler): +# """Reimplementation of Pickler with support for injection""" + +# def __init__( +# self, file: Any, protocol: Any, inj_objs: Any, first: bool = True +# ) -> None: +# """ +# file: File object with write attribute +# protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html +# inj_objs: _joblibInject object that has both the command, and the code to be injected +# first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file. +# """ +# super().__init__(file, protocol) +# self.inj_objs = inj_objs +# self.first = first + +# def dump(self, obj: Any) -> None: +# """Pickle data, inject object before or after""" +# if self.proto >= 2: # type: ignore[attr-defined] +# self.write(pickle.PROTO + struct.pack("= 4: # type: ignore[attr-defined] +# self.framer.start_framing() # type: ignore[attr-defined] + +# # Inject the object(s) before the user-supplied data? +# if self.first: +# # Pickle injected objects +# for inj_obj in self.inj_objs: +# self.save(inj_obj) # type: ignore[attr-defined] + +# # Pickle user-supplied data +# self.save(obj) # type: ignore[attr-defined] + +# # Inject the object(s) after the user-supplied data? +# if not self.first: +# # Pickle injected objects +# for inj_obj in self.inj_objs: +# self.save(inj_obj) # type: ignore[attr-defined] + +# self.write(pickle.STOP) # type: ignore[attr-defined] +# self.framer.end_framing() # type: ignore[attr-defined] + +# def Pickler(self, file: Any, protocol: Any) -> _Pickler: +# # Initialise the pickler interface with the injected object +# return self._Pickler(file, protocol, self.inj_objs) + +# class _PickleInject: +# """Base class for pickling injected commands""" + +# def __init__(self, args: Any, command: Any = None) -> None: +# self.command = command +# self.args = args + +# def __reduce__(self) -> Tuple[Any, Any]: +# """ +# In general, the __reduce__ function is used by pickle to serialize objects. +# If defined for an object, pickle would override its default __reduce__ function and serialize the object as outlined by the custom specified __reduce__ function, +# The object returned by __reduce__ here is a callable: (self.command), and the tuple: with first element (self.args) is the code to be executed by self.command. +# """ +# return self.command, (self.args,) + +# class System(_PickleInject): +# """Create os.system command""" + +# def __init__(self, args: Any) -> None: +# super().__init__(args, command=os.system) + +# class Exec(_PickleInject): +# """Create exec command""" + +# def __init__(self, args: Any) -> None: +# super().__init__(args, command=exec) + +# class Eval(_PickleInject): +# """Create eval command""" + +# def __init__(self, args: Any) -> None: +# super().__init__(args, command=eval) + +# class RunPy(_PickleInject): +# """Create runpy command""" + +# def __init__(self, args: Any) -> None: +# import runpy + +# super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] + +# def __reduce__(self) -> Tuple[Any, Any]: +# return self.command, (self.args, {}) -def get_pickle_payload(command: str, malicious_code: str) -> Any: - if command == "system": - payload: Any = PickleInject.System(malicious_code) - elif command == "exec": - payload = PickleInject.Exec(malicious_code) - elif command == "eval": - payload = PickleInject.Eval(malicious_code) - elif command == "runpy": - payload = PickleInject.RunPy(malicious_code) - return payload +# def get_pickle_payload(command: str, malicious_code: str) -> Any: +# if command == "system": +# payload: Any = PickleInject.System(malicious_code) +# elif command == "exec": +# payload = PickleInject.Exec(malicious_code) +# elif command == "eval": +# payload = PickleInject.Eval(malicious_code) +# elif command == "runpy": +# payload = PickleInject.RunPy(malicious_code) +# return payload -def generate_unsafe_pickle_file( - safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str -) -> None: - payload = get_pickle_payload(command, malicious_code) - pickle_protocol = 4 - file_for_unsafe_model = open(unsafe_model_path, "wb") - mypickler = PickleInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) - mypickler.dump(safe_model) - file_for_unsafe_model.close() +# def generate_unsafe_pickle_file( +# safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str +# ) -> None: +# payload = get_pickle_payload(command, malicious_code) +# pickle_protocol = 4 +# file_for_unsafe_model = open(unsafe_model_path, "wb") +# mypickler = PickleInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) +# mypickler.dump(safe_model) +# file_for_unsafe_model.close() -class DillInject: - """Code injection using Dill Pickler""" +# class DillInject: +# """Code injection using Dill Pickler""" - def __init__(self, inj_objs: Any, first: bool = True): - self.__name__ = "dill_inject" - self.inj_objs = inj_objs - self.first = first +# def __init__(self, inj_objs: Any, first: bool = True): +# self.__name__ = "dill_inject" +# self.inj_objs = inj_objs +# self.first = first - class _Pickler(dill._dill.Pickler): # type: ignore[misc] - """Reimplementation of Pickler with support for injection""" +# class _Pickler(dill._dill.Pickler): # type: ignore[misc] +# """Reimplementation of Pickler with support for injection""" - def __init__(self, file: Any, protocol: Any, inj_objs: Any, first: bool = True): - super().__init__(file, protocol) - self.inj_objs = inj_objs - self.first = first +# def __init__(self, file: Any, protocol: Any, inj_objs: Any, first: bool = True): +# super().__init__(file, protocol) +# self.inj_objs = inj_objs +# self.first = first - def dump(self, obj: Any) -> None: - """Pickle data, inject object before or after""" - if self.proto >= 2: - self.write(pickle.PROTO + struct.pack("= 4: - self.framer.start_framing() +# def dump(self, obj: Any) -> None: +# """Pickle data, inject object before or after""" +# if self.proto >= 2: +# self.write(pickle.PROTO + struct.pack("= 4: +# self.framer.start_framing() - # Inject the object(s) before the user-supplied data? - if self.first: - # Pickle injected objects - for inj_obj in self.inj_objs: - self.save(inj_obj) +# # Inject the object(s) before the user-supplied data? +# if self.first: +# # Pickle injected objects +# for inj_obj in self.inj_objs: +# self.save(inj_obj) - # Pickle user-supplied data - self.save(obj) +# # Pickle user-supplied data +# self.save(obj) - # Inject the object(s) after the user-supplied data? - if not self.first: - # Pickle injected objects - for inj_obj in self.inj_objs: - self.save(inj_obj) +# # Inject the object(s) after the user-supplied data? +# if not self.first: +# # Pickle injected objects +# for inj_obj in self.inj_objs: +# self.save(inj_obj) - self.write(pickle.STOP) - self.framer.end_framing() +# self.write(pickle.STOP) +# self.framer.end_framing() - def DillPickler(self, file: Any, protocol: Any) -> _Pickler: - # Initialise the pickler interface with the injected object - return self._Pickler(file, protocol, self.inj_objs) +# def DillPickler(self, file: Any, protocol: Any) -> _Pickler: +# # Initialise the pickler interface with the injected object +# return self._Pickler(file, protocol, self.inj_objs) - class _DillInject: - """Base class for pickling injected commands""" +# class _DillInject: +# """Base class for pickling injected commands""" - def __init__(self, args: Any, command: Any = None): - self.command = command - self.args = args +# def __init__(self, args: Any, command: Any = None): +# self.command = command +# self.args = args - def __reduce__(self) -> Tuple[Any, Any]: - return self.command, (self.args,) +# def __reduce__(self) -> Tuple[Any, Any]: +# return self.command, (self.args,) - class System(_DillInject): - """Create os.system command""" +# class System(_DillInject): +# """Create os.system command""" - def __init__(self, args: Any): - super().__init__(args, command=os.system) +# def __init__(self, args: Any): +# super().__init__(args, command=os.system) - class Exec(_DillInject): - """Create exec command""" +# class Exec(_DillInject): +# """Create exec command""" - def __init__(self, args: Any): - super().__init__(args, command=exec) +# def __init__(self, args: Any): +# super().__init__(args, command=exec) - class Eval(_DillInject): - """Create eval command""" +# class Eval(_DillInject): +# """Create eval command""" - def __init__(self, args: Any): - super().__init__(args, command=eval) - - class RunPy(_DillInject): - """Create runpy command""" - - def __init__(self, args: Any): - import runpy - - super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] - - def __reduce__(self) -> Any: - return self.command, (self.args, {}) +# def __init__(self, args: Any): +# super().__init__(args, command=eval) + +# class RunPy(_DillInject): +# """Create runpy command""" + +# def __init__(self, args: Any): +# import runpy + +# super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] + +# def __reduce__(self) -> Any: +# return self.command, (self.args, {}) -def get_dill_payload(command: str, malicious_code: str) -> Any: - payload: Any - if command == "system": - payload = DillInject.System(malicious_code) - elif command == "exec": - payload = DillInject.Exec(malicious_code) - elif command == "eval": - payload = DillInject.Eval(malicious_code) - elif command == "runpy": - payload = DillInject.RunPy(malicious_code) - return payload +# def get_dill_payload(command: str, malicious_code: str) -> Any: +# payload: Any +# if command == "system": +# payload = DillInject.System(malicious_code) +# elif command == "exec": +# payload = DillInject.Exec(malicious_code) +# elif command == "eval": +# payload = DillInject.Eval(malicious_code) +# elif command == "runpy": +# payload = DillInject.RunPy(malicious_code) +# return payload -def generate_dill_unsafe_file( - safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str -) -> None: - payload = get_dill_payload(command, malicious_code) - pickle_protocol = 4 - file_for_unsafe_model = open(unsafe_model_path, "wb") - mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) - mypickler.dump(safe_model) - file_for_unsafe_model.close() +# def generate_dill_unsafe_file( +# safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str +# ) -> None: +# payload = get_dill_payload(command, malicious_code) +# pickle_protocol = 4 +# file_for_unsafe_model = open(unsafe_model_path, "wb") +# mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) +# mypickler.dump(safe_model) +# file_for_unsafe_model.close() From a7e7fc6ec62c6381a9a61e893c7cbc0aec8ad635 Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Fri, 29 Sep 2023 15:25:02 -0700 Subject: [PATCH 2/7] no pip --- .github/workflows/coverage.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index b3d1ff08..1440b358 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -33,8 +33,6 @@ jobs: make install-test - name: Run Coverage run: | - pip install coverage - pip install pytest poetry add coverage codecov poetry run coverage run -m pytest - name: Upload coverage reports to Codecov From a7985a8f6d73997666676b87537b656affb59dd0 Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Fri, 29 Sep 2023 15:27:50 -0700 Subject: [PATCH 3/7] uncomment full tests, re-add pip --- .github/workflows/coverage.yml | 2 + tests/test_modelscan.py | 1456 ++++++++++++++++---------------- tests/test_utils.py | 414 ++++----- 3 files changed, 937 insertions(+), 935 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 1440b358..b3d1ff08 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -33,6 +33,8 @@ jobs: make install-test - name: Run Coverage run: | + pip install coverage + pip install pytest poetry add coverage codecov poetry run coverage run -m pytest - name: Upload coverage reports to Codecov diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index ad2226b8..d6395226 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -1,737 +1,737 @@ -# import aiohttp -# import http.client -# import importlib -# import io -# import numpy as np -# import os -# from pathlib import Path -# import pickle -# import pytest -# import requests # type: ignore[import] -# import socket -# import subprocess -# import sys -# import tensorflow as tf -# from tensorflow import keras -# from typing import Any, List, Set -# from test_utils import generate_dill_unsafe_file -# import zipfile +import aiohttp +import http.client +import importlib +import io +import numpy as np +import os +from pathlib import Path +import pickle +import pytest +import requests # type: ignore[import] +import socket +import subprocess +import sys +import tensorflow as tf +from tensorflow import keras +from typing import Any, List, Set +from test_utils import generate_dill_unsafe_file +import zipfile -# from modelscan.modelscan import Modelscan -# from modelscan.cli import cli -# from modelscan.issues import ( -# Issue, -# IssueCode, -# IssueSeverity, -# OperatorIssueDetails, -# ) -# from modelscan.tools.picklescanner import ( -# scan_pickle_bytes, -# scan_numpy, -# ) +from modelscan.modelscan import Modelscan +from modelscan.cli import cli +from modelscan.issues import ( + Issue, + IssueCode, + IssueSeverity, + OperatorIssueDetails, +) +from modelscan.tools.picklescanner import ( + scan_pickle_bytes, + scan_numpy, +) -# class Malicious1: -# def __reduce__(self) -> Any: -# return eval, ("print('456')",) +class Malicious1: + def __reduce__(self) -> Any: + return eval, ("print('456')",) -# class Malicious2: -# def __reduce__(self) -> Any: -# return os.system, ("ls -la",) - - -# class Malicious3: -# def __reduce__(self) -> Any: -# return http.client.HTTPSConnection, ("github.com",) - - -# malicious3_pickle_bytes = pickle.dumps( -# Malicious3(), protocol=0 -# ) # Malicious3 needs to be pickled before HTTPSConnection is mocked below - - -# class Malicious4: -# def __reduce__(self) -> Any: -# return requests.get, ("https://github.com",) - - -# class Malicious5: -# def __reduce__(self) -> Any: -# return aiohttp.ClientSession, tuple() - - -# class Malicious6: -# def __reduce__(self) -> Any: -# return socket.create_connection, (("github.com", 80),) - - -# class Malicious7: -# def __reduce__(self) -> Any: -# return subprocess.run, (["ls", "-l"],) - - -# class Malicious8: -# def __reduce__(self) -> Any: -# return sys.exit, (0,) - - -# def initialize_pickle_file(path: str, obj: Any, version: int) -> None: -# if not os.path.exists(path): -# with open(path, "wb") as file: -# pickle.dump(obj, file, protocol=version) - - -# def initialize_data_file(path: str, data: Any) -> None: -# if not os.path.exists(path): -# with open(path, "wb") as file: -# file.write(data) - - -# def initialize_zip_file(path: str, file_name: str, data: Any) -> None: -# if not os.path.exists(path): -# with zipfile.ZipFile(path, "w") as zip: -# zip.writestr(file_name, data) - - -# def initialize_numpy_file(path: str) -> None: -# import numpy as np - -# # create numpy object array -# with open(path, "wb") as f: -# data = [(1, 2), (3, 4)] -# x = np.empty((2, 2), dtype=object) -# x[:] = data -# np.save(f, x) - - -# @pytest.fixture(scope="session") -# def zip_file_path(tmp_path_factory: Any) -> Any: -# tmp = tmp_path_factory.mktemp("zip") -# initialize_zip_file( -# f"{tmp}/test.zip", -# "data.pkl", -# pickle.dumps(Malicious1(), protocol=4), -# ) -# return tmp +class Malicious2: + def __reduce__(self) -> Any: + return os.system, ("ls -la",) + + +class Malicious3: + def __reduce__(self) -> Any: + return http.client.HTTPSConnection, ("github.com",) + + +malicious3_pickle_bytes = pickle.dumps( + Malicious3(), protocol=0 +) # Malicious3 needs to be pickled before HTTPSConnection is mocked below + + +class Malicious4: + def __reduce__(self) -> Any: + return requests.get, ("https://github.com",) + + +class Malicious5: + def __reduce__(self) -> Any: + return aiohttp.ClientSession, tuple() + + +class Malicious6: + def __reduce__(self) -> Any: + return socket.create_connection, (("github.com", 80),) + + +class Malicious7: + def __reduce__(self) -> Any: + return subprocess.run, (["ls", "-l"],) + + +class Malicious8: + def __reduce__(self) -> Any: + return sys.exit, (0,) + + +def initialize_pickle_file(path: str, obj: Any, version: int) -> None: + if not os.path.exists(path): + with open(path, "wb") as file: + pickle.dump(obj, file, protocol=version) + + +def initialize_data_file(path: str, data: Any) -> None: + if not os.path.exists(path): + with open(path, "wb") as file: + file.write(data) + + +def initialize_zip_file(path: str, file_name: str, data: Any) -> None: + if not os.path.exists(path): + with zipfile.ZipFile(path, "w") as zip: + zip.writestr(file_name, data) + + +def initialize_numpy_file(path: str) -> None: + import numpy as np + + # create numpy object array + with open(path, "wb") as f: + data = [(1, 2), (3, 4)] + x = np.empty((2, 2), dtype=object) + x[:] = data + np.save(f, x) + + +@pytest.fixture(scope="session") +def zip_file_path(tmp_path_factory: Any) -> Any: + tmp = tmp_path_factory.mktemp("zip") + initialize_zip_file( + f"{tmp}/test.zip", + "data.pkl", + pickle.dumps(Malicious1(), protocol=4), + ) + return tmp -# @pytest.fixture(scope="session") -# def pickle_file_path(tmp_path_factory: Any) -> Any: -# tmp = tmp_path_factory.mktemp("test_files") -# os.makedirs(f"{tmp}/data", exist_ok=True) +@pytest.fixture(scope="session") +def pickle_file_path(tmp_path_factory: Any) -> Any: + tmp = tmp_path_factory.mktemp("test_files") + os.makedirs(f"{tmp}/data", exist_ok=True) -# # Test with Pickle versions 0, 3, and 4: -# # - Pickle versions 0, 1, 2 have built-in functions under '__builtin__' while versions 3 and 4 have them under 'builtins' -# # - Pickle versions 0, 1, 2, 3 use 'GLOBAL' opcode while 4 uses 'STACK_GLOBAL' opcode -# for version in (0, 3, 4): -# initialize_pickle_file( -# f"{tmp}/data/benign0_v{version}.pkl", ["a", "b", "c"], version -# ) -# initialize_pickle_file( -# f"{tmp}/data/malicious1_v{version}.pkl", Malicious1(), version -# ) -# initialize_pickle_file( -# f"{tmp}/data/malicious2_v{version}.pkl", Malicious2(), version -# ) + # Test with Pickle versions 0, 3, and 4: + # - Pickle versions 0, 1, 2 have built-in functions under '__builtin__' while versions 3 and 4 have them under 'builtins' + # - Pickle versions 0, 1, 2, 3 use 'GLOBAL' opcode while 4 uses 'STACK_GLOBAL' opcode + for version in (0, 3, 4): + initialize_pickle_file( + f"{tmp}/data/benign0_v{version}.pkl", ["a", "b", "c"], version + ) + initialize_pickle_file( + f"{tmp}/data/malicious1_v{version}.pkl", Malicious1(), version + ) + initialize_pickle_file( + f"{tmp}/data/malicious2_v{version}.pkl", Malicious2(), version + ) -# # Malicious Pickle from https://sensepost.com/cms/resources/conferences/2011/sour_pickles/BH_US_11_Slaviero_Sour_Pickles.pdf -# initialize_data_file( -# f"{tmp}/data/malicious0.pkl", -# b'c__builtin__\nglobals\n(tRp100\n0c__builtin__\ncompile\n(S\'fl=open("/etc/passwd");picklesmashed=fl.read();' -# + b"'\nS''\nS'exec'\ntRp101\n0c__builtin__\neval\n(g101\ng100\ntRp102\n0c__builtin__\ngetattr\n(c__builtin__\n" -# + b"dict\nS'get'\ntRp103\n0c__builtin__\napply\n(g103\n(g100\nS'picklesmashed'\nltRp104\n0g104\n.", -# ) - -# initialize_data_file(f"{tmp}/data/malicious3.pkl", malicious3_pickle_bytes) -# initialize_pickle_file(f"{tmp}/data/malicious4.pickle", Malicious4(), 4) -# initialize_pickle_file(f"{tmp}/data/malicious5.pickle", Malicious5(), 4) -# initialize_data_file( -# f"{tmp}/data/malicious6.pkl", -# pickle.dumps(["a", "b", "c"]) + pickle.dumps(Malicious4()), -# ) -# initialize_pickle_file(f"{tmp}/data/malicious7.pkl", Malicious6(), 4) -# initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4) -# initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4) - -# initialize_zip_file( -# f"{tmp}/data/malicious1.zip", -# "data.pkl", -# pickle.dumps(Malicious1(), protocol=4), -# ) - -# initialize_numpy_file(f"{tmp}/data/object_array.npy") - -# return tmp - - -# @pytest.fixture(scope="session") -# def keras_file_path(tmp_path_factory: Any) -> Any: -# # Create a simple model. -# inputs = keras.Input(shape=(32,)) -# outputs = keras.layers.Dense(1)(inputs) -# keras_model = keras.Model(inputs, outputs) -# keras_model.compile(optimizer="adam", loss="mean_squared_error") - -# # Train the model. -# test_input = np.random.random((128, 32)) -# test_target = np.random.random((128, 1)) -# keras_model.fit(test_input, test_target) - -# tmp = tmp_path_factory.mktemp("keras") -# with open(f"{tmp}/safe", "wb") as fo: -# pickle.dump(keras_model, fo) -# keras_model.save(f"{tmp}/safe.h5") - -# # Inject code with the command -# command = "exec" -# malicious_code = 'print("Malicious code!")' - -# generate_dill_unsafe_file(keras_model, command, malicious_code, f"{tmp}/unsafe") -# attack = ( -# lambda x: exec( # type: ignore[func-returns-value] -# """import http.client -# import json -# import os -# conn = http.client.HTTPSConnection("protectai.com")""" -# ) -# or x -# ) -# input_to_new_layer = keras.layers.Dense(1)(keras_model.layers[-1].output) -# new_layer = keras.layers.Lambda(attack)(input_to_new_layer) - -# malicious_model = tf.keras.Model(inputs=keras_model.inputs, outputs=[new_layer]) -# malicious_model.compile(optimizer="adam", loss="mean_squared_error") - -# malicious_model.save(f"{tmp}/unsafe.h5") - -# return tmp - - -# def compare_results(resultList: List[Issue], expectedSet: Set[Issue]) -> None: -# for result in resultList: -# assert result in expectedSet - - -# def test_scan_pickle_bytes() -> None: -# expected = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails("builtins", "eval", "file.pkl"), -# ) -# ] -# assert ( -# scan_pickle_bytes(io.BytesIO(pickle.dumps(Malicious1())), "file.pkl")[0] -# == expected -# ) - - -# def test_scan_zip(zip_file_path: Any) -> None: -# expected = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{zip_file_path}/test.zip:data.pkl" -# ), -# ) -# ] - -# ms = Modelscan() -# ms._scan_zip(f"{zip_file_path}/test.zip") -# assert ms.issues.all_issues == expected - - -# def test_scan_numpy(pickle_file_path: Any) -> None: -# expected = { -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "numpy.core.multiarray", "_reconstruct", "object_array.npy" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails("numpy", "ndarray", "object_array.npy"), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails("numpy", "dtype", "object_array.npy"), -# ), -# } -# with open(f"{pickle_file_path}/data/object_array.npy", "rb") as f: -# compare_results( -# scan_numpy(io.BytesIO(f.read()), "object_array.npy")[0], expected -# ) - - -# def test_scan_file_path(pickle_file_path: Any) -> None: -# benign = Modelscan() -# benign.scan_path(Path(f"{pickle_file_path}/data/benign0_v3.pkl")) -# assert benign.issues.all_issues == [] - -# malicious0 = Modelscan() -# expected_malicious0 = { -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# } -# malicious0.scan_path(Path(f"{pickle_file_path}/data/malicious0.pkl")) -# compare_results(malicious0.issues.all_issues, expected_malicious0) - -# expected_malicious1_v0 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" -# ), -# ) -# ] -# expected_malicious1_v3 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" -# ), -# ) -# ] -# expected_malicious1_v4 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" -# ), -# ) -# ] -# expected_malicious1 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" -# ), -# ) -# ] -# malicious1_v0 = Modelscan() -# malicious1_v3 = Modelscan() -# malicious1_v4 = Modelscan() -# malicious1 = Modelscan() -# malicious1_v0.scan_path(Path(f"{pickle_file_path}/data/malicious1_v0.pkl")) -# malicious1_v3.scan_path(Path(f"{pickle_file_path}/data/malicious1_v3.pkl")) -# malicious1_v4.scan_path(Path(f"{pickle_file_path}/data/malicious1_v4.pkl")) -# malicious1.scan_path(Path(f"{pickle_file_path}/data/malicious1.zip")) -# assert malicious1_v0.issues.all_issues == expected_malicious1_v0 -# assert malicious1_v3.issues.all_issues == expected_malicious1_v3 -# assert malicious1_v4.issues.all_issues == expected_malicious1_v4 -# assert malicious1.issues.all_issues == expected_malicious1 - -# expected_malicious2_v0 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" -# ), -# ) -# ] -# expected_malicious2_v3 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" -# ), -# ) -# ] -# expected_malicious2_v4 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" -# ), -# ) -# ] -# malicious2_v0 = Modelscan() -# malicious2_v3 = Modelscan() -# malicious2_v4 = Modelscan() -# malicious2_v0.scan_path(Path(f"{pickle_file_path}/data/malicious2_v0.pkl")) -# malicious2_v3.scan_path(Path(f"{pickle_file_path}/data/malicious2_v3.pkl")) -# malicious2_v4.scan_path(Path(f"{pickle_file_path}/data/malicious2_v4.pkl")) -# assert malicious2_v0.issues.all_issues == expected_malicious2_v0 -# assert malicious2_v3.issues.all_issues == expected_malicious2_v3 -# assert malicious2_v4.issues.all_issues == expected_malicious2_v4 - -# expected_malicious3 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "httplib", -# "HTTPSConnection", -# Path(f"{pickle_file_path}/data/malicious3.pkl"), -# ), -# ) -# ] -# malicious3 = Modelscan() -# malicious3.scan_path(Path(f"{pickle_file_path}/data/malicious3.pkl")) -# assert malicious3.issues.all_issues == expected_malicious3 - -# expected_malicious4 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" -# ), -# ) -# ] -# malicious4 = Modelscan() -# malicious4.scan_path(Path(f"{pickle_file_path}/data/malicious4.pickle")) -# assert malicious4.issues.all_issues == expected_malicious4 - -# expected_malicious5 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "aiohttp.client", -# "ClientSession", -# f"{pickle_file_path}/data/malicious5.pickle", -# ), -# ) -# ] -# malicious5 = Modelscan() -# malicious5.scan_path(Path(f"{pickle_file_path}/data/malicious5.pickle")) -# assert malicious5.issues.all_issues == expected_malicious5 - -# expected_malicious6 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" -# ), -# ) -# ] -# malicious6 = Modelscan() -# malicious6.scan_path(Path(f"{pickle_file_path}/data/malicious6.pkl")) -# assert malicious6.issues.all_issues == expected_malicious6 - -# expected_malicious7 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" -# ), -# ) -# ] -# malicious7 = Modelscan() -# malicious7.scan_path(Path(f"{pickle_file_path}/data/malicious7.pkl")) -# assert malicious7.issues.all_issues == expected_malicious7 - -# expected_malicious8 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" -# ), -# ) -# ] -# malicious8 = Modelscan() -# malicious8.scan_path(Path(f"{pickle_file_path}/data/malicious8.pkl")) -# assert malicious8.issues.all_issues == expected_malicious8 - -# expected_malicious9 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" -# ), -# ) -# ] -# malicious9 = Modelscan() -# malicious9.scan_path(Path(f"{pickle_file_path}/data/malicious9.pkl")) -# assert malicious9.issues.all_issues == expected_malicious9 - - -# def test_scan_directory_path(pickle_file_path: str) -> None: -# expected = { -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "numpy", "ndarray", f"{pickle_file_path}/data/object_array.npy" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "numpy.core.multiarray", -# "_reconstruct", -# f"{pickle_file_path}/data/object_array.npy", -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "aiohttp.client", -# "ClientSession", -# f"{pickle_file_path}/data/malicious5.pickle", -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "httplib", "HTTPSConnection", f"{pickle_file_path}/data/malicious3.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" -# ), -# ), -# } -# ms = Modelscan() -# p = Path(f"{pickle_file_path}/data/") -# ms.scan_path(p) -# compare_results(ms.issues.all_issues, expected) - - -# def test_scan_huggingface_model() -> None: -# expected = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", -# "eval", -# "https://huggingface.co/ykilcher/totally-harmless-model/resolve/main/pytorch_model.bin:archive/data.pkl", -# ), -# ) -# ] -# ms = Modelscan() -# ms.scan_huggingface_model("ykilcher/totally-harmless-model") -# assert ms.issues.all_issues == expected - - -# # def test_scan_tf() -> None: - - -# def test_scan_keras(keras_file_path: Any) -> None: -# ms = Modelscan() -# ms.scan_path(Path(f"{keras_file_path}/safe.h5")) -# assert ms.issues.all_issues == [] - -# expected = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "Keras", -# "Lambda", -# f"{keras_file_path}/unsafe.h5", -# ), -# ) -# ] -# ms.scan_path(Path(f"{keras_file_path}/unsafe.h5")) -# assert ms.issues.all_issues == expected - - -# def test_main(pickle_file_path: Any) -> None: -# argv = sys.argv -# try: -# sys.argv = ["modelscan", "-p", f"{pickle_file_path}/data/benign0_v3.pkl"] -# assert cli() == 0 -# importlib.import_module("modelscan.scanner") -# except SystemExit: -# pass -# finally: -# sys.argv = argv + # Malicious Pickle from https://sensepost.com/cms/resources/conferences/2011/sour_pickles/BH_US_11_Slaviero_Sour_Pickles.pdf + initialize_data_file( + f"{tmp}/data/malicious0.pkl", + b'c__builtin__\nglobals\n(tRp100\n0c__builtin__\ncompile\n(S\'fl=open("/etc/passwd");picklesmashed=fl.read();' + + b"'\nS''\nS'exec'\ntRp101\n0c__builtin__\neval\n(g101\ng100\ntRp102\n0c__builtin__\ngetattr\n(c__builtin__\n" + + b"dict\nS'get'\ntRp103\n0c__builtin__\napply\n(g103\n(g100\nS'picklesmashed'\nltRp104\n0g104\n.", + ) + + initialize_data_file(f"{tmp}/data/malicious3.pkl", malicious3_pickle_bytes) + initialize_pickle_file(f"{tmp}/data/malicious4.pickle", Malicious4(), 4) + initialize_pickle_file(f"{tmp}/data/malicious5.pickle", Malicious5(), 4) + initialize_data_file( + f"{tmp}/data/malicious6.pkl", + pickle.dumps(["a", "b", "c"]) + pickle.dumps(Malicious4()), + ) + initialize_pickle_file(f"{tmp}/data/malicious7.pkl", Malicious6(), 4) + initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4) + initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4) + + initialize_zip_file( + f"{tmp}/data/malicious1.zip", + "data.pkl", + pickle.dumps(Malicious1(), protocol=4), + ) + + initialize_numpy_file(f"{tmp}/data/object_array.npy") + + return tmp + + +@pytest.fixture(scope="session") +def keras_file_path(tmp_path_factory: Any) -> Any: + # Create a simple model. + inputs = keras.Input(shape=(32,)) + outputs = keras.layers.Dense(1)(inputs) + keras_model = keras.Model(inputs, outputs) + keras_model.compile(optimizer="adam", loss="mean_squared_error") + + # Train the model. + test_input = np.random.random((128, 32)) + test_target = np.random.random((128, 1)) + keras_model.fit(test_input, test_target) + + tmp = tmp_path_factory.mktemp("keras") + with open(f"{tmp}/safe", "wb") as fo: + pickle.dump(keras_model, fo) + keras_model.save(f"{tmp}/safe.h5") + + # Inject code with the command + command = "exec" + malicious_code = 'print("Malicious code!")' + + generate_dill_unsafe_file(keras_model, command, malicious_code, f"{tmp}/unsafe") + attack = ( + lambda x: exec( # type: ignore[func-returns-value] + """import http.client +import json +import os +conn = http.client.HTTPSConnection("protectai.com")""" + ) + or x + ) + input_to_new_layer = keras.layers.Dense(1)(keras_model.layers[-1].output) + new_layer = keras.layers.Lambda(attack)(input_to_new_layer) + + malicious_model = tf.keras.Model(inputs=keras_model.inputs, outputs=[new_layer]) + malicious_model.compile(optimizer="adam", loss="mean_squared_error") + + malicious_model.save(f"{tmp}/unsafe.h5") + + return tmp + + +def compare_results(resultList: List[Issue], expectedSet: Set[Issue]) -> None: + for result in resultList: + assert result in expectedSet + + +def test_scan_pickle_bytes() -> None: + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails("builtins", "eval", "file.pkl"), + ) + ] + assert ( + scan_pickle_bytes(io.BytesIO(pickle.dumps(Malicious1())), "file.pkl")[0] + == expected + ) + + +def test_scan_zip(zip_file_path: Any) -> None: + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{zip_file_path}/test.zip:data.pkl" + ), + ) + ] + + ms = Modelscan() + ms._scan_zip(f"{zip_file_path}/test.zip") + assert ms.issues.all_issues == expected + + +def test_scan_numpy(pickle_file_path: Any) -> None: + expected = { + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "numpy.core.multiarray", "_reconstruct", "object_array.npy" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails("numpy", "ndarray", "object_array.npy"), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails("numpy", "dtype", "object_array.npy"), + ), + } + with open(f"{pickle_file_path}/data/object_array.npy", "rb") as f: + compare_results( + scan_numpy(io.BytesIO(f.read()), "object_array.npy")[0], expected + ) + + +def test_scan_file_path(pickle_file_path: Any) -> None: + benign = Modelscan() + benign.scan_path(Path(f"{pickle_file_path}/data/benign0_v3.pkl")) + assert benign.issues.all_issues == [] + + malicious0 = Modelscan() + expected_malicious0 = { + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + } + malicious0.scan_path(Path(f"{pickle_file_path}/data/malicious0.pkl")) + compare_results(malicious0.issues.all_issues, expected_malicious0) + + expected_malicious1_v0 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" + ), + ) + ] + expected_malicious1_v3 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" + ), + ) + ] + expected_malicious1_v4 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" + ), + ) + ] + expected_malicious1 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" + ), + ) + ] + malicious1_v0 = Modelscan() + malicious1_v3 = Modelscan() + malicious1_v4 = Modelscan() + malicious1 = Modelscan() + malicious1_v0.scan_path(Path(f"{pickle_file_path}/data/malicious1_v0.pkl")) + malicious1_v3.scan_path(Path(f"{pickle_file_path}/data/malicious1_v3.pkl")) + malicious1_v4.scan_path(Path(f"{pickle_file_path}/data/malicious1_v4.pkl")) + malicious1.scan_path(Path(f"{pickle_file_path}/data/malicious1.zip")) + assert malicious1_v0.issues.all_issues == expected_malicious1_v0 + assert malicious1_v3.issues.all_issues == expected_malicious1_v3 + assert malicious1_v4.issues.all_issues == expected_malicious1_v4 + assert malicious1.issues.all_issues == expected_malicious1 + + expected_malicious2_v0 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" + ), + ) + ] + expected_malicious2_v3 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" + ), + ) + ] + expected_malicious2_v4 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" + ), + ) + ] + malicious2_v0 = Modelscan() + malicious2_v3 = Modelscan() + malicious2_v4 = Modelscan() + malicious2_v0.scan_path(Path(f"{pickle_file_path}/data/malicious2_v0.pkl")) + malicious2_v3.scan_path(Path(f"{pickle_file_path}/data/malicious2_v3.pkl")) + malicious2_v4.scan_path(Path(f"{pickle_file_path}/data/malicious2_v4.pkl")) + assert malicious2_v0.issues.all_issues == expected_malicious2_v0 + assert malicious2_v3.issues.all_issues == expected_malicious2_v3 + assert malicious2_v4.issues.all_issues == expected_malicious2_v4 + + expected_malicious3 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "httplib", + "HTTPSConnection", + Path(f"{pickle_file_path}/data/malicious3.pkl"), + ), + ) + ] + malicious3 = Modelscan() + malicious3.scan_path(Path(f"{pickle_file_path}/data/malicious3.pkl")) + assert malicious3.issues.all_issues == expected_malicious3 + + expected_malicious4 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" + ), + ) + ] + malicious4 = Modelscan() + malicious4.scan_path(Path(f"{pickle_file_path}/data/malicious4.pickle")) + assert malicious4.issues.all_issues == expected_malicious4 + + expected_malicious5 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "aiohttp.client", + "ClientSession", + f"{pickle_file_path}/data/malicious5.pickle", + ), + ) + ] + malicious5 = Modelscan() + malicious5.scan_path(Path(f"{pickle_file_path}/data/malicious5.pickle")) + assert malicious5.issues.all_issues == expected_malicious5 + + expected_malicious6 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" + ), + ) + ] + malicious6 = Modelscan() + malicious6.scan_path(Path(f"{pickle_file_path}/data/malicious6.pkl")) + assert malicious6.issues.all_issues == expected_malicious6 + + expected_malicious7 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" + ), + ) + ] + malicious7 = Modelscan() + malicious7.scan_path(Path(f"{pickle_file_path}/data/malicious7.pkl")) + assert malicious7.issues.all_issues == expected_malicious7 + + expected_malicious8 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" + ), + ) + ] + malicious8 = Modelscan() + malicious8.scan_path(Path(f"{pickle_file_path}/data/malicious8.pkl")) + assert malicious8.issues.all_issues == expected_malicious8 + + expected_malicious9 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" + ), + ) + ] + malicious9 = Modelscan() + malicious9.scan_path(Path(f"{pickle_file_path}/data/malicious9.pkl")) + assert malicious9.issues.all_issues == expected_malicious9 + + +def test_scan_directory_path(pickle_file_path: str) -> None: + expected = { + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "numpy", "ndarray", f"{pickle_file_path}/data/object_array.npy" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "numpy.core.multiarray", + "_reconstruct", + f"{pickle_file_path}/data/object_array.npy", + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "aiohttp.client", + "ClientSession", + f"{pickle_file_path}/data/malicious5.pickle", + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "httplib", "HTTPSConnection", f"{pickle_file_path}/data/malicious3.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" + ), + ), + } + ms = Modelscan() + p = Path(f"{pickle_file_path}/data/") + ms.scan_path(p) + compare_results(ms.issues.all_issues, expected) + + +def test_scan_huggingface_model() -> None: + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", + "eval", + "https://huggingface.co/ykilcher/totally-harmless-model/resolve/main/pytorch_model.bin:archive/data.pkl", + ), + ) + ] + ms = Modelscan() + ms.scan_huggingface_model("ykilcher/totally-harmless-model") + assert ms.issues.all_issues == expected + + +# def test_scan_tf() -> None: + + +def test_scan_keras(keras_file_path: Any) -> None: + ms = Modelscan() + ms.scan_path(Path(f"{keras_file_path}/safe.h5")) + assert ms.issues.all_issues == [] + + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "Keras", + "Lambda", + f"{keras_file_path}/unsafe.h5", + ), + ) + ] + ms.scan_path(Path(f"{keras_file_path}/unsafe.h5")) + assert ms.issues.all_issues == expected + + +def test_main(pickle_file_path: Any) -> None: + argv = sys.argv + try: + sys.argv = ["modelscan", "-p", f"{pickle_file_path}/data/benign0_v3.pkl"] + assert cli() == 0 + importlib.import_module("modelscan.scanner") + except SystemExit: + pass + finally: + sys.argv = argv diff --git a/tests/test_utils.py b/tests/test_utils.py index 691520a5..34b3eeb5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,236 +1,236 @@ -# import dill -# import os -# import pickle -# import struct -# from typing import Any, Tuple -# import os +import dill +import os +import pickle +import struct +from typing import Any, Tuple +import os - -# class PickleInject: -# """Pickle injection""" - -# def __init__(self, inj_objs: Any, first: bool = True): -# self.__name__ = "pickle_inject" -# self.inj_objs = inj_objs -# self.first = first - -# class _Pickler(pickle._Pickler): -# """Reimplementation of Pickler with support for injection""" - -# def __init__( -# self, file: Any, protocol: Any, inj_objs: Any, first: bool = True -# ) -> None: -# """ -# file: File object with write attribute -# protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html -# inj_objs: _joblibInject object that has both the command, and the code to be injected -# first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file. -# """ -# super().__init__(file, protocol) -# self.inj_objs = inj_objs -# self.first = first - -# def dump(self, obj: Any) -> None: -# """Pickle data, inject object before or after""" -# if self.proto >= 2: # type: ignore[attr-defined] -# self.write(pickle.PROTO + struct.pack("= 4: # type: ignore[attr-defined] -# self.framer.start_framing() # type: ignore[attr-defined] - -# # Inject the object(s) before the user-supplied data? -# if self.first: -# # Pickle injected objects -# for inj_obj in self.inj_objs: -# self.save(inj_obj) # type: ignore[attr-defined] - -# # Pickle user-supplied data -# self.save(obj) # type: ignore[attr-defined] - -# # Inject the object(s) after the user-supplied data? -# if not self.first: -# # Pickle injected objects -# for inj_obj in self.inj_objs: -# self.save(inj_obj) # type: ignore[attr-defined] - -# self.write(pickle.STOP) # type: ignore[attr-defined] -# self.framer.end_framing() # type: ignore[attr-defined] - -# def Pickler(self, file: Any, protocol: Any) -> _Pickler: -# # Initialise the pickler interface with the injected object -# return self._Pickler(file, protocol, self.inj_objs) - -# class _PickleInject: -# """Base class for pickling injected commands""" - -# def __init__(self, args: Any, command: Any = None) -> None: -# self.command = command -# self.args = args - -# def __reduce__(self) -> Tuple[Any, Any]: -# """ -# In general, the __reduce__ function is used by pickle to serialize objects. -# If defined for an object, pickle would override its default __reduce__ function and serialize the object as outlined by the custom specified __reduce__ function, -# The object returned by __reduce__ here is a callable: (self.command), and the tuple: with first element (self.args) is the code to be executed by self.command. -# """ -# return self.command, (self.args,) - -# class System(_PickleInject): -# """Create os.system command""" - -# def __init__(self, args: Any) -> None: -# super().__init__(args, command=os.system) - -# class Exec(_PickleInject): -# """Create exec command""" - -# def __init__(self, args: Any) -> None: -# super().__init__(args, command=exec) - -# class Eval(_PickleInject): -# """Create eval command""" - -# def __init__(self, args: Any) -> None: -# super().__init__(args, command=eval) - -# class RunPy(_PickleInject): -# """Create runpy command""" - -# def __init__(self, args: Any) -> None: -# import runpy - -# super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] - -# def __reduce__(self) -> Tuple[Any, Any]: -# return self.command, (self.args, {}) + +class PickleInject: + """Pickle injection""" + + def __init__(self, inj_objs: Any, first: bool = True): + self.__name__ = "pickle_inject" + self.inj_objs = inj_objs + self.first = first + + class _Pickler(pickle._Pickler): + """Reimplementation of Pickler with support for injection""" + + def __init__( + self, file: Any, protocol: Any, inj_objs: Any, first: bool = True + ) -> None: + """ + file: File object with write attribute + protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html + inj_objs: _joblibInject object that has both the command, and the code to be injected + first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file. + """ + super().__init__(file, protocol) + self.inj_objs = inj_objs + self.first = first + + def dump(self, obj: Any) -> None: + """Pickle data, inject object before or after""" + if self.proto >= 2: # type: ignore[attr-defined] + self.write(pickle.PROTO + struct.pack("= 4: # type: ignore[attr-defined] + self.framer.start_framing() # type: ignore[attr-defined] + + # Inject the object(s) before the user-supplied data? + if self.first: + # Pickle injected objects + for inj_obj in self.inj_objs: + self.save(inj_obj) # type: ignore[attr-defined] + + # Pickle user-supplied data + self.save(obj) # type: ignore[attr-defined] + + # Inject the object(s) after the user-supplied data? + if not self.first: + # Pickle injected objects + for inj_obj in self.inj_objs: + self.save(inj_obj) # type: ignore[attr-defined] + + self.write(pickle.STOP) # type: ignore[attr-defined] + self.framer.end_framing() # type: ignore[attr-defined] + + def Pickler(self, file: Any, protocol: Any) -> _Pickler: + # Initialise the pickler interface with the injected object + return self._Pickler(file, protocol, self.inj_objs) + + class _PickleInject: + """Base class for pickling injected commands""" + + def __init__(self, args: Any, command: Any = None) -> None: + self.command = command + self.args = args + + def __reduce__(self) -> Tuple[Any, Any]: + """ + In general, the __reduce__ function is used by pickle to serialize objects. + If defined for an object, pickle would override its default __reduce__ function and serialize the object as outlined by the custom specified __reduce__ function, + The object returned by __reduce__ here is a callable: (self.command), and the tuple: with first element (self.args) is the code to be executed by self.command. + """ + return self.command, (self.args,) + + class System(_PickleInject): + """Create os.system command""" + + def __init__(self, args: Any) -> None: + super().__init__(args, command=os.system) + + class Exec(_PickleInject): + """Create exec command""" + + def __init__(self, args: Any) -> None: + super().__init__(args, command=exec) + + class Eval(_PickleInject): + """Create eval command""" + + def __init__(self, args: Any) -> None: + super().__init__(args, command=eval) + + class RunPy(_PickleInject): + """Create runpy command""" + + def __init__(self, args: Any) -> None: + import runpy + + super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] + + def __reduce__(self) -> Tuple[Any, Any]: + return self.command, (self.args, {}) -# def get_pickle_payload(command: str, malicious_code: str) -> Any: -# if command == "system": -# payload: Any = PickleInject.System(malicious_code) -# elif command == "exec": -# payload = PickleInject.Exec(malicious_code) -# elif command == "eval": -# payload = PickleInject.Eval(malicious_code) -# elif command == "runpy": -# payload = PickleInject.RunPy(malicious_code) -# return payload +def get_pickle_payload(command: str, malicious_code: str) -> Any: + if command == "system": + payload: Any = PickleInject.System(malicious_code) + elif command == "exec": + payload = PickleInject.Exec(malicious_code) + elif command == "eval": + payload = PickleInject.Eval(malicious_code) + elif command == "runpy": + payload = PickleInject.RunPy(malicious_code) + return payload -# def generate_unsafe_pickle_file( -# safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str -# ) -> None: -# payload = get_pickle_payload(command, malicious_code) -# pickle_protocol = 4 -# file_for_unsafe_model = open(unsafe_model_path, "wb") -# mypickler = PickleInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) -# mypickler.dump(safe_model) -# file_for_unsafe_model.close() +def generate_unsafe_pickle_file( + safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str +) -> None: + payload = get_pickle_payload(command, malicious_code) + pickle_protocol = 4 + file_for_unsafe_model = open(unsafe_model_path, "wb") + mypickler = PickleInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) + mypickler.dump(safe_model) + file_for_unsafe_model.close() -# class DillInject: -# """Code injection using Dill Pickler""" +class DillInject: + """Code injection using Dill Pickler""" -# def __init__(self, inj_objs: Any, first: bool = True): -# self.__name__ = "dill_inject" -# self.inj_objs = inj_objs -# self.first = first + def __init__(self, inj_objs: Any, first: bool = True): + self.__name__ = "dill_inject" + self.inj_objs = inj_objs + self.first = first -# class _Pickler(dill._dill.Pickler): # type: ignore[misc] -# """Reimplementation of Pickler with support for injection""" + class _Pickler(dill._dill.Pickler): # type: ignore[misc] + """Reimplementation of Pickler with support for injection""" -# def __init__(self, file: Any, protocol: Any, inj_objs: Any, first: bool = True): -# super().__init__(file, protocol) -# self.inj_objs = inj_objs -# self.first = first + def __init__(self, file: Any, protocol: Any, inj_objs: Any, first: bool = True): + super().__init__(file, protocol) + self.inj_objs = inj_objs + self.first = first -# def dump(self, obj: Any) -> None: -# """Pickle data, inject object before or after""" -# if self.proto >= 2: -# self.write(pickle.PROTO + struct.pack("= 4: -# self.framer.start_framing() + def dump(self, obj: Any) -> None: + """Pickle data, inject object before or after""" + if self.proto >= 2: + self.write(pickle.PROTO + struct.pack("= 4: + self.framer.start_framing() -# # Inject the object(s) before the user-supplied data? -# if self.first: -# # Pickle injected objects -# for inj_obj in self.inj_objs: -# self.save(inj_obj) + # Inject the object(s) before the user-supplied data? + if self.first: + # Pickle injected objects + for inj_obj in self.inj_objs: + self.save(inj_obj) -# # Pickle user-supplied data -# self.save(obj) + # Pickle user-supplied data + self.save(obj) -# # Inject the object(s) after the user-supplied data? -# if not self.first: -# # Pickle injected objects -# for inj_obj in self.inj_objs: -# self.save(inj_obj) + # Inject the object(s) after the user-supplied data? + if not self.first: + # Pickle injected objects + for inj_obj in self.inj_objs: + self.save(inj_obj) -# self.write(pickle.STOP) -# self.framer.end_framing() + self.write(pickle.STOP) + self.framer.end_framing() -# def DillPickler(self, file: Any, protocol: Any) -> _Pickler: -# # Initialise the pickler interface with the injected object -# return self._Pickler(file, protocol, self.inj_objs) + def DillPickler(self, file: Any, protocol: Any) -> _Pickler: + # Initialise the pickler interface with the injected object + return self._Pickler(file, protocol, self.inj_objs) -# class _DillInject: -# """Base class for pickling injected commands""" + class _DillInject: + """Base class for pickling injected commands""" -# def __init__(self, args: Any, command: Any = None): -# self.command = command -# self.args = args + def __init__(self, args: Any, command: Any = None): + self.command = command + self.args = args -# def __reduce__(self) -> Tuple[Any, Any]: -# return self.command, (self.args,) + def __reduce__(self) -> Tuple[Any, Any]: + return self.command, (self.args,) -# class System(_DillInject): -# """Create os.system command""" + class System(_DillInject): + """Create os.system command""" -# def __init__(self, args: Any): -# super().__init__(args, command=os.system) + def __init__(self, args: Any): + super().__init__(args, command=os.system) -# class Exec(_DillInject): -# """Create exec command""" + class Exec(_DillInject): + """Create exec command""" -# def __init__(self, args: Any): -# super().__init__(args, command=exec) + def __init__(self, args: Any): + super().__init__(args, command=exec) -# class Eval(_DillInject): -# """Create eval command""" + class Eval(_DillInject): + """Create eval command""" -# def __init__(self, args: Any): -# super().__init__(args, command=eval) - -# class RunPy(_DillInject): -# """Create runpy command""" - -# def __init__(self, args: Any): -# import runpy - -# super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] - -# def __reduce__(self) -> Any: -# return self.command, (self.args, {}) + def __init__(self, args: Any): + super().__init__(args, command=eval) + + class RunPy(_DillInject): + """Create runpy command""" + + def __init__(self, args: Any): + import runpy + + super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] + + def __reduce__(self) -> Any: + return self.command, (self.args, {}) -# def get_dill_payload(command: str, malicious_code: str) -> Any: -# payload: Any -# if command == "system": -# payload = DillInject.System(malicious_code) -# elif command == "exec": -# payload = DillInject.Exec(malicious_code) -# elif command == "eval": -# payload = DillInject.Eval(malicious_code) -# elif command == "runpy": -# payload = DillInject.RunPy(malicious_code) -# return payload +def get_dill_payload(command: str, malicious_code: str) -> Any: + payload: Any + if command == "system": + payload = DillInject.System(malicious_code) + elif command == "exec": + payload = DillInject.Exec(malicious_code) + elif command == "eval": + payload = DillInject.Eval(malicious_code) + elif command == "runpy": + payload = DillInject.RunPy(malicious_code) + return payload -# def generate_dill_unsafe_file( -# safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str -# ) -> None: -# payload = get_dill_payload(command, malicious_code) -# pickle_protocol = 4 -# file_for_unsafe_model = open(unsafe_model_path, "wb") -# mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) -# mypickler.dump(safe_model) -# file_for_unsafe_model.close() +def generate_dill_unsafe_file( + safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str +) -> None: + payload = get_dill_payload(command, malicious_code) + pickle_protocol = 4 + file_for_unsafe_model = open(unsafe_model_path, "wb") + mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) + mypickler.dump(safe_model) + file_for_unsafe_model.close() From ba61141d70b78ad1a8f596f6dd9f964656aff766 Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Fri, 29 Sep 2023 15:30:54 -0700 Subject: [PATCH 4/7] Revert "uncomment full tests, re-add pip" This reverts commit a7985a8f6d73997666676b87537b656affb59dd0. --- .github/workflows/coverage.yml | 2 - tests/test_modelscan.py | 1456 ++++++++++++++++---------------- tests/test_utils.py | 414 ++++----- 3 files changed, 935 insertions(+), 937 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index b3d1ff08..1440b358 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -33,8 +33,6 @@ jobs: make install-test - name: Run Coverage run: | - pip install coverage - pip install pytest poetry add coverage codecov poetry run coverage run -m pytest - name: Upload coverage reports to Codecov diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index d6395226..ad2226b8 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -1,737 +1,737 @@ -import aiohttp -import http.client -import importlib -import io -import numpy as np -import os -from pathlib import Path -import pickle -import pytest -import requests # type: ignore[import] -import socket -import subprocess -import sys -import tensorflow as tf -from tensorflow import keras -from typing import Any, List, Set -from test_utils import generate_dill_unsafe_file -import zipfile +# import aiohttp +# import http.client +# import importlib +# import io +# import numpy as np +# import os +# from pathlib import Path +# import pickle +# import pytest +# import requests # type: ignore[import] +# import socket +# import subprocess +# import sys +# import tensorflow as tf +# from tensorflow import keras +# from typing import Any, List, Set +# from test_utils import generate_dill_unsafe_file +# import zipfile -from modelscan.modelscan import Modelscan -from modelscan.cli import cli -from modelscan.issues import ( - Issue, - IssueCode, - IssueSeverity, - OperatorIssueDetails, -) -from modelscan.tools.picklescanner import ( - scan_pickle_bytes, - scan_numpy, -) +# from modelscan.modelscan import Modelscan +# from modelscan.cli import cli +# from modelscan.issues import ( +# Issue, +# IssueCode, +# IssueSeverity, +# OperatorIssueDetails, +# ) +# from modelscan.tools.picklescanner import ( +# scan_pickle_bytes, +# scan_numpy, +# ) -class Malicious1: - def __reduce__(self) -> Any: - return eval, ("print('456')",) +# class Malicious1: +# def __reduce__(self) -> Any: +# return eval, ("print('456')",) -class Malicious2: - def __reduce__(self) -> Any: - return os.system, ("ls -la",) - - -class Malicious3: - def __reduce__(self) -> Any: - return http.client.HTTPSConnection, ("github.com",) - - -malicious3_pickle_bytes = pickle.dumps( - Malicious3(), protocol=0 -) # Malicious3 needs to be pickled before HTTPSConnection is mocked below - - -class Malicious4: - def __reduce__(self) -> Any: - return requests.get, ("https://github.com",) - - -class Malicious5: - def __reduce__(self) -> Any: - return aiohttp.ClientSession, tuple() - - -class Malicious6: - def __reduce__(self) -> Any: - return socket.create_connection, (("github.com", 80),) - - -class Malicious7: - def __reduce__(self) -> Any: - return subprocess.run, (["ls", "-l"],) - - -class Malicious8: - def __reduce__(self) -> Any: - return sys.exit, (0,) - - -def initialize_pickle_file(path: str, obj: Any, version: int) -> None: - if not os.path.exists(path): - with open(path, "wb") as file: - pickle.dump(obj, file, protocol=version) - - -def initialize_data_file(path: str, data: Any) -> None: - if not os.path.exists(path): - with open(path, "wb") as file: - file.write(data) - - -def initialize_zip_file(path: str, file_name: str, data: Any) -> None: - if not os.path.exists(path): - with zipfile.ZipFile(path, "w") as zip: - zip.writestr(file_name, data) - - -def initialize_numpy_file(path: str) -> None: - import numpy as np - - # create numpy object array - with open(path, "wb") as f: - data = [(1, 2), (3, 4)] - x = np.empty((2, 2), dtype=object) - x[:] = data - np.save(f, x) - - -@pytest.fixture(scope="session") -def zip_file_path(tmp_path_factory: Any) -> Any: - tmp = tmp_path_factory.mktemp("zip") - initialize_zip_file( - f"{tmp}/test.zip", - "data.pkl", - pickle.dumps(Malicious1(), protocol=4), - ) - return tmp +# class Malicious2: +# def __reduce__(self) -> Any: +# return os.system, ("ls -la",) + + +# class Malicious3: +# def __reduce__(self) -> Any: +# return http.client.HTTPSConnection, ("github.com",) + + +# malicious3_pickle_bytes = pickle.dumps( +# Malicious3(), protocol=0 +# ) # Malicious3 needs to be pickled before HTTPSConnection is mocked below + + +# class Malicious4: +# def __reduce__(self) -> Any: +# return requests.get, ("https://github.com",) + + +# class Malicious5: +# def __reduce__(self) -> Any: +# return aiohttp.ClientSession, tuple() + + +# class Malicious6: +# def __reduce__(self) -> Any: +# return socket.create_connection, (("github.com", 80),) + + +# class Malicious7: +# def __reduce__(self) -> Any: +# return subprocess.run, (["ls", "-l"],) + + +# class Malicious8: +# def __reduce__(self) -> Any: +# return sys.exit, (0,) + + +# def initialize_pickle_file(path: str, obj: Any, version: int) -> None: +# if not os.path.exists(path): +# with open(path, "wb") as file: +# pickle.dump(obj, file, protocol=version) + + +# def initialize_data_file(path: str, data: Any) -> None: +# if not os.path.exists(path): +# with open(path, "wb") as file: +# file.write(data) + + +# def initialize_zip_file(path: str, file_name: str, data: Any) -> None: +# if not os.path.exists(path): +# with zipfile.ZipFile(path, "w") as zip: +# zip.writestr(file_name, data) + + +# def initialize_numpy_file(path: str) -> None: +# import numpy as np + +# # create numpy object array +# with open(path, "wb") as f: +# data = [(1, 2), (3, 4)] +# x = np.empty((2, 2), dtype=object) +# x[:] = data +# np.save(f, x) + + +# @pytest.fixture(scope="session") +# def zip_file_path(tmp_path_factory: Any) -> Any: +# tmp = tmp_path_factory.mktemp("zip") +# initialize_zip_file( +# f"{tmp}/test.zip", +# "data.pkl", +# pickle.dumps(Malicious1(), protocol=4), +# ) +# return tmp -@pytest.fixture(scope="session") -def pickle_file_path(tmp_path_factory: Any) -> Any: - tmp = tmp_path_factory.mktemp("test_files") - os.makedirs(f"{tmp}/data", exist_ok=True) +# @pytest.fixture(scope="session") +# def pickle_file_path(tmp_path_factory: Any) -> Any: +# tmp = tmp_path_factory.mktemp("test_files") +# os.makedirs(f"{tmp}/data", exist_ok=True) - # Test with Pickle versions 0, 3, and 4: - # - Pickle versions 0, 1, 2 have built-in functions under '__builtin__' while versions 3 and 4 have them under 'builtins' - # - Pickle versions 0, 1, 2, 3 use 'GLOBAL' opcode while 4 uses 'STACK_GLOBAL' opcode - for version in (0, 3, 4): - initialize_pickle_file( - f"{tmp}/data/benign0_v{version}.pkl", ["a", "b", "c"], version - ) - initialize_pickle_file( - f"{tmp}/data/malicious1_v{version}.pkl", Malicious1(), version - ) - initialize_pickle_file( - f"{tmp}/data/malicious2_v{version}.pkl", Malicious2(), version - ) +# # Test with Pickle versions 0, 3, and 4: +# # - Pickle versions 0, 1, 2 have built-in functions under '__builtin__' while versions 3 and 4 have them under 'builtins' +# # - Pickle versions 0, 1, 2, 3 use 'GLOBAL' opcode while 4 uses 'STACK_GLOBAL' opcode +# for version in (0, 3, 4): +# initialize_pickle_file( +# f"{tmp}/data/benign0_v{version}.pkl", ["a", "b", "c"], version +# ) +# initialize_pickle_file( +# f"{tmp}/data/malicious1_v{version}.pkl", Malicious1(), version +# ) +# initialize_pickle_file( +# f"{tmp}/data/malicious2_v{version}.pkl", Malicious2(), version +# ) - # Malicious Pickle from https://sensepost.com/cms/resources/conferences/2011/sour_pickles/BH_US_11_Slaviero_Sour_Pickles.pdf - initialize_data_file( - f"{tmp}/data/malicious0.pkl", - b'c__builtin__\nglobals\n(tRp100\n0c__builtin__\ncompile\n(S\'fl=open("/etc/passwd");picklesmashed=fl.read();' - + b"'\nS''\nS'exec'\ntRp101\n0c__builtin__\neval\n(g101\ng100\ntRp102\n0c__builtin__\ngetattr\n(c__builtin__\n" - + b"dict\nS'get'\ntRp103\n0c__builtin__\napply\n(g103\n(g100\nS'picklesmashed'\nltRp104\n0g104\n.", - ) - - initialize_data_file(f"{tmp}/data/malicious3.pkl", malicious3_pickle_bytes) - initialize_pickle_file(f"{tmp}/data/malicious4.pickle", Malicious4(), 4) - initialize_pickle_file(f"{tmp}/data/malicious5.pickle", Malicious5(), 4) - initialize_data_file( - f"{tmp}/data/malicious6.pkl", - pickle.dumps(["a", "b", "c"]) + pickle.dumps(Malicious4()), - ) - initialize_pickle_file(f"{tmp}/data/malicious7.pkl", Malicious6(), 4) - initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4) - initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4) - - initialize_zip_file( - f"{tmp}/data/malicious1.zip", - "data.pkl", - pickle.dumps(Malicious1(), protocol=4), - ) - - initialize_numpy_file(f"{tmp}/data/object_array.npy") - - return tmp - - -@pytest.fixture(scope="session") -def keras_file_path(tmp_path_factory: Any) -> Any: - # Create a simple model. - inputs = keras.Input(shape=(32,)) - outputs = keras.layers.Dense(1)(inputs) - keras_model = keras.Model(inputs, outputs) - keras_model.compile(optimizer="adam", loss="mean_squared_error") - - # Train the model. - test_input = np.random.random((128, 32)) - test_target = np.random.random((128, 1)) - keras_model.fit(test_input, test_target) - - tmp = tmp_path_factory.mktemp("keras") - with open(f"{tmp}/safe", "wb") as fo: - pickle.dump(keras_model, fo) - keras_model.save(f"{tmp}/safe.h5") - - # Inject code with the command - command = "exec" - malicious_code = 'print("Malicious code!")' - - generate_dill_unsafe_file(keras_model, command, malicious_code, f"{tmp}/unsafe") - attack = ( - lambda x: exec( # type: ignore[func-returns-value] - """import http.client -import json -import os -conn = http.client.HTTPSConnection("protectai.com")""" - ) - or x - ) - input_to_new_layer = keras.layers.Dense(1)(keras_model.layers[-1].output) - new_layer = keras.layers.Lambda(attack)(input_to_new_layer) - - malicious_model = tf.keras.Model(inputs=keras_model.inputs, outputs=[new_layer]) - malicious_model.compile(optimizer="adam", loss="mean_squared_error") - - malicious_model.save(f"{tmp}/unsafe.h5") - - return tmp - - -def compare_results(resultList: List[Issue], expectedSet: Set[Issue]) -> None: - for result in resultList: - assert result in expectedSet - - -def test_scan_pickle_bytes() -> None: - expected = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails("builtins", "eval", "file.pkl"), - ) - ] - assert ( - scan_pickle_bytes(io.BytesIO(pickle.dumps(Malicious1())), "file.pkl")[0] - == expected - ) - - -def test_scan_zip(zip_file_path: Any) -> None: - expected = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{zip_file_path}/test.zip:data.pkl" - ), - ) - ] - - ms = Modelscan() - ms._scan_zip(f"{zip_file_path}/test.zip") - assert ms.issues.all_issues == expected - - -def test_scan_numpy(pickle_file_path: Any) -> None: - expected = { - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "numpy.core.multiarray", "_reconstruct", "object_array.npy" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails("numpy", "ndarray", "object_array.npy"), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails("numpy", "dtype", "object_array.npy"), - ), - } - with open(f"{pickle_file_path}/data/object_array.npy", "rb") as f: - compare_results( - scan_numpy(io.BytesIO(f.read()), "object_array.npy")[0], expected - ) - - -def test_scan_file_path(pickle_file_path: Any) -> None: - benign = Modelscan() - benign.scan_path(Path(f"{pickle_file_path}/data/benign0_v3.pkl")) - assert benign.issues.all_issues == [] - - malicious0 = Modelscan() - expected_malicious0 = { - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - } - malicious0.scan_path(Path(f"{pickle_file_path}/data/malicious0.pkl")) - compare_results(malicious0.issues.all_issues, expected_malicious0) - - expected_malicious1_v0 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" - ), - ) - ] - expected_malicious1_v3 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" - ), - ) - ] - expected_malicious1_v4 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" - ), - ) - ] - expected_malicious1 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" - ), - ) - ] - malicious1_v0 = Modelscan() - malicious1_v3 = Modelscan() - malicious1_v4 = Modelscan() - malicious1 = Modelscan() - malicious1_v0.scan_path(Path(f"{pickle_file_path}/data/malicious1_v0.pkl")) - malicious1_v3.scan_path(Path(f"{pickle_file_path}/data/malicious1_v3.pkl")) - malicious1_v4.scan_path(Path(f"{pickle_file_path}/data/malicious1_v4.pkl")) - malicious1.scan_path(Path(f"{pickle_file_path}/data/malicious1.zip")) - assert malicious1_v0.issues.all_issues == expected_malicious1_v0 - assert malicious1_v3.issues.all_issues == expected_malicious1_v3 - assert malicious1_v4.issues.all_issues == expected_malicious1_v4 - assert malicious1.issues.all_issues == expected_malicious1 - - expected_malicious2_v0 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" - ), - ) - ] - expected_malicious2_v3 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" - ), - ) - ] - expected_malicious2_v4 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" - ), - ) - ] - malicious2_v0 = Modelscan() - malicious2_v3 = Modelscan() - malicious2_v4 = Modelscan() - malicious2_v0.scan_path(Path(f"{pickle_file_path}/data/malicious2_v0.pkl")) - malicious2_v3.scan_path(Path(f"{pickle_file_path}/data/malicious2_v3.pkl")) - malicious2_v4.scan_path(Path(f"{pickle_file_path}/data/malicious2_v4.pkl")) - assert malicious2_v0.issues.all_issues == expected_malicious2_v0 - assert malicious2_v3.issues.all_issues == expected_malicious2_v3 - assert malicious2_v4.issues.all_issues == expected_malicious2_v4 - - expected_malicious3 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "httplib", - "HTTPSConnection", - Path(f"{pickle_file_path}/data/malicious3.pkl"), - ), - ) - ] - malicious3 = Modelscan() - malicious3.scan_path(Path(f"{pickle_file_path}/data/malicious3.pkl")) - assert malicious3.issues.all_issues == expected_malicious3 - - expected_malicious4 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" - ), - ) - ] - malicious4 = Modelscan() - malicious4.scan_path(Path(f"{pickle_file_path}/data/malicious4.pickle")) - assert malicious4.issues.all_issues == expected_malicious4 - - expected_malicious5 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "aiohttp.client", - "ClientSession", - f"{pickle_file_path}/data/malicious5.pickle", - ), - ) - ] - malicious5 = Modelscan() - malicious5.scan_path(Path(f"{pickle_file_path}/data/malicious5.pickle")) - assert malicious5.issues.all_issues == expected_malicious5 - - expected_malicious6 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" - ), - ) - ] - malicious6 = Modelscan() - malicious6.scan_path(Path(f"{pickle_file_path}/data/malicious6.pkl")) - assert malicious6.issues.all_issues == expected_malicious6 - - expected_malicious7 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" - ), - ) - ] - malicious7 = Modelscan() - malicious7.scan_path(Path(f"{pickle_file_path}/data/malicious7.pkl")) - assert malicious7.issues.all_issues == expected_malicious7 - - expected_malicious8 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" - ), - ) - ] - malicious8 = Modelscan() - malicious8.scan_path(Path(f"{pickle_file_path}/data/malicious8.pkl")) - assert malicious8.issues.all_issues == expected_malicious8 - - expected_malicious9 = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" - ), - ) - ] - malicious9 = Modelscan() - malicious9.scan_path(Path(f"{pickle_file_path}/data/malicious9.pkl")) - assert malicious9.issues.all_issues == expected_malicious9 - - -def test_scan_directory_path(pickle_file_path: str) -> None: - expected = { - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "numpy", "ndarray", f"{pickle_file_path}/data/object_array.npy" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "numpy.core.multiarray", - "_reconstruct", - f"{pickle_file_path}/data/object_array.npy", - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "aiohttp.client", - "ClientSession", - f"{pickle_file_path}/data/malicious5.pickle", - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.HIGH, - OperatorIssueDetails( - "httplib", "HTTPSConnection", f"{pickle_file_path}/data/malicious3.pkl" - ), - ), - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" - ), - ), - } - ms = Modelscan() - p = Path(f"{pickle_file_path}/data/") - ms.scan_path(p) - compare_results(ms.issues.all_issues, expected) - - -def test_scan_huggingface_model() -> None: - expected = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.CRITICAL, - OperatorIssueDetails( - "__builtin__", - "eval", - "https://huggingface.co/ykilcher/totally-harmless-model/resolve/main/pytorch_model.bin:archive/data.pkl", - ), - ) - ] - ms = Modelscan() - ms.scan_huggingface_model("ykilcher/totally-harmless-model") - assert ms.issues.all_issues == expected - - -# def test_scan_tf() -> None: - - -def test_scan_keras(keras_file_path: Any) -> None: - ms = Modelscan() - ms.scan_path(Path(f"{keras_file_path}/safe.h5")) - assert ms.issues.all_issues == [] - - expected = [ - Issue( - IssueCode.UNSAFE_OPERATOR, - IssueSeverity.MEDIUM, - OperatorIssueDetails( - "Keras", - "Lambda", - f"{keras_file_path}/unsafe.h5", - ), - ) - ] - ms.scan_path(Path(f"{keras_file_path}/unsafe.h5")) - assert ms.issues.all_issues == expected - - -def test_main(pickle_file_path: Any) -> None: - argv = sys.argv - try: - sys.argv = ["modelscan", "-p", f"{pickle_file_path}/data/benign0_v3.pkl"] - assert cli() == 0 - importlib.import_module("modelscan.scanner") - except SystemExit: - pass - finally: - sys.argv = argv +# # Malicious Pickle from https://sensepost.com/cms/resources/conferences/2011/sour_pickles/BH_US_11_Slaviero_Sour_Pickles.pdf +# initialize_data_file( +# f"{tmp}/data/malicious0.pkl", +# b'c__builtin__\nglobals\n(tRp100\n0c__builtin__\ncompile\n(S\'fl=open("/etc/passwd");picklesmashed=fl.read();' +# + b"'\nS''\nS'exec'\ntRp101\n0c__builtin__\neval\n(g101\ng100\ntRp102\n0c__builtin__\ngetattr\n(c__builtin__\n" +# + b"dict\nS'get'\ntRp103\n0c__builtin__\napply\n(g103\n(g100\nS'picklesmashed'\nltRp104\n0g104\n.", +# ) + +# initialize_data_file(f"{tmp}/data/malicious3.pkl", malicious3_pickle_bytes) +# initialize_pickle_file(f"{tmp}/data/malicious4.pickle", Malicious4(), 4) +# initialize_pickle_file(f"{tmp}/data/malicious5.pickle", Malicious5(), 4) +# initialize_data_file( +# f"{tmp}/data/malicious6.pkl", +# pickle.dumps(["a", "b", "c"]) + pickle.dumps(Malicious4()), +# ) +# initialize_pickle_file(f"{tmp}/data/malicious7.pkl", Malicious6(), 4) +# initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4) +# initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4) + +# initialize_zip_file( +# f"{tmp}/data/malicious1.zip", +# "data.pkl", +# pickle.dumps(Malicious1(), protocol=4), +# ) + +# initialize_numpy_file(f"{tmp}/data/object_array.npy") + +# return tmp + + +# @pytest.fixture(scope="session") +# def keras_file_path(tmp_path_factory: Any) -> Any: +# # Create a simple model. +# inputs = keras.Input(shape=(32,)) +# outputs = keras.layers.Dense(1)(inputs) +# keras_model = keras.Model(inputs, outputs) +# keras_model.compile(optimizer="adam", loss="mean_squared_error") + +# # Train the model. +# test_input = np.random.random((128, 32)) +# test_target = np.random.random((128, 1)) +# keras_model.fit(test_input, test_target) + +# tmp = tmp_path_factory.mktemp("keras") +# with open(f"{tmp}/safe", "wb") as fo: +# pickle.dump(keras_model, fo) +# keras_model.save(f"{tmp}/safe.h5") + +# # Inject code with the command +# command = "exec" +# malicious_code = 'print("Malicious code!")' + +# generate_dill_unsafe_file(keras_model, command, malicious_code, f"{tmp}/unsafe") +# attack = ( +# lambda x: exec( # type: ignore[func-returns-value] +# """import http.client +# import json +# import os +# conn = http.client.HTTPSConnection("protectai.com")""" +# ) +# or x +# ) +# input_to_new_layer = keras.layers.Dense(1)(keras_model.layers[-1].output) +# new_layer = keras.layers.Lambda(attack)(input_to_new_layer) + +# malicious_model = tf.keras.Model(inputs=keras_model.inputs, outputs=[new_layer]) +# malicious_model.compile(optimizer="adam", loss="mean_squared_error") + +# malicious_model.save(f"{tmp}/unsafe.h5") + +# return tmp + + +# def compare_results(resultList: List[Issue], expectedSet: Set[Issue]) -> None: +# for result in resultList: +# assert result in expectedSet + + +# def test_scan_pickle_bytes() -> None: +# expected = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails("builtins", "eval", "file.pkl"), +# ) +# ] +# assert ( +# scan_pickle_bytes(io.BytesIO(pickle.dumps(Malicious1())), "file.pkl")[0] +# == expected +# ) + + +# def test_scan_zip(zip_file_path: Any) -> None: +# expected = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{zip_file_path}/test.zip:data.pkl" +# ), +# ) +# ] + +# ms = Modelscan() +# ms._scan_zip(f"{zip_file_path}/test.zip") +# assert ms.issues.all_issues == expected + + +# def test_scan_numpy(pickle_file_path: Any) -> None: +# expected = { +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "numpy.core.multiarray", "_reconstruct", "object_array.npy" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails("numpy", "ndarray", "object_array.npy"), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails("numpy", "dtype", "object_array.npy"), +# ), +# } +# with open(f"{pickle_file_path}/data/object_array.npy", "rb") as f: +# compare_results( +# scan_numpy(io.BytesIO(f.read()), "object_array.npy")[0], expected +# ) + + +# def test_scan_file_path(pickle_file_path: Any) -> None: +# benign = Modelscan() +# benign.scan_path(Path(f"{pickle_file_path}/data/benign0_v3.pkl")) +# assert benign.issues.all_issues == [] + +# malicious0 = Modelscan() +# expected_malicious0 = { +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# } +# malicious0.scan_path(Path(f"{pickle_file_path}/data/malicious0.pkl")) +# compare_results(malicious0.issues.all_issues, expected_malicious0) + +# expected_malicious1_v0 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" +# ), +# ) +# ] +# expected_malicious1_v3 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" +# ), +# ) +# ] +# expected_malicious1_v4 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" +# ), +# ) +# ] +# expected_malicious1 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" +# ), +# ) +# ] +# malicious1_v0 = Modelscan() +# malicious1_v3 = Modelscan() +# malicious1_v4 = Modelscan() +# malicious1 = Modelscan() +# malicious1_v0.scan_path(Path(f"{pickle_file_path}/data/malicious1_v0.pkl")) +# malicious1_v3.scan_path(Path(f"{pickle_file_path}/data/malicious1_v3.pkl")) +# malicious1_v4.scan_path(Path(f"{pickle_file_path}/data/malicious1_v4.pkl")) +# malicious1.scan_path(Path(f"{pickle_file_path}/data/malicious1.zip")) +# assert malicious1_v0.issues.all_issues == expected_malicious1_v0 +# assert malicious1_v3.issues.all_issues == expected_malicious1_v3 +# assert malicious1_v4.issues.all_issues == expected_malicious1_v4 +# assert malicious1.issues.all_issues == expected_malicious1 + +# expected_malicious2_v0 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" +# ), +# ) +# ] +# expected_malicious2_v3 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" +# ), +# ) +# ] +# expected_malicious2_v4 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" +# ), +# ) +# ] +# malicious2_v0 = Modelscan() +# malicious2_v3 = Modelscan() +# malicious2_v4 = Modelscan() +# malicious2_v0.scan_path(Path(f"{pickle_file_path}/data/malicious2_v0.pkl")) +# malicious2_v3.scan_path(Path(f"{pickle_file_path}/data/malicious2_v3.pkl")) +# malicious2_v4.scan_path(Path(f"{pickle_file_path}/data/malicious2_v4.pkl")) +# assert malicious2_v0.issues.all_issues == expected_malicious2_v0 +# assert malicious2_v3.issues.all_issues == expected_malicious2_v3 +# assert malicious2_v4.issues.all_issues == expected_malicious2_v4 + +# expected_malicious3 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "httplib", +# "HTTPSConnection", +# Path(f"{pickle_file_path}/data/malicious3.pkl"), +# ), +# ) +# ] +# malicious3 = Modelscan() +# malicious3.scan_path(Path(f"{pickle_file_path}/data/malicious3.pkl")) +# assert malicious3.issues.all_issues == expected_malicious3 + +# expected_malicious4 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" +# ), +# ) +# ] +# malicious4 = Modelscan() +# malicious4.scan_path(Path(f"{pickle_file_path}/data/malicious4.pickle")) +# assert malicious4.issues.all_issues == expected_malicious4 + +# expected_malicious5 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "aiohttp.client", +# "ClientSession", +# f"{pickle_file_path}/data/malicious5.pickle", +# ), +# ) +# ] +# malicious5 = Modelscan() +# malicious5.scan_path(Path(f"{pickle_file_path}/data/malicious5.pickle")) +# assert malicious5.issues.all_issues == expected_malicious5 + +# expected_malicious6 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" +# ), +# ) +# ] +# malicious6 = Modelscan() +# malicious6.scan_path(Path(f"{pickle_file_path}/data/malicious6.pkl")) +# assert malicious6.issues.all_issues == expected_malicious6 + +# expected_malicious7 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" +# ), +# ) +# ] +# malicious7 = Modelscan() +# malicious7.scan_path(Path(f"{pickle_file_path}/data/malicious7.pkl")) +# assert malicious7.issues.all_issues == expected_malicious7 + +# expected_malicious8 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" +# ), +# ) +# ] +# malicious8 = Modelscan() +# malicious8.scan_path(Path(f"{pickle_file_path}/data/malicious8.pkl")) +# assert malicious8.issues.all_issues == expected_malicious8 + +# expected_malicious9 = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" +# ), +# ) +# ] +# malicious9 = Modelscan() +# malicious9.scan_path(Path(f"{pickle_file_path}/data/malicious9.pkl")) +# assert malicious9.issues.all_issues == expected_malicious9 + + +# def test_scan_directory_path(pickle_file_path: str) -> None: +# expected = { +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "numpy", "ndarray", f"{pickle_file_path}/data/object_array.npy" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "numpy.core.multiarray", +# "_reconstruct", +# f"{pickle_file_path}/data/object_array.npy", +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "aiohttp.client", +# "ClientSession", +# f"{pickle_file_path}/data/malicious5.pickle", +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.HIGH, +# OperatorIssueDetails( +# "httplib", "HTTPSConnection", f"{pickle_file_path}/data/malicious3.pkl" +# ), +# ), +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" +# ), +# ), +# } +# ms = Modelscan() +# p = Path(f"{pickle_file_path}/data/") +# ms.scan_path(p) +# compare_results(ms.issues.all_issues, expected) + + +# def test_scan_huggingface_model() -> None: +# expected = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.CRITICAL, +# OperatorIssueDetails( +# "__builtin__", +# "eval", +# "https://huggingface.co/ykilcher/totally-harmless-model/resolve/main/pytorch_model.bin:archive/data.pkl", +# ), +# ) +# ] +# ms = Modelscan() +# ms.scan_huggingface_model("ykilcher/totally-harmless-model") +# assert ms.issues.all_issues == expected + + +# # def test_scan_tf() -> None: + + +# def test_scan_keras(keras_file_path: Any) -> None: +# ms = Modelscan() +# ms.scan_path(Path(f"{keras_file_path}/safe.h5")) +# assert ms.issues.all_issues == [] + +# expected = [ +# Issue( +# IssueCode.UNSAFE_OPERATOR, +# IssueSeverity.MEDIUM, +# OperatorIssueDetails( +# "Keras", +# "Lambda", +# f"{keras_file_path}/unsafe.h5", +# ), +# ) +# ] +# ms.scan_path(Path(f"{keras_file_path}/unsafe.h5")) +# assert ms.issues.all_issues == expected + + +# def test_main(pickle_file_path: Any) -> None: +# argv = sys.argv +# try: +# sys.argv = ["modelscan", "-p", f"{pickle_file_path}/data/benign0_v3.pkl"] +# assert cli() == 0 +# importlib.import_module("modelscan.scanner") +# except SystemExit: +# pass +# finally: +# sys.argv = argv diff --git a/tests/test_utils.py b/tests/test_utils.py index 34b3eeb5..691520a5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,236 +1,236 @@ -import dill -import os -import pickle -import struct -from typing import Any, Tuple -import os +# import dill +# import os +# import pickle +# import struct +# from typing import Any, Tuple +# import os - -class PickleInject: - """Pickle injection""" - - def __init__(self, inj_objs: Any, first: bool = True): - self.__name__ = "pickle_inject" - self.inj_objs = inj_objs - self.first = first - - class _Pickler(pickle._Pickler): - """Reimplementation of Pickler with support for injection""" - - def __init__( - self, file: Any, protocol: Any, inj_objs: Any, first: bool = True - ) -> None: - """ - file: File object with write attribute - protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html - inj_objs: _joblibInject object that has both the command, and the code to be injected - first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file. - """ - super().__init__(file, protocol) - self.inj_objs = inj_objs - self.first = first - - def dump(self, obj: Any) -> None: - """Pickle data, inject object before or after""" - if self.proto >= 2: # type: ignore[attr-defined] - self.write(pickle.PROTO + struct.pack("= 4: # type: ignore[attr-defined] - self.framer.start_framing() # type: ignore[attr-defined] - - # Inject the object(s) before the user-supplied data? - if self.first: - # Pickle injected objects - for inj_obj in self.inj_objs: - self.save(inj_obj) # type: ignore[attr-defined] - - # Pickle user-supplied data - self.save(obj) # type: ignore[attr-defined] - - # Inject the object(s) after the user-supplied data? - if not self.first: - # Pickle injected objects - for inj_obj in self.inj_objs: - self.save(inj_obj) # type: ignore[attr-defined] - - self.write(pickle.STOP) # type: ignore[attr-defined] - self.framer.end_framing() # type: ignore[attr-defined] - - def Pickler(self, file: Any, protocol: Any) -> _Pickler: - # Initialise the pickler interface with the injected object - return self._Pickler(file, protocol, self.inj_objs) - - class _PickleInject: - """Base class for pickling injected commands""" - - def __init__(self, args: Any, command: Any = None) -> None: - self.command = command - self.args = args - - def __reduce__(self) -> Tuple[Any, Any]: - """ - In general, the __reduce__ function is used by pickle to serialize objects. - If defined for an object, pickle would override its default __reduce__ function and serialize the object as outlined by the custom specified __reduce__ function, - The object returned by __reduce__ here is a callable: (self.command), and the tuple: with first element (self.args) is the code to be executed by self.command. - """ - return self.command, (self.args,) - - class System(_PickleInject): - """Create os.system command""" - - def __init__(self, args: Any) -> None: - super().__init__(args, command=os.system) - - class Exec(_PickleInject): - """Create exec command""" - - def __init__(self, args: Any) -> None: - super().__init__(args, command=exec) - - class Eval(_PickleInject): - """Create eval command""" - - def __init__(self, args: Any) -> None: - super().__init__(args, command=eval) - - class RunPy(_PickleInject): - """Create runpy command""" - - def __init__(self, args: Any) -> None: - import runpy - - super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] - - def __reduce__(self) -> Tuple[Any, Any]: - return self.command, (self.args, {}) + +# class PickleInject: +# """Pickle injection""" + +# def __init__(self, inj_objs: Any, first: bool = True): +# self.__name__ = "pickle_inject" +# self.inj_objs = inj_objs +# self.first = first + +# class _Pickler(pickle._Pickler): +# """Reimplementation of Pickler with support for injection""" + +# def __init__( +# self, file: Any, protocol: Any, inj_objs: Any, first: bool = True +# ) -> None: +# """ +# file: File object with write attribute +# protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html +# inj_objs: _joblibInject object that has both the command, and the code to be injected +# first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file. +# """ +# super().__init__(file, protocol) +# self.inj_objs = inj_objs +# self.first = first + +# def dump(self, obj: Any) -> None: +# """Pickle data, inject object before or after""" +# if self.proto >= 2: # type: ignore[attr-defined] +# self.write(pickle.PROTO + struct.pack("= 4: # type: ignore[attr-defined] +# self.framer.start_framing() # type: ignore[attr-defined] + +# # Inject the object(s) before the user-supplied data? +# if self.first: +# # Pickle injected objects +# for inj_obj in self.inj_objs: +# self.save(inj_obj) # type: ignore[attr-defined] + +# # Pickle user-supplied data +# self.save(obj) # type: ignore[attr-defined] + +# # Inject the object(s) after the user-supplied data? +# if not self.first: +# # Pickle injected objects +# for inj_obj in self.inj_objs: +# self.save(inj_obj) # type: ignore[attr-defined] + +# self.write(pickle.STOP) # type: ignore[attr-defined] +# self.framer.end_framing() # type: ignore[attr-defined] + +# def Pickler(self, file: Any, protocol: Any) -> _Pickler: +# # Initialise the pickler interface with the injected object +# return self._Pickler(file, protocol, self.inj_objs) + +# class _PickleInject: +# """Base class for pickling injected commands""" + +# def __init__(self, args: Any, command: Any = None) -> None: +# self.command = command +# self.args = args + +# def __reduce__(self) -> Tuple[Any, Any]: +# """ +# In general, the __reduce__ function is used by pickle to serialize objects. +# If defined for an object, pickle would override its default __reduce__ function and serialize the object as outlined by the custom specified __reduce__ function, +# The object returned by __reduce__ here is a callable: (self.command), and the tuple: with first element (self.args) is the code to be executed by self.command. +# """ +# return self.command, (self.args,) + +# class System(_PickleInject): +# """Create os.system command""" + +# def __init__(self, args: Any) -> None: +# super().__init__(args, command=os.system) + +# class Exec(_PickleInject): +# """Create exec command""" + +# def __init__(self, args: Any) -> None: +# super().__init__(args, command=exec) + +# class Eval(_PickleInject): +# """Create eval command""" + +# def __init__(self, args: Any) -> None: +# super().__init__(args, command=eval) + +# class RunPy(_PickleInject): +# """Create runpy command""" + +# def __init__(self, args: Any) -> None: +# import runpy + +# super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] + +# def __reduce__(self) -> Tuple[Any, Any]: +# return self.command, (self.args, {}) -def get_pickle_payload(command: str, malicious_code: str) -> Any: - if command == "system": - payload: Any = PickleInject.System(malicious_code) - elif command == "exec": - payload = PickleInject.Exec(malicious_code) - elif command == "eval": - payload = PickleInject.Eval(malicious_code) - elif command == "runpy": - payload = PickleInject.RunPy(malicious_code) - return payload +# def get_pickle_payload(command: str, malicious_code: str) -> Any: +# if command == "system": +# payload: Any = PickleInject.System(malicious_code) +# elif command == "exec": +# payload = PickleInject.Exec(malicious_code) +# elif command == "eval": +# payload = PickleInject.Eval(malicious_code) +# elif command == "runpy": +# payload = PickleInject.RunPy(malicious_code) +# return payload -def generate_unsafe_pickle_file( - safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str -) -> None: - payload = get_pickle_payload(command, malicious_code) - pickle_protocol = 4 - file_for_unsafe_model = open(unsafe_model_path, "wb") - mypickler = PickleInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) - mypickler.dump(safe_model) - file_for_unsafe_model.close() +# def generate_unsafe_pickle_file( +# safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str +# ) -> None: +# payload = get_pickle_payload(command, malicious_code) +# pickle_protocol = 4 +# file_for_unsafe_model = open(unsafe_model_path, "wb") +# mypickler = PickleInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) +# mypickler.dump(safe_model) +# file_for_unsafe_model.close() -class DillInject: - """Code injection using Dill Pickler""" +# class DillInject: +# """Code injection using Dill Pickler""" - def __init__(self, inj_objs: Any, first: bool = True): - self.__name__ = "dill_inject" - self.inj_objs = inj_objs - self.first = first +# def __init__(self, inj_objs: Any, first: bool = True): +# self.__name__ = "dill_inject" +# self.inj_objs = inj_objs +# self.first = first - class _Pickler(dill._dill.Pickler): # type: ignore[misc] - """Reimplementation of Pickler with support for injection""" +# class _Pickler(dill._dill.Pickler): # type: ignore[misc] +# """Reimplementation of Pickler with support for injection""" - def __init__(self, file: Any, protocol: Any, inj_objs: Any, first: bool = True): - super().__init__(file, protocol) - self.inj_objs = inj_objs - self.first = first +# def __init__(self, file: Any, protocol: Any, inj_objs: Any, first: bool = True): +# super().__init__(file, protocol) +# self.inj_objs = inj_objs +# self.first = first - def dump(self, obj: Any) -> None: - """Pickle data, inject object before or after""" - if self.proto >= 2: - self.write(pickle.PROTO + struct.pack("= 4: - self.framer.start_framing() +# def dump(self, obj: Any) -> None: +# """Pickle data, inject object before or after""" +# if self.proto >= 2: +# self.write(pickle.PROTO + struct.pack("= 4: +# self.framer.start_framing() - # Inject the object(s) before the user-supplied data? - if self.first: - # Pickle injected objects - for inj_obj in self.inj_objs: - self.save(inj_obj) +# # Inject the object(s) before the user-supplied data? +# if self.first: +# # Pickle injected objects +# for inj_obj in self.inj_objs: +# self.save(inj_obj) - # Pickle user-supplied data - self.save(obj) +# # Pickle user-supplied data +# self.save(obj) - # Inject the object(s) after the user-supplied data? - if not self.first: - # Pickle injected objects - for inj_obj in self.inj_objs: - self.save(inj_obj) +# # Inject the object(s) after the user-supplied data? +# if not self.first: +# # Pickle injected objects +# for inj_obj in self.inj_objs: +# self.save(inj_obj) - self.write(pickle.STOP) - self.framer.end_framing() +# self.write(pickle.STOP) +# self.framer.end_framing() - def DillPickler(self, file: Any, protocol: Any) -> _Pickler: - # Initialise the pickler interface with the injected object - return self._Pickler(file, protocol, self.inj_objs) +# def DillPickler(self, file: Any, protocol: Any) -> _Pickler: +# # Initialise the pickler interface with the injected object +# return self._Pickler(file, protocol, self.inj_objs) - class _DillInject: - """Base class for pickling injected commands""" +# class _DillInject: +# """Base class for pickling injected commands""" - def __init__(self, args: Any, command: Any = None): - self.command = command - self.args = args +# def __init__(self, args: Any, command: Any = None): +# self.command = command +# self.args = args - def __reduce__(self) -> Tuple[Any, Any]: - return self.command, (self.args,) +# def __reduce__(self) -> Tuple[Any, Any]: +# return self.command, (self.args,) - class System(_DillInject): - """Create os.system command""" +# class System(_DillInject): +# """Create os.system command""" - def __init__(self, args: Any): - super().__init__(args, command=os.system) +# def __init__(self, args: Any): +# super().__init__(args, command=os.system) - class Exec(_DillInject): - """Create exec command""" +# class Exec(_DillInject): +# """Create exec command""" - def __init__(self, args: Any): - super().__init__(args, command=exec) +# def __init__(self, args: Any): +# super().__init__(args, command=exec) - class Eval(_DillInject): - """Create eval command""" +# class Eval(_DillInject): +# """Create eval command""" - def __init__(self, args: Any): - super().__init__(args, command=eval) - - class RunPy(_DillInject): - """Create runpy command""" - - def __init__(self, args: Any): - import runpy - - super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] - - def __reduce__(self) -> Any: - return self.command, (self.args, {}) +# def __init__(self, args: Any): +# super().__init__(args, command=eval) + +# class RunPy(_DillInject): +# """Create runpy command""" + +# def __init__(self, args: Any): +# import runpy + +# super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] + +# def __reduce__(self) -> Any: +# return self.command, (self.args, {}) -def get_dill_payload(command: str, malicious_code: str) -> Any: - payload: Any - if command == "system": - payload = DillInject.System(malicious_code) - elif command == "exec": - payload = DillInject.Exec(malicious_code) - elif command == "eval": - payload = DillInject.Eval(malicious_code) - elif command == "runpy": - payload = DillInject.RunPy(malicious_code) - return payload +# def get_dill_payload(command: str, malicious_code: str) -> Any: +# payload: Any +# if command == "system": +# payload = DillInject.System(malicious_code) +# elif command == "exec": +# payload = DillInject.Exec(malicious_code) +# elif command == "eval": +# payload = DillInject.Eval(malicious_code) +# elif command == "runpy": +# payload = DillInject.RunPy(malicious_code) +# return payload -def generate_dill_unsafe_file( - safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str -) -> None: - payload = get_dill_payload(command, malicious_code) - pickle_protocol = 4 - file_for_unsafe_model = open(unsafe_model_path, "wb") - mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) - mypickler.dump(safe_model) - file_for_unsafe_model.close() +# def generate_dill_unsafe_file( +# safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str +# ) -> None: +# payload = get_dill_payload(command, malicious_code) +# pickle_protocol = 4 +# file_for_unsafe_model = open(unsafe_model_path, "wb") +# mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) +# mypickler.dump(safe_model) +# file_for_unsafe_model.close() From c5457408d143ffc8664fa71124aef1f5978bdf18 Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Fri, 29 Sep 2023 15:31:00 -0700 Subject: [PATCH 5/7] Revert "no pip" This reverts commit a7e7fc6ec62c6381a9a61e893c7cbc0aec8ad635. --- .github/workflows/coverage.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 1440b358..b3d1ff08 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -33,6 +33,8 @@ jobs: make install-test - name: Run Coverage run: | + pip install coverage + pip install pytest poetry add coverage codecov poetry run coverage run -m pytest - name: Upload coverage reports to Codecov From c363e3514c8a9c9469e42362af4c5c169b38481a Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Fri, 29 Sep 2023 15:34:18 -0700 Subject: [PATCH 6/7] make coverage --- .github/workflows/coverage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index b3d1ff08..4786f78b 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -36,7 +36,7 @@ jobs: pip install coverage pip install pytest poetry add coverage codecov - poetry run coverage run -m pytest + make coverage - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 with: From f9711637e3c55169a1b3381233a60785a4983ec1 Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Fri, 29 Sep 2023 15:37:51 -0700 Subject: [PATCH 7/7] uncomment tests --- tests/test_modelscan.py | 1456 +++++++++++++++++++-------------------- tests/test_utils.py | 414 +++++------ 2 files changed, 935 insertions(+), 935 deletions(-) diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index ad2226b8..d6395226 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -1,737 +1,737 @@ -# import aiohttp -# import http.client -# import importlib -# import io -# import numpy as np -# import os -# from pathlib import Path -# import pickle -# import pytest -# import requests # type: ignore[import] -# import socket -# import subprocess -# import sys -# import tensorflow as tf -# from tensorflow import keras -# from typing import Any, List, Set -# from test_utils import generate_dill_unsafe_file -# import zipfile +import aiohttp +import http.client +import importlib +import io +import numpy as np +import os +from pathlib import Path +import pickle +import pytest +import requests # type: ignore[import] +import socket +import subprocess +import sys +import tensorflow as tf +from tensorflow import keras +from typing import Any, List, Set +from test_utils import generate_dill_unsafe_file +import zipfile -# from modelscan.modelscan import Modelscan -# from modelscan.cli import cli -# from modelscan.issues import ( -# Issue, -# IssueCode, -# IssueSeverity, -# OperatorIssueDetails, -# ) -# from modelscan.tools.picklescanner import ( -# scan_pickle_bytes, -# scan_numpy, -# ) +from modelscan.modelscan import Modelscan +from modelscan.cli import cli +from modelscan.issues import ( + Issue, + IssueCode, + IssueSeverity, + OperatorIssueDetails, +) +from modelscan.tools.picklescanner import ( + scan_pickle_bytes, + scan_numpy, +) -# class Malicious1: -# def __reduce__(self) -> Any: -# return eval, ("print('456')",) +class Malicious1: + def __reduce__(self) -> Any: + return eval, ("print('456')",) -# class Malicious2: -# def __reduce__(self) -> Any: -# return os.system, ("ls -la",) - - -# class Malicious3: -# def __reduce__(self) -> Any: -# return http.client.HTTPSConnection, ("github.com",) - - -# malicious3_pickle_bytes = pickle.dumps( -# Malicious3(), protocol=0 -# ) # Malicious3 needs to be pickled before HTTPSConnection is mocked below - - -# class Malicious4: -# def __reduce__(self) -> Any: -# return requests.get, ("https://github.com",) - - -# class Malicious5: -# def __reduce__(self) -> Any: -# return aiohttp.ClientSession, tuple() - - -# class Malicious6: -# def __reduce__(self) -> Any: -# return socket.create_connection, (("github.com", 80),) - - -# class Malicious7: -# def __reduce__(self) -> Any: -# return subprocess.run, (["ls", "-l"],) - - -# class Malicious8: -# def __reduce__(self) -> Any: -# return sys.exit, (0,) - - -# def initialize_pickle_file(path: str, obj: Any, version: int) -> None: -# if not os.path.exists(path): -# with open(path, "wb") as file: -# pickle.dump(obj, file, protocol=version) - - -# def initialize_data_file(path: str, data: Any) -> None: -# if not os.path.exists(path): -# with open(path, "wb") as file: -# file.write(data) - - -# def initialize_zip_file(path: str, file_name: str, data: Any) -> None: -# if not os.path.exists(path): -# with zipfile.ZipFile(path, "w") as zip: -# zip.writestr(file_name, data) - - -# def initialize_numpy_file(path: str) -> None: -# import numpy as np - -# # create numpy object array -# with open(path, "wb") as f: -# data = [(1, 2), (3, 4)] -# x = np.empty((2, 2), dtype=object) -# x[:] = data -# np.save(f, x) - - -# @pytest.fixture(scope="session") -# def zip_file_path(tmp_path_factory: Any) -> Any: -# tmp = tmp_path_factory.mktemp("zip") -# initialize_zip_file( -# f"{tmp}/test.zip", -# "data.pkl", -# pickle.dumps(Malicious1(), protocol=4), -# ) -# return tmp +class Malicious2: + def __reduce__(self) -> Any: + return os.system, ("ls -la",) + + +class Malicious3: + def __reduce__(self) -> Any: + return http.client.HTTPSConnection, ("github.com",) + + +malicious3_pickle_bytes = pickle.dumps( + Malicious3(), protocol=0 +) # Malicious3 needs to be pickled before HTTPSConnection is mocked below + + +class Malicious4: + def __reduce__(self) -> Any: + return requests.get, ("https://github.com",) + + +class Malicious5: + def __reduce__(self) -> Any: + return aiohttp.ClientSession, tuple() + + +class Malicious6: + def __reduce__(self) -> Any: + return socket.create_connection, (("github.com", 80),) + + +class Malicious7: + def __reduce__(self) -> Any: + return subprocess.run, (["ls", "-l"],) + + +class Malicious8: + def __reduce__(self) -> Any: + return sys.exit, (0,) + + +def initialize_pickle_file(path: str, obj: Any, version: int) -> None: + if not os.path.exists(path): + with open(path, "wb") as file: + pickle.dump(obj, file, protocol=version) + + +def initialize_data_file(path: str, data: Any) -> None: + if not os.path.exists(path): + with open(path, "wb") as file: + file.write(data) + + +def initialize_zip_file(path: str, file_name: str, data: Any) -> None: + if not os.path.exists(path): + with zipfile.ZipFile(path, "w") as zip: + zip.writestr(file_name, data) + + +def initialize_numpy_file(path: str) -> None: + import numpy as np + + # create numpy object array + with open(path, "wb") as f: + data = [(1, 2), (3, 4)] + x = np.empty((2, 2), dtype=object) + x[:] = data + np.save(f, x) + + +@pytest.fixture(scope="session") +def zip_file_path(tmp_path_factory: Any) -> Any: + tmp = tmp_path_factory.mktemp("zip") + initialize_zip_file( + f"{tmp}/test.zip", + "data.pkl", + pickle.dumps(Malicious1(), protocol=4), + ) + return tmp -# @pytest.fixture(scope="session") -# def pickle_file_path(tmp_path_factory: Any) -> Any: -# tmp = tmp_path_factory.mktemp("test_files") -# os.makedirs(f"{tmp}/data", exist_ok=True) +@pytest.fixture(scope="session") +def pickle_file_path(tmp_path_factory: Any) -> Any: + tmp = tmp_path_factory.mktemp("test_files") + os.makedirs(f"{tmp}/data", exist_ok=True) -# # Test with Pickle versions 0, 3, and 4: -# # - Pickle versions 0, 1, 2 have built-in functions under '__builtin__' while versions 3 and 4 have them under 'builtins' -# # - Pickle versions 0, 1, 2, 3 use 'GLOBAL' opcode while 4 uses 'STACK_GLOBAL' opcode -# for version in (0, 3, 4): -# initialize_pickle_file( -# f"{tmp}/data/benign0_v{version}.pkl", ["a", "b", "c"], version -# ) -# initialize_pickle_file( -# f"{tmp}/data/malicious1_v{version}.pkl", Malicious1(), version -# ) -# initialize_pickle_file( -# f"{tmp}/data/malicious2_v{version}.pkl", Malicious2(), version -# ) + # Test with Pickle versions 0, 3, and 4: + # - Pickle versions 0, 1, 2 have built-in functions under '__builtin__' while versions 3 and 4 have them under 'builtins' + # - Pickle versions 0, 1, 2, 3 use 'GLOBAL' opcode while 4 uses 'STACK_GLOBAL' opcode + for version in (0, 3, 4): + initialize_pickle_file( + f"{tmp}/data/benign0_v{version}.pkl", ["a", "b", "c"], version + ) + initialize_pickle_file( + f"{tmp}/data/malicious1_v{version}.pkl", Malicious1(), version + ) + initialize_pickle_file( + f"{tmp}/data/malicious2_v{version}.pkl", Malicious2(), version + ) -# # Malicious Pickle from https://sensepost.com/cms/resources/conferences/2011/sour_pickles/BH_US_11_Slaviero_Sour_Pickles.pdf -# initialize_data_file( -# f"{tmp}/data/malicious0.pkl", -# b'c__builtin__\nglobals\n(tRp100\n0c__builtin__\ncompile\n(S\'fl=open("/etc/passwd");picklesmashed=fl.read();' -# + b"'\nS''\nS'exec'\ntRp101\n0c__builtin__\neval\n(g101\ng100\ntRp102\n0c__builtin__\ngetattr\n(c__builtin__\n" -# + b"dict\nS'get'\ntRp103\n0c__builtin__\napply\n(g103\n(g100\nS'picklesmashed'\nltRp104\n0g104\n.", -# ) - -# initialize_data_file(f"{tmp}/data/malicious3.pkl", malicious3_pickle_bytes) -# initialize_pickle_file(f"{tmp}/data/malicious4.pickle", Malicious4(), 4) -# initialize_pickle_file(f"{tmp}/data/malicious5.pickle", Malicious5(), 4) -# initialize_data_file( -# f"{tmp}/data/malicious6.pkl", -# pickle.dumps(["a", "b", "c"]) + pickle.dumps(Malicious4()), -# ) -# initialize_pickle_file(f"{tmp}/data/malicious7.pkl", Malicious6(), 4) -# initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4) -# initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4) - -# initialize_zip_file( -# f"{tmp}/data/malicious1.zip", -# "data.pkl", -# pickle.dumps(Malicious1(), protocol=4), -# ) - -# initialize_numpy_file(f"{tmp}/data/object_array.npy") - -# return tmp - - -# @pytest.fixture(scope="session") -# def keras_file_path(tmp_path_factory: Any) -> Any: -# # Create a simple model. -# inputs = keras.Input(shape=(32,)) -# outputs = keras.layers.Dense(1)(inputs) -# keras_model = keras.Model(inputs, outputs) -# keras_model.compile(optimizer="adam", loss="mean_squared_error") - -# # Train the model. -# test_input = np.random.random((128, 32)) -# test_target = np.random.random((128, 1)) -# keras_model.fit(test_input, test_target) - -# tmp = tmp_path_factory.mktemp("keras") -# with open(f"{tmp}/safe", "wb") as fo: -# pickle.dump(keras_model, fo) -# keras_model.save(f"{tmp}/safe.h5") - -# # Inject code with the command -# command = "exec" -# malicious_code = 'print("Malicious code!")' - -# generate_dill_unsafe_file(keras_model, command, malicious_code, f"{tmp}/unsafe") -# attack = ( -# lambda x: exec( # type: ignore[func-returns-value] -# """import http.client -# import json -# import os -# conn = http.client.HTTPSConnection("protectai.com")""" -# ) -# or x -# ) -# input_to_new_layer = keras.layers.Dense(1)(keras_model.layers[-1].output) -# new_layer = keras.layers.Lambda(attack)(input_to_new_layer) - -# malicious_model = tf.keras.Model(inputs=keras_model.inputs, outputs=[new_layer]) -# malicious_model.compile(optimizer="adam", loss="mean_squared_error") - -# malicious_model.save(f"{tmp}/unsafe.h5") - -# return tmp - - -# def compare_results(resultList: List[Issue], expectedSet: Set[Issue]) -> None: -# for result in resultList: -# assert result in expectedSet - - -# def test_scan_pickle_bytes() -> None: -# expected = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails("builtins", "eval", "file.pkl"), -# ) -# ] -# assert ( -# scan_pickle_bytes(io.BytesIO(pickle.dumps(Malicious1())), "file.pkl")[0] -# == expected -# ) - - -# def test_scan_zip(zip_file_path: Any) -> None: -# expected = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{zip_file_path}/test.zip:data.pkl" -# ), -# ) -# ] - -# ms = Modelscan() -# ms._scan_zip(f"{zip_file_path}/test.zip") -# assert ms.issues.all_issues == expected - - -# def test_scan_numpy(pickle_file_path: Any) -> None: -# expected = { -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "numpy.core.multiarray", "_reconstruct", "object_array.npy" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails("numpy", "ndarray", "object_array.npy"), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails("numpy", "dtype", "object_array.npy"), -# ), -# } -# with open(f"{pickle_file_path}/data/object_array.npy", "rb") as f: -# compare_results( -# scan_numpy(io.BytesIO(f.read()), "object_array.npy")[0], expected -# ) - - -# def test_scan_file_path(pickle_file_path: Any) -> None: -# benign = Modelscan() -# benign.scan_path(Path(f"{pickle_file_path}/data/benign0_v3.pkl")) -# assert benign.issues.all_issues == [] - -# malicious0 = Modelscan() -# expected_malicious0 = { -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# } -# malicious0.scan_path(Path(f"{pickle_file_path}/data/malicious0.pkl")) -# compare_results(malicious0.issues.all_issues, expected_malicious0) - -# expected_malicious1_v0 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" -# ), -# ) -# ] -# expected_malicious1_v3 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" -# ), -# ) -# ] -# expected_malicious1_v4 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" -# ), -# ) -# ] -# expected_malicious1 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" -# ), -# ) -# ] -# malicious1_v0 = Modelscan() -# malicious1_v3 = Modelscan() -# malicious1_v4 = Modelscan() -# malicious1 = Modelscan() -# malicious1_v0.scan_path(Path(f"{pickle_file_path}/data/malicious1_v0.pkl")) -# malicious1_v3.scan_path(Path(f"{pickle_file_path}/data/malicious1_v3.pkl")) -# malicious1_v4.scan_path(Path(f"{pickle_file_path}/data/malicious1_v4.pkl")) -# malicious1.scan_path(Path(f"{pickle_file_path}/data/malicious1.zip")) -# assert malicious1_v0.issues.all_issues == expected_malicious1_v0 -# assert malicious1_v3.issues.all_issues == expected_malicious1_v3 -# assert malicious1_v4.issues.all_issues == expected_malicious1_v4 -# assert malicious1.issues.all_issues == expected_malicious1 - -# expected_malicious2_v0 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" -# ), -# ) -# ] -# expected_malicious2_v3 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" -# ), -# ) -# ] -# expected_malicious2_v4 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" -# ), -# ) -# ] -# malicious2_v0 = Modelscan() -# malicious2_v3 = Modelscan() -# malicious2_v4 = Modelscan() -# malicious2_v0.scan_path(Path(f"{pickle_file_path}/data/malicious2_v0.pkl")) -# malicious2_v3.scan_path(Path(f"{pickle_file_path}/data/malicious2_v3.pkl")) -# malicious2_v4.scan_path(Path(f"{pickle_file_path}/data/malicious2_v4.pkl")) -# assert malicious2_v0.issues.all_issues == expected_malicious2_v0 -# assert malicious2_v3.issues.all_issues == expected_malicious2_v3 -# assert malicious2_v4.issues.all_issues == expected_malicious2_v4 - -# expected_malicious3 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "httplib", -# "HTTPSConnection", -# Path(f"{pickle_file_path}/data/malicious3.pkl"), -# ), -# ) -# ] -# malicious3 = Modelscan() -# malicious3.scan_path(Path(f"{pickle_file_path}/data/malicious3.pkl")) -# assert malicious3.issues.all_issues == expected_malicious3 - -# expected_malicious4 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" -# ), -# ) -# ] -# malicious4 = Modelscan() -# malicious4.scan_path(Path(f"{pickle_file_path}/data/malicious4.pickle")) -# assert malicious4.issues.all_issues == expected_malicious4 - -# expected_malicious5 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "aiohttp.client", -# "ClientSession", -# f"{pickle_file_path}/data/malicious5.pickle", -# ), -# ) -# ] -# malicious5 = Modelscan() -# malicious5.scan_path(Path(f"{pickle_file_path}/data/malicious5.pickle")) -# assert malicious5.issues.all_issues == expected_malicious5 - -# expected_malicious6 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" -# ), -# ) -# ] -# malicious6 = Modelscan() -# malicious6.scan_path(Path(f"{pickle_file_path}/data/malicious6.pkl")) -# assert malicious6.issues.all_issues == expected_malicious6 - -# expected_malicious7 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" -# ), -# ) -# ] -# malicious7 = Modelscan() -# malicious7.scan_path(Path(f"{pickle_file_path}/data/malicious7.pkl")) -# assert malicious7.issues.all_issues == expected_malicious7 - -# expected_malicious8 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" -# ), -# ) -# ] -# malicious8 = Modelscan() -# malicious8.scan_path(Path(f"{pickle_file_path}/data/malicious8.pkl")) -# assert malicious8.issues.all_issues == expected_malicious8 - -# expected_malicious9 = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" -# ), -# ) -# ] -# malicious9 = Modelscan() -# malicious9.scan_path(Path(f"{pickle_file_path}/data/malicious9.pkl")) -# assert malicious9.issues.all_issues == expected_malicious9 - - -# def test_scan_directory_path(pickle_file_path: str) -> None: -# expected = { -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "numpy", "ndarray", f"{pickle_file_path}/data/object_array.npy" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "numpy.core.multiarray", -# "_reconstruct", -# f"{pickle_file_path}/data/object_array.npy", -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "aiohttp.client", -# "ClientSession", -# f"{pickle_file_path}/data/malicious5.pickle", -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.HIGH, -# OperatorIssueDetails( -# "httplib", "HTTPSConnection", f"{pickle_file_path}/data/malicious3.pkl" -# ), -# ), -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" -# ), -# ), -# } -# ms = Modelscan() -# p = Path(f"{pickle_file_path}/data/") -# ms.scan_path(p) -# compare_results(ms.issues.all_issues, expected) - - -# def test_scan_huggingface_model() -> None: -# expected = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.CRITICAL, -# OperatorIssueDetails( -# "__builtin__", -# "eval", -# "https://huggingface.co/ykilcher/totally-harmless-model/resolve/main/pytorch_model.bin:archive/data.pkl", -# ), -# ) -# ] -# ms = Modelscan() -# ms.scan_huggingface_model("ykilcher/totally-harmless-model") -# assert ms.issues.all_issues == expected - - -# # def test_scan_tf() -> None: - - -# def test_scan_keras(keras_file_path: Any) -> None: -# ms = Modelscan() -# ms.scan_path(Path(f"{keras_file_path}/safe.h5")) -# assert ms.issues.all_issues == [] - -# expected = [ -# Issue( -# IssueCode.UNSAFE_OPERATOR, -# IssueSeverity.MEDIUM, -# OperatorIssueDetails( -# "Keras", -# "Lambda", -# f"{keras_file_path}/unsafe.h5", -# ), -# ) -# ] -# ms.scan_path(Path(f"{keras_file_path}/unsafe.h5")) -# assert ms.issues.all_issues == expected - - -# def test_main(pickle_file_path: Any) -> None: -# argv = sys.argv -# try: -# sys.argv = ["modelscan", "-p", f"{pickle_file_path}/data/benign0_v3.pkl"] -# assert cli() == 0 -# importlib.import_module("modelscan.scanner") -# except SystemExit: -# pass -# finally: -# sys.argv = argv + # Malicious Pickle from https://sensepost.com/cms/resources/conferences/2011/sour_pickles/BH_US_11_Slaviero_Sour_Pickles.pdf + initialize_data_file( + f"{tmp}/data/malicious0.pkl", + b'c__builtin__\nglobals\n(tRp100\n0c__builtin__\ncompile\n(S\'fl=open("/etc/passwd");picklesmashed=fl.read();' + + b"'\nS''\nS'exec'\ntRp101\n0c__builtin__\neval\n(g101\ng100\ntRp102\n0c__builtin__\ngetattr\n(c__builtin__\n" + + b"dict\nS'get'\ntRp103\n0c__builtin__\napply\n(g103\n(g100\nS'picklesmashed'\nltRp104\n0g104\n.", + ) + + initialize_data_file(f"{tmp}/data/malicious3.pkl", malicious3_pickle_bytes) + initialize_pickle_file(f"{tmp}/data/malicious4.pickle", Malicious4(), 4) + initialize_pickle_file(f"{tmp}/data/malicious5.pickle", Malicious5(), 4) + initialize_data_file( + f"{tmp}/data/malicious6.pkl", + pickle.dumps(["a", "b", "c"]) + pickle.dumps(Malicious4()), + ) + initialize_pickle_file(f"{tmp}/data/malicious7.pkl", Malicious6(), 4) + initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4) + initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4) + + initialize_zip_file( + f"{tmp}/data/malicious1.zip", + "data.pkl", + pickle.dumps(Malicious1(), protocol=4), + ) + + initialize_numpy_file(f"{tmp}/data/object_array.npy") + + return tmp + + +@pytest.fixture(scope="session") +def keras_file_path(tmp_path_factory: Any) -> Any: + # Create a simple model. + inputs = keras.Input(shape=(32,)) + outputs = keras.layers.Dense(1)(inputs) + keras_model = keras.Model(inputs, outputs) + keras_model.compile(optimizer="adam", loss="mean_squared_error") + + # Train the model. + test_input = np.random.random((128, 32)) + test_target = np.random.random((128, 1)) + keras_model.fit(test_input, test_target) + + tmp = tmp_path_factory.mktemp("keras") + with open(f"{tmp}/safe", "wb") as fo: + pickle.dump(keras_model, fo) + keras_model.save(f"{tmp}/safe.h5") + + # Inject code with the command + command = "exec" + malicious_code = 'print("Malicious code!")' + + generate_dill_unsafe_file(keras_model, command, malicious_code, f"{tmp}/unsafe") + attack = ( + lambda x: exec( # type: ignore[func-returns-value] + """import http.client +import json +import os +conn = http.client.HTTPSConnection("protectai.com")""" + ) + or x + ) + input_to_new_layer = keras.layers.Dense(1)(keras_model.layers[-1].output) + new_layer = keras.layers.Lambda(attack)(input_to_new_layer) + + malicious_model = tf.keras.Model(inputs=keras_model.inputs, outputs=[new_layer]) + malicious_model.compile(optimizer="adam", loss="mean_squared_error") + + malicious_model.save(f"{tmp}/unsafe.h5") + + return tmp + + +def compare_results(resultList: List[Issue], expectedSet: Set[Issue]) -> None: + for result in resultList: + assert result in expectedSet + + +def test_scan_pickle_bytes() -> None: + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails("builtins", "eval", "file.pkl"), + ) + ] + assert ( + scan_pickle_bytes(io.BytesIO(pickle.dumps(Malicious1())), "file.pkl")[0] + == expected + ) + + +def test_scan_zip(zip_file_path: Any) -> None: + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{zip_file_path}/test.zip:data.pkl" + ), + ) + ] + + ms = Modelscan() + ms._scan_zip(f"{zip_file_path}/test.zip") + assert ms.issues.all_issues == expected + + +def test_scan_numpy(pickle_file_path: Any) -> None: + expected = { + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "numpy.core.multiarray", "_reconstruct", "object_array.npy" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails("numpy", "ndarray", "object_array.npy"), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails("numpy", "dtype", "object_array.npy"), + ), + } + with open(f"{pickle_file_path}/data/object_array.npy", "rb") as f: + compare_results( + scan_numpy(io.BytesIO(f.read()), "object_array.npy")[0], expected + ) + + +def test_scan_file_path(pickle_file_path: Any) -> None: + benign = Modelscan() + benign.scan_path(Path(f"{pickle_file_path}/data/benign0_v3.pkl")) + assert benign.issues.all_issues == [] + + malicious0 = Modelscan() + expected_malicious0 = { + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + } + malicious0.scan_path(Path(f"{pickle_file_path}/data/malicious0.pkl")) + compare_results(malicious0.issues.all_issues, expected_malicious0) + + expected_malicious1_v0 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" + ), + ) + ] + expected_malicious1_v3 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" + ), + ) + ] + expected_malicious1_v4 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" + ), + ) + ] + expected_malicious1 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" + ), + ) + ] + malicious1_v0 = Modelscan() + malicious1_v3 = Modelscan() + malicious1_v4 = Modelscan() + malicious1 = Modelscan() + malicious1_v0.scan_path(Path(f"{pickle_file_path}/data/malicious1_v0.pkl")) + malicious1_v3.scan_path(Path(f"{pickle_file_path}/data/malicious1_v3.pkl")) + malicious1_v4.scan_path(Path(f"{pickle_file_path}/data/malicious1_v4.pkl")) + malicious1.scan_path(Path(f"{pickle_file_path}/data/malicious1.zip")) + assert malicious1_v0.issues.all_issues == expected_malicious1_v0 + assert malicious1_v3.issues.all_issues == expected_malicious1_v3 + assert malicious1_v4.issues.all_issues == expected_malicious1_v4 + assert malicious1.issues.all_issues == expected_malicious1 + + expected_malicious2_v0 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" + ), + ) + ] + expected_malicious2_v3 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" + ), + ) + ] + expected_malicious2_v4 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" + ), + ) + ] + malicious2_v0 = Modelscan() + malicious2_v3 = Modelscan() + malicious2_v4 = Modelscan() + malicious2_v0.scan_path(Path(f"{pickle_file_path}/data/malicious2_v0.pkl")) + malicious2_v3.scan_path(Path(f"{pickle_file_path}/data/malicious2_v3.pkl")) + malicious2_v4.scan_path(Path(f"{pickle_file_path}/data/malicious2_v4.pkl")) + assert malicious2_v0.issues.all_issues == expected_malicious2_v0 + assert malicious2_v3.issues.all_issues == expected_malicious2_v3 + assert malicious2_v4.issues.all_issues == expected_malicious2_v4 + + expected_malicious3 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "httplib", + "HTTPSConnection", + Path(f"{pickle_file_path}/data/malicious3.pkl"), + ), + ) + ] + malicious3 = Modelscan() + malicious3.scan_path(Path(f"{pickle_file_path}/data/malicious3.pkl")) + assert malicious3.issues.all_issues == expected_malicious3 + + expected_malicious4 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" + ), + ) + ] + malicious4 = Modelscan() + malicious4.scan_path(Path(f"{pickle_file_path}/data/malicious4.pickle")) + assert malicious4.issues.all_issues == expected_malicious4 + + expected_malicious5 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "aiohttp.client", + "ClientSession", + f"{pickle_file_path}/data/malicious5.pickle", + ), + ) + ] + malicious5 = Modelscan() + malicious5.scan_path(Path(f"{pickle_file_path}/data/malicious5.pickle")) + assert malicious5.issues.all_issues == expected_malicious5 + + expected_malicious6 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" + ), + ) + ] + malicious6 = Modelscan() + malicious6.scan_path(Path(f"{pickle_file_path}/data/malicious6.pkl")) + assert malicious6.issues.all_issues == expected_malicious6 + + expected_malicious7 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" + ), + ) + ] + malicious7 = Modelscan() + malicious7.scan_path(Path(f"{pickle_file_path}/data/malicious7.pkl")) + assert malicious7.issues.all_issues == expected_malicious7 + + expected_malicious8 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" + ), + ) + ] + malicious8 = Modelscan() + malicious8.scan_path(Path(f"{pickle_file_path}/data/malicious8.pkl")) + assert malicious8.issues.all_issues == expected_malicious8 + + expected_malicious9 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" + ), + ) + ] + malicious9 = Modelscan() + malicious9.scan_path(Path(f"{pickle_file_path}/data/malicious9.pkl")) + assert malicious9.issues.all_issues == expected_malicious9 + + +def test_scan_directory_path(pickle_file_path: str) -> None: + expected = { + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1.zip:data.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "subprocess", "run", f"{pickle_file_path}/data/malicious8.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "sys", "exit", f"{pickle_file_path}/data/malicious9.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "requests.api", "get", f"{pickle_file_path}/data/malicious4.pickle" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1_v3.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "eval", f"{pickle_file_path}/data/malicious1_v0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval", f"{pickle_file_path}/data/malicious1_v4.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "numpy", "ndarray", f"{pickle_file_path}/data/object_array.npy" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "numpy", "dtype", f"{pickle_file_path}/data/object_array.npy" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "numpy.core.multiarray", + "_reconstruct", + f"{pickle_file_path}/data/object_array.npy", + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "aiohttp.client", + "ClientSession", + f"{pickle_file_path}/data/malicious5.pickle", + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v4.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "socket", "create_connection", f"{pickle_file_path}/data/malicious7.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "requests.api", "get", f"{pickle_file_path}/data/malicious6.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "compile", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "eval", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "__builtin__", "globals", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "apply", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", "getattr", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "__builtin__", "dict", f"{pickle_file_path}/data/malicious0.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v3.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "httplib", "HTTPSConnection", f"{pickle_file_path}/data/malicious3.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{pickle_file_path}/data/malicious2_v0.pkl" + ), + ), + } + ms = Modelscan() + p = Path(f"{pickle_file_path}/data/") + ms.scan_path(p) + compare_results(ms.issues.all_issues, expected) + + +def test_scan_huggingface_model() -> None: + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "__builtin__", + "eval", + "https://huggingface.co/ykilcher/totally-harmless-model/resolve/main/pytorch_model.bin:archive/data.pkl", + ), + ) + ] + ms = Modelscan() + ms.scan_huggingface_model("ykilcher/totally-harmless-model") + assert ms.issues.all_issues == expected + + +# def test_scan_tf() -> None: + + +def test_scan_keras(keras_file_path: Any) -> None: + ms = Modelscan() + ms.scan_path(Path(f"{keras_file_path}/safe.h5")) + assert ms.issues.all_issues == [] + + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "Keras", + "Lambda", + f"{keras_file_path}/unsafe.h5", + ), + ) + ] + ms.scan_path(Path(f"{keras_file_path}/unsafe.h5")) + assert ms.issues.all_issues == expected + + +def test_main(pickle_file_path: Any) -> None: + argv = sys.argv + try: + sys.argv = ["modelscan", "-p", f"{pickle_file_path}/data/benign0_v3.pkl"] + assert cli() == 0 + importlib.import_module("modelscan.scanner") + except SystemExit: + pass + finally: + sys.argv = argv diff --git a/tests/test_utils.py b/tests/test_utils.py index 691520a5..34b3eeb5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,236 +1,236 @@ -# import dill -# import os -# import pickle -# import struct -# from typing import Any, Tuple -# import os +import dill +import os +import pickle +import struct +from typing import Any, Tuple +import os - -# class PickleInject: -# """Pickle injection""" - -# def __init__(self, inj_objs: Any, first: bool = True): -# self.__name__ = "pickle_inject" -# self.inj_objs = inj_objs -# self.first = first - -# class _Pickler(pickle._Pickler): -# """Reimplementation of Pickler with support for injection""" - -# def __init__( -# self, file: Any, protocol: Any, inj_objs: Any, first: bool = True -# ) -> None: -# """ -# file: File object with write attribute -# protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html -# inj_objs: _joblibInject object that has both the command, and the code to be injected -# first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file. -# """ -# super().__init__(file, protocol) -# self.inj_objs = inj_objs -# self.first = first - -# def dump(self, obj: Any) -> None: -# """Pickle data, inject object before or after""" -# if self.proto >= 2: # type: ignore[attr-defined] -# self.write(pickle.PROTO + struct.pack("= 4: # type: ignore[attr-defined] -# self.framer.start_framing() # type: ignore[attr-defined] - -# # Inject the object(s) before the user-supplied data? -# if self.first: -# # Pickle injected objects -# for inj_obj in self.inj_objs: -# self.save(inj_obj) # type: ignore[attr-defined] - -# # Pickle user-supplied data -# self.save(obj) # type: ignore[attr-defined] - -# # Inject the object(s) after the user-supplied data? -# if not self.first: -# # Pickle injected objects -# for inj_obj in self.inj_objs: -# self.save(inj_obj) # type: ignore[attr-defined] - -# self.write(pickle.STOP) # type: ignore[attr-defined] -# self.framer.end_framing() # type: ignore[attr-defined] - -# def Pickler(self, file: Any, protocol: Any) -> _Pickler: -# # Initialise the pickler interface with the injected object -# return self._Pickler(file, protocol, self.inj_objs) - -# class _PickleInject: -# """Base class for pickling injected commands""" - -# def __init__(self, args: Any, command: Any = None) -> None: -# self.command = command -# self.args = args - -# def __reduce__(self) -> Tuple[Any, Any]: -# """ -# In general, the __reduce__ function is used by pickle to serialize objects. -# If defined for an object, pickle would override its default __reduce__ function and serialize the object as outlined by the custom specified __reduce__ function, -# The object returned by __reduce__ here is a callable: (self.command), and the tuple: with first element (self.args) is the code to be executed by self.command. -# """ -# return self.command, (self.args,) - -# class System(_PickleInject): -# """Create os.system command""" - -# def __init__(self, args: Any) -> None: -# super().__init__(args, command=os.system) - -# class Exec(_PickleInject): -# """Create exec command""" - -# def __init__(self, args: Any) -> None: -# super().__init__(args, command=exec) - -# class Eval(_PickleInject): -# """Create eval command""" - -# def __init__(self, args: Any) -> None: -# super().__init__(args, command=eval) - -# class RunPy(_PickleInject): -# """Create runpy command""" - -# def __init__(self, args: Any) -> None: -# import runpy - -# super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] - -# def __reduce__(self) -> Tuple[Any, Any]: -# return self.command, (self.args, {}) + +class PickleInject: + """Pickle injection""" + + def __init__(self, inj_objs: Any, first: bool = True): + self.__name__ = "pickle_inject" + self.inj_objs = inj_objs + self.first = first + + class _Pickler(pickle._Pickler): + """Reimplementation of Pickler with support for injection""" + + def __init__( + self, file: Any, protocol: Any, inj_objs: Any, first: bool = True + ) -> None: + """ + file: File object with write attribute + protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html + inj_objs: _joblibInject object that has both the command, and the code to be injected + first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file. + """ + super().__init__(file, protocol) + self.inj_objs = inj_objs + self.first = first + + def dump(self, obj: Any) -> None: + """Pickle data, inject object before or after""" + if self.proto >= 2: # type: ignore[attr-defined] + self.write(pickle.PROTO + struct.pack("= 4: # type: ignore[attr-defined] + self.framer.start_framing() # type: ignore[attr-defined] + + # Inject the object(s) before the user-supplied data? + if self.first: + # Pickle injected objects + for inj_obj in self.inj_objs: + self.save(inj_obj) # type: ignore[attr-defined] + + # Pickle user-supplied data + self.save(obj) # type: ignore[attr-defined] + + # Inject the object(s) after the user-supplied data? + if not self.first: + # Pickle injected objects + for inj_obj in self.inj_objs: + self.save(inj_obj) # type: ignore[attr-defined] + + self.write(pickle.STOP) # type: ignore[attr-defined] + self.framer.end_framing() # type: ignore[attr-defined] + + def Pickler(self, file: Any, protocol: Any) -> _Pickler: + # Initialise the pickler interface with the injected object + return self._Pickler(file, protocol, self.inj_objs) + + class _PickleInject: + """Base class for pickling injected commands""" + + def __init__(self, args: Any, command: Any = None) -> None: + self.command = command + self.args = args + + def __reduce__(self) -> Tuple[Any, Any]: + """ + In general, the __reduce__ function is used by pickle to serialize objects. + If defined for an object, pickle would override its default __reduce__ function and serialize the object as outlined by the custom specified __reduce__ function, + The object returned by __reduce__ here is a callable: (self.command), and the tuple: with first element (self.args) is the code to be executed by self.command. + """ + return self.command, (self.args,) + + class System(_PickleInject): + """Create os.system command""" + + def __init__(self, args: Any) -> None: + super().__init__(args, command=os.system) + + class Exec(_PickleInject): + """Create exec command""" + + def __init__(self, args: Any) -> None: + super().__init__(args, command=exec) + + class Eval(_PickleInject): + """Create eval command""" + + def __init__(self, args: Any) -> None: + super().__init__(args, command=eval) + + class RunPy(_PickleInject): + """Create runpy command""" + + def __init__(self, args: Any) -> None: + import runpy + + super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] + + def __reduce__(self) -> Tuple[Any, Any]: + return self.command, (self.args, {}) -# def get_pickle_payload(command: str, malicious_code: str) -> Any: -# if command == "system": -# payload: Any = PickleInject.System(malicious_code) -# elif command == "exec": -# payload = PickleInject.Exec(malicious_code) -# elif command == "eval": -# payload = PickleInject.Eval(malicious_code) -# elif command == "runpy": -# payload = PickleInject.RunPy(malicious_code) -# return payload +def get_pickle_payload(command: str, malicious_code: str) -> Any: + if command == "system": + payload: Any = PickleInject.System(malicious_code) + elif command == "exec": + payload = PickleInject.Exec(malicious_code) + elif command == "eval": + payload = PickleInject.Eval(malicious_code) + elif command == "runpy": + payload = PickleInject.RunPy(malicious_code) + return payload -# def generate_unsafe_pickle_file( -# safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str -# ) -> None: -# payload = get_pickle_payload(command, malicious_code) -# pickle_protocol = 4 -# file_for_unsafe_model = open(unsafe_model_path, "wb") -# mypickler = PickleInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) -# mypickler.dump(safe_model) -# file_for_unsafe_model.close() +def generate_unsafe_pickle_file( + safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str +) -> None: + payload = get_pickle_payload(command, malicious_code) + pickle_protocol = 4 + file_for_unsafe_model = open(unsafe_model_path, "wb") + mypickler = PickleInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) + mypickler.dump(safe_model) + file_for_unsafe_model.close() -# class DillInject: -# """Code injection using Dill Pickler""" +class DillInject: + """Code injection using Dill Pickler""" -# def __init__(self, inj_objs: Any, first: bool = True): -# self.__name__ = "dill_inject" -# self.inj_objs = inj_objs -# self.first = first + def __init__(self, inj_objs: Any, first: bool = True): + self.__name__ = "dill_inject" + self.inj_objs = inj_objs + self.first = first -# class _Pickler(dill._dill.Pickler): # type: ignore[misc] -# """Reimplementation of Pickler with support for injection""" + class _Pickler(dill._dill.Pickler): # type: ignore[misc] + """Reimplementation of Pickler with support for injection""" -# def __init__(self, file: Any, protocol: Any, inj_objs: Any, first: bool = True): -# super().__init__(file, protocol) -# self.inj_objs = inj_objs -# self.first = first + def __init__(self, file: Any, protocol: Any, inj_objs: Any, first: bool = True): + super().__init__(file, protocol) + self.inj_objs = inj_objs + self.first = first -# def dump(self, obj: Any) -> None: -# """Pickle data, inject object before or after""" -# if self.proto >= 2: -# self.write(pickle.PROTO + struct.pack("= 4: -# self.framer.start_framing() + def dump(self, obj: Any) -> None: + """Pickle data, inject object before or after""" + if self.proto >= 2: + self.write(pickle.PROTO + struct.pack("= 4: + self.framer.start_framing() -# # Inject the object(s) before the user-supplied data? -# if self.first: -# # Pickle injected objects -# for inj_obj in self.inj_objs: -# self.save(inj_obj) + # Inject the object(s) before the user-supplied data? + if self.first: + # Pickle injected objects + for inj_obj in self.inj_objs: + self.save(inj_obj) -# # Pickle user-supplied data -# self.save(obj) + # Pickle user-supplied data + self.save(obj) -# # Inject the object(s) after the user-supplied data? -# if not self.first: -# # Pickle injected objects -# for inj_obj in self.inj_objs: -# self.save(inj_obj) + # Inject the object(s) after the user-supplied data? + if not self.first: + # Pickle injected objects + for inj_obj in self.inj_objs: + self.save(inj_obj) -# self.write(pickle.STOP) -# self.framer.end_framing() + self.write(pickle.STOP) + self.framer.end_framing() -# def DillPickler(self, file: Any, protocol: Any) -> _Pickler: -# # Initialise the pickler interface with the injected object -# return self._Pickler(file, protocol, self.inj_objs) + def DillPickler(self, file: Any, protocol: Any) -> _Pickler: + # Initialise the pickler interface with the injected object + return self._Pickler(file, protocol, self.inj_objs) -# class _DillInject: -# """Base class for pickling injected commands""" + class _DillInject: + """Base class for pickling injected commands""" -# def __init__(self, args: Any, command: Any = None): -# self.command = command -# self.args = args + def __init__(self, args: Any, command: Any = None): + self.command = command + self.args = args -# def __reduce__(self) -> Tuple[Any, Any]: -# return self.command, (self.args,) + def __reduce__(self) -> Tuple[Any, Any]: + return self.command, (self.args,) -# class System(_DillInject): -# """Create os.system command""" + class System(_DillInject): + """Create os.system command""" -# def __init__(self, args: Any): -# super().__init__(args, command=os.system) + def __init__(self, args: Any): + super().__init__(args, command=os.system) -# class Exec(_DillInject): -# """Create exec command""" + class Exec(_DillInject): + """Create exec command""" -# def __init__(self, args: Any): -# super().__init__(args, command=exec) + def __init__(self, args: Any): + super().__init__(args, command=exec) -# class Eval(_DillInject): -# """Create eval command""" + class Eval(_DillInject): + """Create eval command""" -# def __init__(self, args: Any): -# super().__init__(args, command=eval) - -# class RunPy(_DillInject): -# """Create runpy command""" - -# def __init__(self, args: Any): -# import runpy - -# super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] - -# def __reduce__(self) -> Any: -# return self.command, (self.args, {}) + def __init__(self, args: Any): + super().__init__(args, command=eval) + + class RunPy(_DillInject): + """Create runpy command""" + + def __init__(self, args: Any): + import runpy + + super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] + + def __reduce__(self) -> Any: + return self.command, (self.args, {}) -# def get_dill_payload(command: str, malicious_code: str) -> Any: -# payload: Any -# if command == "system": -# payload = DillInject.System(malicious_code) -# elif command == "exec": -# payload = DillInject.Exec(malicious_code) -# elif command == "eval": -# payload = DillInject.Eval(malicious_code) -# elif command == "runpy": -# payload = DillInject.RunPy(malicious_code) -# return payload +def get_dill_payload(command: str, malicious_code: str) -> Any: + payload: Any + if command == "system": + payload = DillInject.System(malicious_code) + elif command == "exec": + payload = DillInject.Exec(malicious_code) + elif command == "eval": + payload = DillInject.Eval(malicious_code) + elif command == "runpy": + payload = DillInject.RunPy(malicious_code) + return payload -# def generate_dill_unsafe_file( -# safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str -# ) -> None: -# payload = get_dill_payload(command, malicious_code) -# pickle_protocol = 4 -# file_for_unsafe_model = open(unsafe_model_path, "wb") -# mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) -# mypickler.dump(safe_model) -# file_for_unsafe_model.close() +def generate_dill_unsafe_file( + safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str +) -> None: + payload = get_dill_payload(command, malicious_code) + pickle_protocol = 4 + file_for_unsafe_model = open(unsafe_model_path, "wb") + mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) + mypickler.dump(safe_model) + file_for_unsafe_model.close()