diff --git a/_msbuild_test.py b/_msbuild_test.py index 48a1c01..7f09dca 100644 --- a/_msbuild_test.py +++ b/_msbuild_test.py @@ -48,12 +48,14 @@ CSourceFile('etwtrace/_tdhreader.cpp'), IncludeFile('etwtrace/_tdhreader.h'), ), - PydFile( - 'DiagnosticsHub.InstrumentationCollector', - *PYD_OPTS, - CSourceFile('etwtrace/_diaghubstub.c'), - TargetExt='.dll', - ), + # This package will be renamed in init_PACKAGE + Package('arch', + CProject( + 'DiagnosticsHubStub', + *PYD_OPTS, + CSourceFile('etwtrace/_diaghubstub.c'), + ), + ) ), source='src', ) @@ -65,3 +67,14 @@ def init_METADATA(): if sep and re.match(r"\d+(\.\d+)+((a|b|rc)\d+)?$", version): # Looks like a version tag METADATA["Version"] = version + + +def init_PACKAGE(tag=None): + if not tag: + return + if tag.endswith("-win32"): + PACKAGE.find('test/arch').name = 'x86' + elif tag.endswith("-win_amd64"): + PACKAGE.find('test/arch').name = 'amd64' + elif tag.endswith("-win_arm64"): + PACKAGE.find('test/arch').name = 'arm64' diff --git a/src/etwtrace/__init__.py b/src/etwtrace/__init__.py index 6d35e7c..886b2d4 100644 --- a/src/etwtrace/__init__.py +++ b/src/etwtrace/__init__.py @@ -101,16 +101,28 @@ def __init__(self): class DiagnosticsHubTracer(_TracingMixin): def __init__(self, stub=False): + self._data = None if stub: from ctypes import PyDLL, py_object from pathlib import Path + from os import environ + from sys import winver self._data = [] - dll = Path(__file__).parent / "test" / "DiagnosticsHub.InstrumentationCollector.dll" + root = Path(__file__).parent / "test" + if winver.endswith("-32"): + dll = root / "x86" + elif winver.endswith("-arm64"): + dll = root / "arm64" + else: + dll = root / "amd64" + dll = dll / "DiagnosticsHubStub.dll" if not dll.is_file(): raise RuntimeError("Diagnostics hub stub requires test files") self._stub = PyDLL(str(dll)) self._stub.OnEvent.argtypes = [py_object] self._stub.OnEvent(lambda *a: self._on_event(*a)) + environ["DIAGHUB_INSTR_COLLECTOR_ROOT"] = str(root) + environ["DIAGHUB_INSTR_RUNTIME_NAME"] = dll.name super().__init__() from . import _vsinstrument as mod self._module = mod @@ -121,7 +133,8 @@ def _on_event(self, *args): def disable(self): super().disable() - print(*self._data, sep="\n") + if self._data: + print(*self._data, sep="\n") def enable_if(enable_var, type_var): diff --git a/src/etwtrace/_diaghubstub.c b/src/etwtrace/_diaghubstub.c index 3108bc6..30abe0a 100644 --- a/src/etwtrace/_diaghubstub.c +++ b/src/etwtrace/_diaghubstub.c @@ -55,6 +55,20 @@ EXPORT void PROBE_IMPL Cap_Define_Script_Function(_In_ void* pFunction, _In_ voi } } +EXPORT BOOL PROBE_IMPL ChildAttach() +{ + if (_on_event) { + PyObject *r = PyObject_CallFunction(_on_event, "s", "ChildAttach"); + if (!r) { + PyErr_WriteUnraisable(NULL); + return FALSE; + } else { + Py_DECREF(r); + } + } + return TRUE; +} + EXPORT void PROBE_IMPL Stub_Write_Mark(_In_ int opcode, _In_z_ LPCWSTR szMark) { if (_on_event) { diff --git a/src/etwtrace/_vsinstrument.c b/src/etwtrace/_vsinstrument.c index 12e10d5..e738288 100644 --- a/src/etwtrace/_vsinstrument.c +++ b/src/etwtrace/_vsinstrument.c @@ -278,14 +278,52 @@ static int vsinstrument_exec(PyObject *m) return -1; } - // We ensure the DLL is loaded already before loading it again, but - // use LoadLibraryW to increment the reference count to ensure it does not - // get freed on us. - if (!GetModuleHandleW(L"DiagnosticsHub.InstrumentationCollector.dll")) { - PyErr_SetString(PyExc_RuntimeError, "VS tracing must be launched from Diagnostics Hub"); +#if defined(_M_IX86) + const wchar_t * const subpath = L"\\x86\\"; +#elif defined(_M_AMD64) + const wchar_t * const subpath = L"\\amd64\\"; +#elif defined(_M_ARM64) + const wchar_t * const subpath = L"\\arm64\\"; +#else + #error Unsupported architecture +#endif + + DWORD cchPath = GetEnvironmentVariableW(L"DIAGHUB_INSTR_COLLECTOR_ROOT", NULL, 0); + if (!cchPath) { + PyErr_SetFromWindowsErr(0); + return -1; + } + DWORD cchName = GetEnvironmentVariableW(L"DIAGHUB_INSTR_RUNTIME_NAME", NULL, 0); + if (!cchName) { + PyErr_SetFromWindowsErr(0); + return -1; + } + cchPath = cchPath + cchName + (DWORD)wcslen(subpath); + wchar_t *path = (wchar_t *)PyMem_Malloc(cchPath * sizeof(wchar_t)); + if (!path) { + PyErr_NoMemory(); + return -1; + } + + DWORD cch = GetEnvironmentVariable(L"DIAGHUB_INSTR_COLLECTOR_ROOT", path, cchPath); + if (!cch) { + PyErr_SetFromWindowsErr(0); return -1; } - state->hModule = LoadLibraryW(L"DiagnosticsHub.InstrumentationCollector.dll"); + while (cch > 0 && path[cch - 1] == L'\\') { + --cch; + } + wcscpy_s(&path[cch], cchPath - cch, subpath); + cch += (DWORD)wcslen(subpath); + + if (!GetEnvironmentVariable(L"DIAGHUB_INSTR_RUNTIME_NAME", &path[cch], cchPath - cch)) { + PyMem_Free(path); + PyErr_SetFromWindowsErr(0); + return -1; + } + + state->hModule = LoadLibraryW(path); + PyMem_Free(path); if (!state->hModule) { PyErr_SetFromWindowsErr(0); return -1; @@ -314,6 +352,17 @@ static int vsinstrument_exec(PyObject *m) // Allowed to be absent state->WriteMark = (Stub_Write_Mark)GetProcAddress(state->hModule, "Stub_Write_Mark"); + BOOL (*childAttach)() = (BOOL (*)())GetProcAddress(state->hModule, "ChildAttach"); + if (!childAttach) { + PyErr_SetFromWindowsErr(0); + return -1; + } + + if (!(*childAttach)()) { + PyErr_SetString(PyExc_RuntimeError, "Failed to attach to profiler"); + return -1; + } + return 0; } diff --git a/tests/test_diaghub.py b/tests/test_diaghub.py index febdd54..7d29da6 100644 --- a/tests/test_diaghub.py +++ b/tests/test_diaghub.py @@ -65,10 +65,28 @@ def test_but_do_we_diaghub(): with subprocess.Popen( [sys.executable, "-m", "etwtrace", "--diaghub", "--", SCRIPTS / "no_events.py"], cwd=SCRIPTS, + env={ + **os.environ, + "DIAGHUB_INSTR_COLLECTOR_ROOT": "", + "DIAGHUB_INSTR_RUNTIME_NAME": "", + }, ) as p: p.wait() assert p.returncode + # We should succeed, but can't tell what's happened + with subprocess.Popen( + [sys.executable, "-m", "etwtrace", "--diaghub", "--", SCRIPTS / "no_events.py"], + cwd=SCRIPTS, + env={ + **os.environ, + "DIAGHUB_INSTR_COLLECTOR_ROOT": str(Path(etwtrace.__file__).parent / "test"), + "DIAGHUB_INSTR_RUNTIME_NAME": "DiagnosticsHubStub.dll", + }, + ) as p: + p.wait() + assert not p.returncode + def test_but_do_we_diaghubtest(): subprocess.check_call(