diff --git a/CHANGELOG.md b/CHANGELOG.md index 1432da4..02b42f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ------------------------------------------------------------------- ## [Unreleased] +## Added +- Added a function to the mysql executioner to use mariadb commands if mariadb is detected. This is because the mysql commands are deprecated on newer systems. ## [2.5.0] 2024-12-27 ## Fixed diff --git a/pynonymizer/database/mysql/execution.py b/pynonymizer/database/mysql/execution.py index 96fec7a..e000174 100644 --- a/pynonymizer/database/mysql/execution.py +++ b/pynonymizer/database/mysql/execution.py @@ -2,7 +2,9 @@ import shutil import shlex import subprocess + from pynonymizer.database.exceptions import DependencyError +import pynonymizer logger = logging.getLogger(__name__) @@ -21,6 +23,15 @@ def _optional_arg_pair(arg_value_pair): return _optional_arg(arg_value_pair[1], arg_value_pair) +def _set_mysql_cmd() -> None: + if shutil.which("mariadb"): + logger.info( + "Found mariadb client, using mariadb and mariadb-dump instead of deprecated mysql commands." + ) + pynonymizer.database.mysql.execution.RESTORE_CMD = "mariadb" + pynonymizer.database.mysql.execution.DUMP_CMD = "mariadb-dump" + + class MySqlDumpRunner: def __init__( self, @@ -42,6 +53,8 @@ def __init__( if db_name is None: raise ValueError("db_name cannot be null") + _set_mysql_cmd() + if not (shutil.which(DUMP_CMD)): raise DependencyError( DUMP_CMD, f"The '{DUMP_CMD}' client must be present in the $PATH" @@ -97,6 +110,8 @@ def __init__( if db_name is None: raise ValueError("db_name cannot be null") + _set_mysql_cmd() + if not (shutil.which(RESTORE_CMD)): raise DependencyError( RESTORE_CMD, f"The '{RESTORE_CMD}' client must be present in the $PATH" diff --git a/tests/database/__init__.py b/tests/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/database/mysql/__init__.py b/tests/database/mysql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/database/mysql/test_execution.py b/tests/database/mysql/test_execution.py new file mode 100644 index 0000000..721a56d --- /dev/null +++ b/tests/database/mysql/test_execution.py @@ -0,0 +1,29 @@ +import os +from pathlib import Path + +import pytest +from unittest import mock +from pynonymizer.database.mysql import execution + + +@pytest.mark.parametrize( + argnames=("mocked_command", "expected_restore_cmd", "expected_dump_cmd"), + argvalues=[("mysql", "mysql", "mysqldump"), ("mariadb", "mariadb", "mariadb-dump")], + ids=["mysql", "mariadb"], +) +def test_set_mysql_cmd( + mocked_command: str, expected_restore_cmd: str, expected_dump_cmd: str, tmpdir +) -> None: + # create a fake binary for shutil to find + fake_binary_path: Path = tmpdir / mocked_command + fake_binary_path.write_text("#!/bin/bash\nexit 0", encoding="utf-8") + fake_binary_path.chmod(0o755) + + with mock.patch.dict(os.environ, {"PATH": str(tmpdir)}): + # Make sure the global variables are set to the default values + assert execution.RESTORE_CMD == "mysql" + assert execution.DUMP_CMD == "mysqldump" + execution._set_mysql_cmd() + # Check if the global variables are set to the expected values + assert execution.RESTORE_CMD == expected_restore_cmd + assert execution.DUMP_CMD == expected_dump_cmd