diff --git a/google/cloud/dataproc_spark_connect/environment.py b/google/cloud/dataproc_spark_connect/environment.py index 1f00a03..c74e81e 100644 --- a/google/cloud/dataproc_spark_connect/environment.py +++ b/google/cloud/dataproc_spark_connect/environment.py @@ -18,7 +18,7 @@ def is_vscode() -> bool: - """True if running inside VS Code at all.""" + """True if running inside VS Code at all.""" return os.getenv("VSCODE_PID") is not None @@ -38,16 +38,76 @@ def is_colab() -> bool: def is_workbench() -> bool: - """True if running in AI Workbench (managed Jupyter).""" + """True if running in Vertex Workbench Instance (managed Jupyter).""" return os.getenv("VERTEX_PRODUCT") == "WORKBENCH_INSTANCE" +def is_kaggle() -> bool: + """True if running in Kaggle Notebooks.""" + return os.getenv("KAGGLE_KERNEL_RUN_TYPE") is not None + + +def is_databricks() -> bool: + """True if running in Databricks.""" + return os.getenv("DATABRICKS_RUNTIME_VERSION") is not None + + +def is_sagemaker() -> bool: + """True if running in AWS SageMaker.""" + return os.getenv("SAGEMAKER_INTERNAL_IMAGE_URI") is not None + + +def is_deepnote() -> bool: + """True if running in Deepnote.""" + return os.getenv("DEEPNOTE_PROJECT_ID") is not None + + +def is_datalore() -> bool: + """True if running in JetBrains Datalore.""" + return os.getenv("DATALORE_USER") is not None + + +def is_spyder() -> bool: + """True if running inside Spyder IDE.""" + return any(k.startswith("SPYDER") for k in os.environ) + + +def is_cloud_shell() -> bool: + """True if running in Google Cloud Shell.""" + return os.getenv("CLOUD_SHELL") is not None + + +def is_codespaces() -> bool: + """True if running in GitHub Codespaces.""" + return os.getenv("CODESPACES") is not None + + def is_jetbrains_ide() -> bool: - """True if running inside any JetBrains IDE.""" - return "jetbrains" in os.getenv("TERMINAL_EMULATOR", "").lower() + """True if running inside JetBrains IDE.""" + return ( + "jetbrains" in os.getenv("TERMINAL_EMULATOR", "").lower() + or "PYCHARM_HOSTED" in os.environ + ) + + +def is_hex() -> bool: + """True if running in Hex.""" + return os.getenv("HEX_PROJECT_ID") is not None + + +def is_polynote() -> bool: + """True if running in Polynote.""" + return os.getenv("POLYNOTE_VERSION") is not None + + +def is_eclipse() -> bool: + """True if running inside Eclipse IDE.""" + return "ECLIPSE_HOME" in os.environ or any( + k.startswith("ECLIPSE") for k in os.environ + ) -def is_interactive(): +def is_interactive() -> bool: try: from IPython import get_ipython @@ -56,14 +116,14 @@ def is_interactive(): except ImportError: pass - return hasattr(sys, "ps1") or sys.flags.interactive + return hasattr(sys, "ps1") or bool(sys.flags.interactive) -def is_terminal(): +def is_terminal() -> bool: return sys.stdin.isatty() -def is_interactive_terminal(): +def is_interactive_terminal() -> bool: return is_interactive() and is_terminal() @@ -78,18 +138,40 @@ def get_client_environment_label() -> str: Priority order: 1. Colab Enterprise ("colab-enterprise") 2. Colab ("colab") - 3. Workbench ("workbench-jupyter") - 4. VS Code ("vscode") - 5. JetBrains IDE ("jetbrains") - 6. Jupyter ("jupyter") - 7. Unknown ("unknown") + 3. Vertex Workbench Instance ("workbench-jupyter") + 4. Kaggle ("kaggle") + 5. AWS SageMaker ("sagemaker") + 6. Databricks ("databricks") + 7. Deepnote ("deepnote") + 8. JetBrains Datalore ("datalore") + 9. GitHub Codespaces ("codespaces") + 10. Google Cloud Shell ("cloud-shell") + 11. Hex ("hex") + 12. Polynote ("polynote") + 13. VS Code ("vscode") + 14. JetBrains IDE ("jetbrains") + 15. Spyder ("spyder") + 16. Eclipse ("eclipse") + 17. Jupyter ("jupyter") + 18. Unknown ("unknown") """ checks: List[Tuple[Callable[[], bool], str]] = [ (is_colab_enterprise, "colab-enterprise"), (is_colab, "colab"), (is_workbench, "workbench-jupyter"), + (is_kaggle, "kaggle"), + (is_sagemaker, "sagemaker"), + (is_databricks, "databricks"), + (is_deepnote, "deepnote"), + (is_datalore, "datalore"), + (is_codespaces, "codespaces"), + (is_cloud_shell, "cloud-shell"), + (is_hex, "hex"), + (is_polynote, "polynote"), (is_vscode, "vscode"), (is_jetbrains_ide, "jetbrains"), + (is_spyder, "spyder"), + (is_eclipse, "eclipse"), (is_jupyter, "jupyter"), ] for detector, label in checks: diff --git a/tests/unit/test_environment.py b/tests/unit/test_environment.py index 836fbd3..d3bd5f0 100644 --- a/tests/unit/test_environment.py +++ b/tests/unit/test_environment.py @@ -70,16 +70,115 @@ def test_is_workbench_false(self): os.environ["VERTEX_PRODUCT"] = "OTHER" self.assertFalse(environment.is_workbench()) + def test_is_kaggle_true(self): + os.environ["KAGGLE_KERNEL_RUN_TYPE"] = "Interactive" + self.assertTrue(environment.is_kaggle()) + + def test_is_kaggle_false(self): + os.environ.pop("KAGGLE_KERNEL_RUN_TYPE", None) + self.assertFalse(environment.is_kaggle()) + + def test_is_databricks_true(self): + os.environ["DATABRICKS_RUNTIME_VERSION"] = "10.4.x-scala2.12" + self.assertTrue(environment.is_databricks()) + + def test_is_databricks_false(self): + os.environ.pop("DATABRICKS_RUNTIME_VERSION", None) + self.assertFalse(environment.is_databricks()) + + def test_is_sagemaker_true(self): + os.environ["SAGEMAKER_INTERNAL_IMAGE_URI"] = "image" + self.assertTrue(environment.is_sagemaker()) + + def test_is_sagemaker_false(self): + os.environ.pop("SAGEMAKER_INTERNAL_IMAGE_URI", None) + self.assertFalse(environment.is_sagemaker()) + + def test_is_deepnote_true(self): + os.environ["DEEPNOTE_PROJECT_ID"] = "project-123" + self.assertTrue(environment.is_deepnote()) + + def test_is_deepnote_false(self): + os.environ.pop("DEEPNOTE_PROJECT_ID", None) + self.assertFalse(environment.is_deepnote()) + + def test_is_datalore_true(self): + os.environ["DATALORE_USER"] = "user-123" + self.assertTrue(environment.is_datalore()) + + def test_is_datalore_false(self): + os.environ.pop("DATALORE_USER", None) + self.assertFalse(environment.is_datalore()) + + def test_is_spyder_true(self): + os.environ["SPYDER_ARGS"] = "[]" + self.assertTrue(environment.is_spyder()) + + def test_is_spyder_false(self): + for k in list(os.environ.keys()): + if k.startswith("SPYDER"): + os.environ.pop(k) + self.assertFalse(environment.is_spyder()) + + def test_is_cloud_shell_true(self): + os.environ["CLOUD_SHELL"] = "true" + self.assertTrue(environment.is_cloud_shell()) + + def test_is_cloud_shell_false(self): + os.environ.pop("CLOUD_SHELL", None) + self.assertFalse(environment.is_cloud_shell()) + + def test_is_codespaces_true(self): + os.environ["CODESPACES"] = "true" + self.assertTrue(environment.is_codespaces()) + + def test_is_codespaces_false(self): + os.environ.pop("CODESPACES", None) + self.assertFalse(environment.is_codespaces()) + + def test_is_hex_true(self): + os.environ["HEX_PROJECT_ID"] = "hex-123" + self.assertTrue(environment.is_hex()) + + def test_is_hex_false(self): + os.environ.pop("HEX_PROJECT_ID", None) + self.assertFalse(environment.is_hex()) + + def test_is_polynote_true(self): + os.environ["POLYNOTE_VERSION"] = "1.0" + self.assertTrue(environment.is_polynote()) + + def test_is_polynote_false(self): + os.environ.pop("POLYNOTE_VERSION", None) + self.assertFalse(environment.is_polynote()) + + def test_is_eclipse_true(self): + os.environ["ECLIPSE_HOME"] = "/path/to/eclipse" + self.assertTrue(environment.is_eclipse()) + + def test_is_eclipse_false(self): + for k in list(os.environ.keys()): + if k.startswith("ECLIPSE"): + os.environ.pop(k) + self.assertFalse(environment.is_eclipse()) + def test_is_jetbrains_ide_true(self): os.environ["TERMINAL_EMULATOR"] = "JetBrains term" self.assertTrue(environment.is_jetbrains_ide()) + def test_is_jetbrains_ide_true_pycharm(self): + os.environ.pop("TERMINAL_EMULATOR", None) + os.environ["PYCHARM_HOSTED"] = "1" + self.assertTrue(environment.is_jetbrains_ide()) + def test_is_jetbrains_ide_false_env_var_not_set(self): os.environ.pop("TERMINAL_EMULATOR", None) + os.environ.pop("PYCHARM_HOSTED", None) self.assertFalse(environment.is_jetbrains_ide()) def test_is_jetbrains_ide_false_env_var_not_jetbrains(self): os.environ["TERMINAL_EMULATOR"] = "real term" + os.environ.pop("PYCHARM_HOSTED", None) self.assertFalse(environment.is_jetbrains_ide()) # ---- get_client_environment_label tests ---- @@ -96,6 +195,42 @@ def test_is_jetbrains_ide_false_env_var_not_jetbrains(self): "google.cloud.dataproc_spark_connect.environment.is_workbench", return_value=False, ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_kaggle", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_sagemaker", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_databricks", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_deepnote", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_datalore", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_codespaces", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_cloud_shell", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_hex", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_polynote", + return_value=False, + ) @mock.patch( "google.cloud.dataproc_spark_connect.environment.is_vscode", return_value=False, @@ -104,6 +239,14 @@ def test_is_jetbrains_ide_false_env_var_not_jetbrains(self): "google.cloud.dataproc_spark_connect.environment.is_jetbrains_ide", return_value=False, ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_spyder", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_eclipse", + return_value=False, + ) @mock.patch( "google.cloud.dataproc_spark_connect.environment.is_jupyter", return_value=False, @@ -170,6 +313,64 @@ def test_get_client_environment_label_workbench(self, *mocks): "google.cloud.dataproc_spark_connect.environment.is_workbench", return_value=False, ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_kaggle", + return_value=True, + ) + def test_get_client_environment_label_kaggle(self, *mocks): + self.assertEqual( + environment.get_client_environment_label(), + "kaggle", + ) + + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_colab_enterprise", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_colab", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_workbench", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_kaggle", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_sagemaker", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_databricks", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_deepnote", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_datalore", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_codespaces", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_cloud_shell", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_hex", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_polynote", + return_value=False, + ) @mock.patch( "google.cloud.dataproc_spark_connect.environment.is_vscode", return_value=True, @@ -192,6 +393,42 @@ def test_get_client_environment_label_vscode(self, *mocks): "google.cloud.dataproc_spark_connect.environment.is_workbench", return_value=False, ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_kaggle", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_sagemaker", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_databricks", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_deepnote", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_datalore", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_codespaces", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_cloud_shell", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_hex", + return_value=False, + ) + @mock.patch( + "google.cloud.dataproc_spark_connect.environment.is_polynote", + return_value=False, + ) @mock.patch( "google.cloud.dataproc_spark_connect.environment.is_vscode", return_value=False,