diff --git a/src/etwtrace/__init__.py b/src/etwtrace/__init__.py index 886b2d4..7f0ad5b 100644 --- a/src/etwtrace/__init__.py +++ b/src/etwtrace/__init__.py @@ -169,24 +169,41 @@ def enable_if(enable_var, type_var): tracer.enable() +def is_active(): + """Returns True if tracing is active.""" + return bool(_tracer) + + def mark(name): """Emits a mark event with the provided text.""" - if not _tracer: - raise RuntimeError("unable to mark when global tracer is not enabled") - _tracer.mark(name) + if _tracer: + _tracer.mark(name) + else: + import warnings + warnings.warn("Unable to mark when global tracer is not enabled", RuntimeWarning) + + +class _NullRange: + def __enter__(self): return self + def __exit__(self, *exc_info): pass def mark_range(name): """Context manager to emit start/stop mark events with the provided text.""" - if not _tracer: - raise RuntimeError("unable to mark when global tracer is not enabled") - return _tracer.mark_range(name) + if _tracer: + return _tracer.mark_range(name) + else: + import warnings + warnings.warn("Unable to mark when global tracer is not enabled", RuntimeWarning) + return _NullRange() def _mark_stack(mark): - if not _tracer: - raise RuntimeError("unable to mark when global tracer is not enabled") - return _tracer._mark_stack(mark) + if _tracer: + return _tracer._mark_stack(mark) + else: + import warnings + warnings.warn("Unable to mark when global tracer is not enabled", RuntimeWarning) _TEMP_PROFILE = None diff --git a/tests/test_etw.py b/tests/test_etw.py index a72e83e..3ed3a33 100644 --- a/tests/test_etw.py +++ b/tests/test_etw.py @@ -234,6 +234,22 @@ def test_but_do_we_instrument(): ) +def test_but_are_we_inactive(): + assert etwtrace.is_active() is False + + +def test_but_do_we_warn_on_mark(): + with pytest.warns(RuntimeWarning): + etwtrace.mark("Test mark without tracing") + + with pytest.warns(RuntimeWarning): + with etwtrace.mark_range("Test mark without tracing"): + pass + + with pytest.warns(RuntimeWarning): + etwtrace._mark_stack("Test mark without tracing") + + def test_basic(trace_events): funcs = set() with trace_events("basic.py", providers=['Python']) as etl: