Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 166 additions & 3 deletions pyiceberg/catalog/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Union,
)

from pydantic import Field, field_validator
from pydantic import ConfigDict, Field, field_validator
from requests import HTTPError, Session
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt

Expand Down Expand Up @@ -76,6 +76,43 @@
import pyarrow as pa


class HttpMethod(str, Enum):
GET = "GET"
HEAD = "HEAD"
POST = "POST"
DELETE = "DELETE"


class Endpoint(IcebergBaseModel):
model_config = ConfigDict(frozen=True)

http_method: HttpMethod = Field()
path: str = Field()

@field_validator("path", mode="before")
@classmethod
def _validate_path(cls, raw_path: str) -> str:
if not raw_path:
raise ValueError("Invalid path: empty")
raw_path = raw_path.strip()
if not raw_path:
raise ValueError("Invalid path: empty")
Comment on lines +95 to +99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not raw_path:
raise ValueError("Invalid path: empty")
raw_path = raw_path.strip()
if not raw_path:
raise ValueError("Invalid path: empty")
raw_path = raw_path.strip()
if not raw_path:
raise ValueError("Invalid path: empty")

i think we can just check once here

return raw_path

def __str__(self) -> str:
"""Return the string representation of the Endpoint class."""
return f"{self.http_method.value} {self.path}"

@classmethod
def from_string(cls, endpoint: str | None) -> "Endpoint":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def from_string(cls, endpoint: str | None) -> "Endpoint":
def from_string(cls, endpoint: str) -> "Endpoint":

can we enforce that endpoint must be str?

if endpoint is None:
raise ValueError("Invalid endpoint (must consist of 'METHOD /path'): None")
elements = endpoint.split(None, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elements = endpoint.split(None, 1)
elements = endpoint.strip().split(None, 1)

strip leading/trailing whitespace before split, just in case

if len(elements) != 2:
raise ValueError(f"Invalid endpoint (must consist of two elements separated by a single space): {endpoint}")
return cls(http_method=HttpMethod(elements[0].upper()), path=elements[1])


class Endpoints:
get_config: str = "config"
list_namespaces: str = "namespaces"
Expand All @@ -86,7 +123,7 @@ class Endpoints:
namespace_exists: str = "namespaces/{namespace}"
list_tables: str = "namespaces/{namespace}/tables"
create_table: str = "namespaces/{namespace}/tables"
register_table = "namespaces/{namespace}/register"
register_table: str = "namespaces/{namespace}/register"
load_table: str = "namespaces/{namespace}/tables/{table}"
update_table: str = "namespaces/{namespace}/tables/{table}"
drop_table: str = "namespaces/{namespace}/tables/{table}"
Expand All @@ -100,6 +137,66 @@ class Endpoints:
fetch_scan_tasks: str = "namespaces/{namespace}/tables/{table}/tasks"


class Capability:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can refactor the Endpoints class and consolidate this class

V1_LIST_NAMESPACES = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces")
V1_LOAD_NAMESPACE = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces/{namespace}")
V1_NAMESPACE_EXISTS = Endpoint(http_method=HttpMethod.HEAD, path="/v1/{prefix}/namespaces/{namespace}")
V1_UPDATE_NAMESPACE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/properties")
V1_CREATE_NAMESPACE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces")
V1_DELETE_NAMESPACE = Endpoint(http_method=HttpMethod.DELETE, path="/v1/{prefix}/namespaces/{namespace}")

V1_LIST_TABLES = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces/{namespace}/tables")
V1_LOAD_TABLE = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}")
V1_TABLE_EXISTS = Endpoint(http_method=HttpMethod.HEAD, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}")
V1_CREATE_TABLE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/tables")
V1_UPDATE_TABLE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}")
V1_DELETE_TABLE = Endpoint(http_method=HttpMethod.DELETE, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}")
V1_RENAME_TABLE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/tables/rename")
V1_REGISTER_TABLE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/register")

V1_LIST_VIEWS = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces/{namespace}/views")
V1_LOAD_VIEW = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces/{namespace}/views/{view}")
V1_VIEW_EXISTS = Endpoint(http_method=HttpMethod.HEAD, path="/v1/{prefix}/namespaces/{namespace}/views/{view}")
V1_CREATE_VIEW = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/views")
V1_UPDATE_VIEW = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/views/{view}")
V1_DELETE_VIEW = Endpoint(http_method=HttpMethod.DELETE, path="/v1/{prefix}/namespaces/{namespace}/views/{view}")
V1_RENAME_VIEW = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/views/rename")
V1_SUBMIT_TABLE_SCAN_PLAN = Endpoint(
http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}/plan"
)
V1_TABLE_SCAN_PLAN_TASKS = Endpoint(
http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}/tasks"
)


# Default endpoints for backwards compatibility with legacy servers that don't return endpoints
# in ConfigResponse. Only includes namespace and table endpoints.
DEFAULT_ENDPOINTS: frozenset[Endpoint] = frozenset(
(
Capability.V1_LIST_NAMESPACES,
Capability.V1_LOAD_NAMESPACE,
Capability.V1_CREATE_NAMESPACE,
Capability.V1_UPDATE_NAMESPACE,
Capability.V1_DELETE_NAMESPACE,
Capability.V1_LIST_TABLES,
Capability.V1_LOAD_TABLE,
Capability.V1_CREATE_TABLE,
Capability.V1_UPDATE_TABLE,
Capability.V1_DELETE_TABLE,
Capability.V1_RENAME_TABLE,
Capability.V1_REGISTER_TABLE,
)
)

# View endpoints conditionally added based on VIEW_ENDPOINTS_SUPPORTED property.
VIEW_ENDPOINTS: frozenset[Endpoint] = frozenset(
(
Capability.V1_LIST_VIEWS,
Capability.V1_DELETE_VIEW,
)
)


class IdentifierKind(Enum):
TABLE = "table"
VIEW = "view"
Expand Down Expand Up @@ -134,6 +231,8 @@ class IdentifierKind(Enum):
CUSTOM = "custom"
REST_SCAN_PLANNING_ENABLED = "rest-scan-planning-enabled"
REST_SCAN_PLANNING_ENABLED_DEFAULT = False
VIEW_ENDPOINTS_SUPPORTED = "view-endpoints-supported"
VIEW_ENDPOINTS_SUPPORTED_DEFAULT = False

NAMESPACE_SEPARATOR = b"\x1f".decode(UTF8)

Expand Down Expand Up @@ -180,6 +279,14 @@ class RegisterTableRequest(IcebergBaseModel):
class ConfigResponse(IcebergBaseModel):
defaults: Properties | None = Field(default_factory=dict)
overrides: Properties | None = Field(default_factory=dict)
endpoints: set[Endpoint] | None = Field(default=None)

@field_validator("endpoints", mode="before")
@classmethod
def _parse_endpoints(cls, v: list[str] | None) -> set[Endpoint] | None:
if v is None:
return None
return {Endpoint.from_string(s) for s in v}


class ListNamespaceResponse(IcebergBaseModel):
Expand Down Expand Up @@ -218,6 +325,7 @@ class ListViewsResponse(IcebergBaseModel):
class RestCatalog(Catalog):
uri: str
_session: Session
_supported_endpoints: set[Endpoint]

def __init__(self, name: str, **properties: str):
"""Rest Catalog.
Expand Down Expand Up @@ -279,7 +387,9 @@ def is_rest_scan_planning_enabled(self) -> bool:
Returns:
True if enabled, False otherwise.
"""
return property_as_bool(self.properties, REST_SCAN_PLANNING_ENABLED, REST_SCAN_PLANNING_ENABLED_DEFAULT)
return Capability.V1_SUBMIT_TABLE_SCAN_PLAN in self._supported_endpoints and property_as_bool(
self.properties, REST_SCAN_PLANNING_ENABLED, REST_SCAN_PLANNING_ENABLED_DEFAULT
)

def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager:
"""Create the LegacyOAuth2AuthManager by fetching required properties.
Expand Down Expand Up @@ -327,6 +437,18 @@ def url(self, endpoint: str, prefixed: bool = True, **kwargs: Any) -> str:

return url + endpoint.format(**kwargs)

def _check_endpoint(self, endpoint: Endpoint) -> None:
"""Check if an endpoint is supported by the server.

Args:
endpoint: The endpoint to check against the set of supported endpoints

Raises:
NotImplementedError: If the endpoint is not supported.
"""
if endpoint not in self._supported_endpoints:
raise NotImplementedError(f"Server does not support endpoint: {endpoint}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: java throws UnsupportedOperationException here


@property
def auth_url(self) -> str:
self._warn_oauth_tokens_deprecation()
Expand Down Expand Up @@ -384,6 +506,17 @@ def _fetch_config(self) -> None:
# Update URI based on overrides
self.uri = config[URI]

# Determine supported endpoints
endpoints = config_response.endpoints
if endpoints:
self._supported_endpoints = set(endpoints)
else:
# Use default endpoints for legacy servers that don't return endpoints
self._supported_endpoints = set(DEFAULT_ENDPOINTS)
# Conditionally add view endpoints based on config
if property_as_bool(self.properties, VIEW_ENDPOINTS_SUPPORTED, VIEW_ENDPOINTS_SUPPORTED_DEFAULT):
self._supported_endpoints.update(VIEW_ENDPOINTS)

def _identifier_to_validated_tuple(self, identifier: str | Identifier) -> Identifier:
identifier_tuple = self.identifier_to_tuple(identifier)
if len(identifier_tuple) <= 1:
Expand Down Expand Up @@ -503,6 +636,7 @@ def _create_table(
properties: Properties = EMPTY_DICT,
stage_create: bool = False,
) -> TableResponse:
self._check_endpoint(Capability.V1_CREATE_TABLE)
iceberg_schema = self._convert_schema_if_needed(
schema,
int(properties.get(TableProperties.FORMAT_VERSION, TableProperties.DEFAULT_FORMAT_VERSION)), # type: ignore
Expand Down Expand Up @@ -591,6 +725,7 @@ def register_table(self, identifier: str | Identifier, metadata_location: str) -
Raises:
TableAlreadyExistsError: If the table already exists
"""
self._check_endpoint(Capability.V1_REGISTER_TABLE)
namespace_and_table = self._split_identifier_for_path(identifier)
request = RegisterTableRequest(
name=namespace_and_table["table"],
Expand All @@ -611,6 +746,7 @@ def register_table(self, identifier: str | Identifier, metadata_location: str) -

@retry(**_RETRY_ARGS)
def list_tables(self, namespace: str | Identifier) -> list[Identifier]:
self._check_endpoint(Capability.V1_LIST_TABLES)
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
response = self._session.get(self.url(Endpoints.list_tables, namespace=namespace_concat))
Expand All @@ -622,6 +758,7 @@ def list_tables(self, namespace: str | Identifier) -> list[Identifier]:

@retry(**_RETRY_ARGS)
def load_table(self, identifier: str | Identifier) -> Table:
self._check_endpoint(Capability.V1_LOAD_TABLE)
params = {}
if mode := self.properties.get(SNAPSHOT_LOADING_MODE):
if mode in {"all", "refs"}:
Expand All @@ -642,6 +779,7 @@ def load_table(self, identifier: str | Identifier) -> Table:

@retry(**_RETRY_ARGS)
def drop_table(self, identifier: str | Identifier, purge_requested: bool = False) -> None:
self._check_endpoint(Capability.V1_DELETE_TABLE)
response = self._session.delete(
self.url(Endpoints.drop_table, prefixed=True, **self._split_identifier_for_path(identifier)),
params={"purgeRequested": purge_requested},
Expand All @@ -657,6 +795,7 @@ def purge_table(self, identifier: str | Identifier) -> None:

@retry(**_RETRY_ARGS)
def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table:
self._check_endpoint(Capability.V1_RENAME_TABLE)
payload = {
"source": self._split_identifier_for_json(from_identifier),
"destination": self._split_identifier_for_json(to_identifier),
Expand Down Expand Up @@ -692,6 +831,8 @@ def _remove_catalog_name_from_table_request_identifier(self, table_request: Comm

@retry(**_RETRY_ARGS)
def list_views(self, namespace: str | Identifier) -> list[Identifier]:
if Capability.V1_LIST_VIEWS not in self._supported_endpoints:
return []
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
response = self._session.get(self.url(Endpoints.list_views, namespace=namespace_concat))
Expand Down Expand Up @@ -720,6 +861,7 @@ def commit_table(
CommitFailedException: Requirement not met, or a conflict with a concurrent commit.
CommitStateUnknownException: Failed due to an internal exception on the side of the catalog.
"""
self._check_endpoint(Capability.V1_UPDATE_TABLE)
identifier = table.name()
table_identifier = TableIdentifier(namespace=identifier[:-1], name=identifier[-1])
table_request = CommitTableRequest(identifier=table_identifier, requirements=requirements, updates=updates)
Expand Down Expand Up @@ -749,6 +891,7 @@ def commit_table(

@retry(**_RETRY_ARGS)
def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None:
self._check_endpoint(Capability.V1_CREATE_NAMESPACE)
namespace_tuple = self._check_valid_namespace_identifier(namespace)
payload = {"namespace": namespace_tuple, "properties": properties}
response = self._session.post(self.url(Endpoints.create_namespace), json=payload)
Expand All @@ -759,6 +902,7 @@ def create_namespace(self, namespace: str | Identifier, properties: Properties =

@retry(**_RETRY_ARGS)
def drop_namespace(self, namespace: str | Identifier) -> None:
self._check_endpoint(Capability.V1_DELETE_NAMESPACE)
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
response = self._session.delete(self.url(Endpoints.drop_namespace, namespace=namespace))
Expand All @@ -769,6 +913,7 @@ def drop_namespace(self, namespace: str | Identifier) -> None:

@retry(**_RETRY_ARGS)
def list_namespaces(self, namespace: str | Identifier = ()) -> list[Identifier]:
self._check_endpoint(Capability.V1_LIST_NAMESPACES)
namespace_tuple = self.identifier_to_tuple(namespace)
response = self._session.get(
self.url(
Expand All @@ -786,6 +931,7 @@ def list_namespaces(self, namespace: str | Identifier = ()) -> list[Identifier]:

@retry(**_RETRY_ARGS)
def load_namespace_properties(self, namespace: str | Identifier) -> Properties:
self._check_endpoint(Capability.V1_LOAD_NAMESPACE)
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
response = self._session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace))
Expand All @@ -800,6 +946,7 @@ def load_namespace_properties(self, namespace: str | Identifier) -> Properties:
def update_namespace_properties(
self, namespace: str | Identifier, removals: set[str] | None = None, updates: Properties = EMPTY_DICT
) -> PropertiesUpdateSummary:
self._check_endpoint(Capability.V1_UPDATE_NAMESPACE)
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
payload = {"removals": list(removals or []), "updates": updates}
Expand All @@ -819,6 +966,14 @@ def update_namespace_properties(
def namespace_exists(self, namespace: str | Identifier) -> bool:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)

if Capability.V1_NAMESPACE_EXISTS not in self._supported_endpoints:
try:
self.load_namespace_properties(namespace_tuple)
return True
except NoSuchNamespaceError:
return False

Comment on lines +969 to +976
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

response = self._session.head(self.url(Endpoints.namespace_exists, namespace=namespace))

if response.status_code == 404:
Expand All @@ -843,6 +998,13 @@ def table_exists(self, identifier: str | Identifier) -> bool:
Returns:
bool: True if the table exists, False otherwise.
"""
if Capability.V1_TABLE_EXISTS not in self._supported_endpoints:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try:
self.load_table(identifier)
return True
except NoSuchTableError:
return False

response = self._session.head(
self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier))
)
Expand Down Expand Up @@ -886,6 +1048,7 @@ def view_exists(self, identifier: str | Identifier) -> bool:

@retry(**_RETRY_ARGS)
def drop_view(self, identifier: str) -> None:
self._check_endpoint(Capability.V1_DELETE_VIEW)
response = self._session.delete(
self.url(Endpoints.drop_view, prefixed=True, **self._split_identifier_for_path(identifier, IdentifierKind.VIEW)),
)
Expand Down
Loading