diff --git a/src/pystatis/config.py b/src/pystatis/config.py index caa72858..d3689a33 100644 --- a/src/pystatis/config.py +++ b/src/pystatis/config.py @@ -151,9 +151,26 @@ def config_exists() -> bool: return config_file.exists() -def setup_credentials() -> None: - """Setup credentials for all supported databases.""" - for db_name in get_supported_db(): +def setup_credentials(db_names: list[str] | None = None) -> None: + """ + Setup credentials for all supported databases. + + Args: + db_names (list[str]): Names of the databases to setup + """ + if db_names is None: + db_names = get_supported_db() + + for db_name in db_names: + # despite this check, we should consider using literals as the type hint for + # db_names. + if db_name not in get_supported_db(): + raise KeyError( + f"Provided db_name '{db_name}' no regnized. " + f"Valid options are {get_supported_db()}" + ) + + for db_name in db_names: config.set(db_name, "username", _get_user_input(db_name, "username")) config.set(db_name, "password", _get_user_input(db_name, "password")) if not db.check_credentials_are_valid(db_name): diff --git a/tests/test_config.py b/tests/test_config.py index f85949c3..3b7e2f68 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -83,6 +83,25 @@ def test_setup_credentials(mocker, config_): assert config_[db_name]["password"] == "test123!" +def test_setup_credentials_with_db_names_subset(mocker, config_): + mocker.patch.object(db, "check_credentials_are_valid", return_value=True) + + os.environ["PYSTATIS_GENESIS_API_USERNAME"] = "test" + os.environ["PYSTATIS_GENESIS_API_PASSWORD"] = "test123!" + + config.setup_credentials(db_names=["genesis"]) + + # Only the specified DB should be set + assert config_["genesis"]["username"] == "test" + assert config_["genesis"]["password"] == "test123!" + + +def test_setup_credentials_invalid_credentials_raises(mocker, config_): + mocker.patch.object(db, "check_credentials_are_valid", return_value=False) + + with pytest.raises(KeyError) as _exc_info: + config.setup_credentials(db_names=["regio_s"]) + @pytest.mark.parametrize( "mock_return, check_result", [