Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .kokoro/presubmit/presubmit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
# Only run a subset of all nox sessions
env_vars: {
key: "NOX_SESSION"
value: "unit-3.9 unit-3.12 cover docs docfx"
value: "unit-3.10 unit-3.12 cover docs docfx"
}
7 changes: 5 additions & 2 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Context manager for Cloud Spanner batched writes."""

import functools
from typing import List, Optional

Expand Down Expand Up @@ -242,8 +243,10 @@ def commit(
observability_options=getattr(database, "observability_options", None),
metadata=metadata,
) as span, MetricsCapture():
nth_request = getattr(database, "_next_nth_request", 0)

def wrapped_method():
attempt = AtomicCounter(0)
commit_request = CommitRequest(
session=session.name,
mutations=mutations,
Expand All @@ -256,8 +259,8 @@ def wrapped_method():
# should be increased. attempt can only be increased if
# we encounter UNAVAILABLE or INTERNAL.
call_metadata, error_augmenter = database.with_error_augmentation(
getattr(database, "_next_nth_request", 0),
1,
nth_request,
attempt.increment(),
Comment on lines -259 to +263
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the existing code was correct. There are two different things that can be retried in Spanner:

  1. Aborted transactions: When a read/write transaction is aborted, then the entire transaction is retried. This should not cause attempt to be increased, even in this case, where the entire transaction is just a single Commit call.
  2. Unavailable: A single RPC can fail due to network errors, server temporarily being down etc. This is normally retried by Gax. In this case, only a single RPC (so not the entire transaction) is retried. It is only in these cases that attempt should be increased.

metadata,
span,
)
Expand Down
62 changes: 42 additions & 20 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
* a :class:`~google.cloud.spanner_v1.instance.Instance` owns a
:class:`~google.cloud.spanner_v1.database.Database`
"""

import grpc
import os
import logging
import warnings
import threading

from google.api_core.gapic_v1 import client_info
from google.auth.credentials import AnonymousCredentials
Expand Down Expand Up @@ -99,11 +101,50 @@ def _get_spanner_optimizer_statistics_package():

log = logging.getLogger(__name__)

_metrics_monitor_initialized = False
_metrics_monitor_lock = threading.Lock()


def _get_spanner_enable_builtin_metrics_env():
return os.getenv(SPANNER_DISABLE_BUILTIN_METRICS_ENV_VAR) != "true"


def _initialize_metrics(project, credentials):
"""
Initializes the Spanner built-in metrics.

This function sets up the OpenTelemetry MeterProvider and the SpannerMetricsTracerFactory.
It uses a lock to ensure that initialization happens only once.
"""
global _metrics_monitor_initialized
if not _metrics_monitor_initialized:
with _metrics_monitor_lock:
if not _metrics_monitor_initialized:
meter_provider = metrics.NoOpMeterProvider()
try:
if not _get_spanner_emulator_host():
meter_provider = MeterProvider(
metric_readers=[
PeriodicExportingMetricReader(
CloudMonitoringMetricsExporter(
project_id=project,
credentials=credentials,
),
export_interval_millis=METRIC_EXPORT_INTERVAL_MS,
),
]
)
metrics.set_meter_provider(meter_provider)
SpannerMetricsTracerFactory()
_metrics_monitor_initialized = True
except Exception as e:
# log is already defined at module level
log.warning(
"Failed to initialize Spanner built-in metrics. Error: %s",
e,
)


class Client(ClientWithProject):
"""Client for interacting with Cloud Spanner API.

Expand Down Expand Up @@ -251,31 +292,12 @@ def __init__(
"http://" in self._emulator_host or "https://" in self._emulator_host
):
warnings.warn(_EMULATOR_HOST_HTTP_SCHEME)
# Check flag to enable Spanner builtin metrics
if (
_get_spanner_enable_builtin_metrics_env()
and not disable_builtin_metrics
and HAS_GOOGLE_CLOUD_MONITORING_INSTALLED
):
meter_provider = metrics.NoOpMeterProvider()
try:
if not _get_spanner_emulator_host():
meter_provider = MeterProvider(
metric_readers=[
PeriodicExportingMetricReader(
CloudMonitoringMetricsExporter(
project_id=project, credentials=credentials
),
export_interval_millis=METRIC_EXPORT_INTERVAL_MS,
),
]
)
metrics.set_meter_provider(meter_provider)
SpannerMetricsTracerFactory()
except Exception as e:
log.warning(
"Failed to initialize Spanner built-in metrics. Error: %s", e
)
_initialize_metrics(project, credentials)
else:
SpannerMetricsTracerFactory(enabled=False)

Expand Down
24 changes: 17 additions & 7 deletions google/cloud/spanner_v1/metrics/metrics_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
performance monitoring.
"""

from contextvars import Token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: group this with the other import


from .spanner_metrics_tracer_factory import SpannerMetricsTracerFactory


Expand All @@ -30,6 +32,9 @@ class MetricsCapture:
the start and completion of metrics tracing for a given operation.
"""

_token: Token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this? Could we add a small comment for it?

"""Token to reset the context variable after the operation completes."""

def __enter__(self):
"""Enter the runtime context related to this object.

Expand All @@ -45,11 +50,11 @@ def __enter__(self):
return self

# Define a new metrics tracer for the new operation
SpannerMetricsTracerFactory.current_metrics_tracer = (
factory.create_metrics_tracer()
)
if SpannerMetricsTracerFactory.current_metrics_tracer:
SpannerMetricsTracerFactory.current_metrics_tracer.record_operation_start()
# Set the context var and keep the token for reset
tracer = factory.create_metrics_tracer()
self._token = SpannerMetricsTracerFactory.set_current_tracer(tracer)
if tracer:
tracer.record_operation_start()
return self

def __exit__(self, exc_type, exc_value, traceback):
Expand All @@ -70,6 +75,11 @@ def __exit__(self, exc_type, exc_value, traceback):
if not SpannerMetricsTracerFactory().enabled:
return False

if SpannerMetricsTracerFactory.current_metrics_tracer:
SpannerMetricsTracerFactory.current_metrics_tracer.record_operation_completion()
tracer = SpannerMetricsTracerFactory.get_current_tracer()
if tracer:
tracer.record_operation_completion()

# Reset the context var using the token
if getattr(self, "_token", None):
SpannerMetricsTracerFactory.reset_current_tracer(self._token)
return False # Propagate the exception if any
44 changes: 20 additions & 24 deletions google/cloud/spanner_v1/metrics/metrics_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,17 @@ def _set_metrics_tracer_attributes(self, resources: Dict[str, str]) -> None:
Args:
resources (Dict[str, str]): A dictionary containing project, instance, and database information.
"""
if SpannerMetricsTracerFactory.current_metrics_tracer is None:
tracer = SpannerMetricsTracerFactory.get_current_tracer()
if tracer is None:
return

if resources:
if "project" in resources:
SpannerMetricsTracerFactory.current_metrics_tracer.set_project(
resources["project"]
)
tracer.set_project(resources["project"])
if "instance" in resources:
SpannerMetricsTracerFactory.current_metrics_tracer.set_instance(
resources["instance"]
)
tracer.set_instance(resources["instance"])
if "database" in resources:
SpannerMetricsTracerFactory.current_metrics_tracer.set_database(
resources["database"]
)
tracer.set_database(resources["database"])

def intercept(self, invoked_method, request_or_iterator, call_details):
"""Intercept gRPC calls to collect metrics.
Expand All @@ -126,31 +121,32 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
The RPC response
"""
factory = SpannerMetricsTracerFactory()
if (
SpannerMetricsTracerFactory.current_metrics_tracer is None
or not factory.enabled
):
tracer = SpannerMetricsTracerFactory.get_current_tracer()
if tracer is None or not factory.enabled:
return invoked_method(request_or_iterator, call_details)

# Setup Metric Tracer attributes from call details
## Extract Project / Instance / Databse from header information
resources = self._extract_resource_from_path(call_details.metadata)
self._set_metrics_tracer_attributes(resources)
## Extract Project / Instance / Database from header information if not already set
if not (
tracer.client_attributes.get("project_id")
and tracer.client_attributes.get("instance_id")
and tracer.client_attributes.get("database")
):
resources = self._extract_resource_from_path(call_details.metadata)
self._set_metrics_tracer_attributes(resources)

## Format method to be be spanner.<method name>
method_name = self._remove_prefix(
call_details.method, SPANNER_METHOD_PREFIX
).replace("/", ".")

SpannerMetricsTracerFactory.current_metrics_tracer.set_method(method_name)
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_start()
tracer.set_method(method_name)
tracer.record_attempt_start()
response = invoked_method(request_or_iterator, call_details)
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_completion()
tracer.record_attempt_completion()

# Process and send GFE metrics if enabled
if SpannerMetricsTracerFactory.current_metrics_tracer.gfe_enabled:
if tracer.gfe_enabled:
metadata = response.initial_metadata()
SpannerMetricsTracerFactory.current_metrics_trace.record_gfe_metrics(
metadata
)
tracer.record_gfe_metrics(metadata)
return response
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import logging
from .constants import SPANNER_SERVICE_NAME
import contextvars

try:
import mmh3
Expand All @@ -43,7 +44,9 @@ class SpannerMetricsTracerFactory(MetricsTracerFactory):
"""A factory for creating SpannerMetricsTracer instances."""

_metrics_tracer_factory: "SpannerMetricsTracerFactory" = None
current_metrics_tracer: MetricsTracer = None
_current_metrics_tracer_ctx = contextvars.ContextVar(
"current_metrics_tracer", default=None
)

def __new__(
cls, enabled: bool = True, gfe_enabled: bool = False
Expand Down Expand Up @@ -80,10 +83,26 @@ def __new__(
cls._metrics_tracer_factory.gfe_enabled = gfe_enabled

if cls._metrics_tracer_factory.enabled != enabled:
cls._metrics_tracer_factory.enabeld = enabled
cls._metrics_tracer_factory.enabled = enabled

return cls._metrics_tracer_factory

@staticmethod
def get_current_tracer() -> MetricsTracer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add set_current_tracer and reset_current_tracer methods?

return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()

@staticmethod
def set_current_tracer(tracer: MetricsTracer) -> contextvars.Token:
return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.set(tracer)

@staticmethod
def reset_current_tracer(token: contextvars.Token):
SpannerMetricsTracerFactory._current_metrics_tracer_ctx.reset(token)

@property
def current_metrics_tracer(self) -> MetricsTracer:
return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()

Comment on lines +102 to +105

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

You've introduced both a static method get_current_tracer() and a property current_metrics_tracer that do the same thing: retrieve the tracer from the context variable.

The property current_metrics_tracer is problematic because it replaces a class attribute with an instance property. Any code that previously accessed SpannerMetricsTracerFactory.current_metrics_tracer will now get a property object instead of the tracer, which is a breaking change and could lead to subtle bugs.

Since all new code in this PR uses the clear and unambiguous static method get_current_tracer(), I recommend removing the redundant and potentially confusing current_metrics_tracer property. This will make the API cleaner and prevent accidental misuse.

@staticmethod
def _generate_client_uid() -> str:
"""Generate a client UID in the form of uuidv4@pid@hostname.
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add copyright header

from unittest.mock import patch


@pytest.fixture(autouse=True)
def mock_periodic_exporting_metric_reader():
"""Globally mock PeriodicExportingMetricReader to prevent real network calls."""
with patch(
"google.cloud.spanner_v1.client.PeriodicExportingMetricReader"
) as mock_client_reader, patch(
"opentelemetry.sdk.metrics.export.PeriodicExportingMetricReader"
):
yield mock_client_reader
Loading
Loading