From 19695aa2f110253d1c8641fef122ccb7d8c85fc4 Mon Sep 17 00:00:00 2001 From: Gaudy Blanco Meneses Date: Thu, 12 Feb 2026 14:23:22 -0600 Subject: [PATCH 1/6] work in progress --- .../graphrag_storage/azure_cosmos_storage.py | 152 ++++++++++-------- .../graphrag/index/run/run_pipeline.py | 2 + .../index/workflows/create_communities.py | 5 +- 3 files changed, 95 insertions(+), 64 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index ff3ec6decb..a57b656aae 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -37,7 +37,7 @@ class AzureCosmosStorage(Storage): _database_name: str _container_name: str _encoding: str - _no_id_prefixes: list[str] + _no_id_prefixes: set[str] = set() def __init__( self, @@ -81,7 +81,7 @@ def __init__( self._cosmosdb_account_name = ( account_url.split("//")[1].split(".")[0] if account_url else None ) - self._no_id_prefixes = [] + self._no_id_prefixes = set() logger.debug( "Creating cosmosdb storage with account [%s] and database [%s] and container [%s]", self._cosmosdb_account_name, @@ -185,79 +185,105 @@ def find( "An error occurred while searching for documents in Cosmos DB." ) - async def get( - self, key: str, as_bytes: bool | None = None, encoding: str | None = None - ) -> Any: - """Fetch all items in a container that match the given key.""" - try: - if not self._database_client or not self._container_client: + async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None = None) -> Any: + """Fetch all items in a container that match the given key.""" + try: + if not self._database_client or not self._container_client: return None - if as_bytes: - prefix = self._get_prefix(key) - query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}')" # noqa: S608 - queried_items = self._container_client.query_items( - query=query, enable_cross_partition_query=True - ) - items_list = list(queried_items) - for item in items_list: - item["id"] = item["id"].split(":")[1] - - items_json_str = json.dumps(items_list) - - items_df = pd.read_json( - StringIO(items_json_str), orient="records", lines=False - ) - # Drop the "id" column if the original dataframe does not include it - # TODO: Figure out optimal way to handle missing id keys in input dataframes - if prefix in self._no_id_prefixes: - items_df.drop(columns=["id"], axis=1, inplace=True) - - return items_df.to_parquet() - item = self._container_client.read_item(item=key, partition_key=key) - item_body = item.get("body") - return json.dumps(item_body) - except Exception: # noqa: BLE001 - logger.warning("Error reading item %s", key) + if as_bytes: + prefix = self._get_prefix(key) + logger.info(f"Test Prefix: {prefix}") + query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608 + + queried_items = self._container_client.query_items( query=query, enable_cross_partition_query=True ) + items_list = list(queried_items) + + logger.info("Cosmos load prefix=%s count=%d", prefix, len(items_list)) + + if not items_list: + logger.warning("No items found for prefix %s (key=%s)", prefix, key) + return None + + for item in items_list: + item["id"] = item["id"].split(":",1)[1] + + items_json_str = json.dumps(items_list) + items_df = pd.read_json( StringIO(items_json_str), orient="records", lines=False) + + if prefix == "entities": + # Always preserve the Cosmos suffix for debugging/migrations + items_df["cosmos_id"] = items_df["id"] + items_df["id"] = items_df["id"].astype(int) # Only restore pipeline UUID id if we actually have it + + if "human_readable_id" in items_df.columns: + items_df["human_readable_id"] = items_df["human_readable_id"].astype(int) + else: + # Fresh run case: extract_graph entities may not have entity_id yet + # Keep id as the suffix (stable_key/index) for now. + logger.info("Entities loaded without entity_id; leaving id as cosmos suffix.") + + if items_df.empty: + logger.warning("No rows returned for prefix %s (key=%s)", prefix, key) + return None + return items_df.to_parquet() + item = self._container_client.read_item(item=key, partition_key=key) + item_body = item.get("body") + return json.dumps(item_body) + except Exception: # noqa: BLE001 + logger.warning("Error reading item %s", key) return None - async def set(self, key: str, value: Any, encoding: str | None = None) -> None: - """Insert the contents of a file into a cosmosdb container for the given filename key. - For better optimization, the file is destructured such that each row is a unique cosmosdb item. - """ + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: try: if not self._database_client or not self._container_client: - msg = "Database or container not initialized" - raise ValueError(msg) # noqa: TRY301 - # value represents a parquet file + raise ValueError("Database or container not initialized") + if isinstance(value, bytes): prefix = self._get_prefix(key) value_df = pd.read_parquet(BytesIO(value)) - value_json = value_df.to_json( - orient="records", lines=False, force_ascii=False - ) - if value_json is None: - logger.error("Error converting output %s to json", key) + + # Decide once per dataframe + df_has_id = "id" in value_df.columns + + # IMPORTANT: if we now have ids, undo the earlier "no id" marking + if df_has_id: + self._no_id_prefixes.discard(prefix) else: - cosmosdb_item_list = json.loads(value_json) - for index, cosmosdb_item in enumerate(cosmosdb_item_list): - # If the id key does not exist in the input dataframe json, create a unique id using the prefix and item index - # TODO: Figure out optimal way to handle missing id keys in input dataframes - if "id" not in cosmosdb_item: - prefixed_id = f"{prefix}:{index}" - self._no_id_prefixes.append(prefix) + self._no_id_prefixes.add(prefix) + + cosmosdb_item_list = json.loads( + value_df.to_json(orient="records", lines=False, force_ascii=False) + ) + + for index, cosmosdb_item in enumerate(cosmosdb_item_list): + if prefix == "entities": + # Stable key for Cosmos identity + stable_key = cosmosdb_item.get("human_readable_id", index) + cosmos_id = f"{prefix}:{stable_key}" + + # If the pipeline provided a final UUID, store it separately + if "id" in cosmosdb_item: + cosmosdb_item["entity_id"] = cosmosdb_item["id"] + + # Cosmos identity must be stable and NEVER change + cosmosdb_item["id"] = cosmos_id + logger.info("Print ids") + logger.info(f"{cosmos_id}") + + else: + # Original behavior for non-entity prefixes + if df_has_id: + cosmosdb_item["id"] = f"{prefix}:{cosmosdb_item['id']}" else: - prefixed_id = f"{prefix}:{cosmosdb_item['id']}" - cosmosdb_item["id"] = prefixed_id - self._container_client.upsert_item(body=cosmosdb_item) - # value represents a cache output or stats.json + cosmosdb_item["id"] = f"{prefix}:{index}" + + self._container_client.upsert_item(body=cosmosdb_item) else: - cosmosdb_item = { - "id": key, - "body": json.loads(value), - } + cosmosdb_item = {"id": key, "body": json.loads(value)} self._container_client.upsert_item(body=cosmosdb_item) + except Exception: logger.exception("Error writing item %s", key) @@ -267,7 +293,7 @@ async def has(self, key: str) -> bool: return False if ".parquet" in key: prefix = self._get_prefix(key) - query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}')" # noqa: S608 + query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608 queried_items = self._container_client.query_items( query=query, enable_cross_partition_query=True ) @@ -285,7 +311,7 @@ async def delete(self, key: str) -> None: try: if ".parquet" in key: prefix = self._get_prefix(key) - query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}')" # noqa: S608 + query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608 queried_items = self._container_client.query_items( query=query, enable_cross_partition_query=True ) diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index a4ce17582c..c9c368da15 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -41,6 +41,8 @@ async def run_pipeline( # load existing state in case any workflows are stateful state_json = await output_storage.get("context.json") + logger.info("Printing state json") + logger.info(state_json) state = json.loads(state_json) if state_json else {} if additional_context: diff --git a/packages/graphrag/graphrag/index/workflows/create_communities.py b/packages/graphrag/graphrag/index/workflows/create_communities.py index 4394593e99..6c93e31c09 100644 --- a/packages/graphrag/graphrag/index/workflows/create_communities.py +++ b/packages/graphrag/graphrag/index/workflows/create_communities.py @@ -28,10 +28,13 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform final communities.""" logger.info("Workflow started: create_communities") - entities = await load_table_from_storage("entities", context.output_storage) + logger.info("Amount of relationships:") relationships = await load_table_from_storage( "relationships", context.output_storage ) + logger.info(len(relationships)) + entities = await load_table_from_storage("entities", context.output_storage) + logger.info(entities) max_cluster_size = config.cluster_graph.max_cluster_size use_lcc = config.cluster_graph.use_lcc From b8e515b139808cab75de63ae10fd5bcea46484ad Mon Sep 17 00:00:00 2001 From: Gaudy Blanco Meneses Date: Sat, 14 Feb 2026 21:43:38 -0600 Subject: [PATCH 2/6] cosmosdb output error fix --- .../graphrag_storage/azure_cosmos_storage.py | 212 ++++++++++++------ .../tables/cosmosdb_table_provider.py | 110 +++++++++ .../tables/table_provider_factory.py | 7 + .../graphrag_storage/tables/table_type.py | 1 + packages/graphrag/graphrag/data_model/dfs.py | 15 +- .../graphrag/index/run/run_pipeline.py | 7 + 6 files changed, 274 insertions(+), 78 deletions(-) create mode 100644 packages/graphrag-storage/graphrag_storage/tables/cosmosdb_table_provider.py diff --git a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index a57b656aae..9828d81486 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -25,6 +25,8 @@ logger = logging.getLogger(__name__) +_DEFAULT_PAGE_SIZE = 100 + class AzureCosmosStorage(Storage): """The CosmosDB-Storage Implementation.""" @@ -37,7 +39,7 @@ class AzureCosmosStorage(Storage): _database_name: str _container_name: str _encoding: str - _no_id_prefixes: set[str] = set() + _no_id_prefixes: set[str] def __init__( self, @@ -51,7 +53,7 @@ def __init__( """Create a CosmosDB storage instance.""" logger.info("Creating cosmosdb storage") database_name = database_name - if database_name is None: + if not database_name: msg = "CosmosDB Storage requires a base_dir to be specified. This is used as the database name." logger.error(msg) raise ValueError(msg) @@ -150,12 +152,10 @@ def find( {"name": "@pattern", "value": file_pattern.pattern} ] - items = list( - self._container_client.query_items( - query=query, - parameters=parameters, - enable_cross_partition_query=True, - ) + items = self._query_all_items( + self._container_client, + query=query, + parameters=parameters, ) logger.debug("All items: %s", [item["id"] for item in items]) num_loaded = 0 @@ -185,92 +185,107 @@ def find( "An error occurred while searching for documents in Cosmos DB." ) - async def get(self, key: str, as_bytes: bool | None = None, encoding: str | None = None) -> Any: - """Fetch all items in a container that match the given key.""" - try: - if not self._database_client or not self._container_client: + async def get( + self, key: str, as_bytes: bool | None = None, encoding: str | None = None + ) -> Any: + """Fetch all items in a container that match the given key.""" + try: + if not self._database_client or not self._container_client: return None - if as_bytes: - prefix = self._get_prefix(key) - logger.info(f"Test Prefix: {prefix}") - query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608 + if as_bytes: + prefix = self._get_prefix(key) + query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608 + items_list = self._query_all_items( + self._container_client, + query=query, + ) - queried_items = self._container_client.query_items( query=query, enable_cross_partition_query=True ) - items_list = list(queried_items) - logger.info("Cosmos load prefix=%s count=%d", prefix, len(items_list)) - if not items_list: - logger.warning("No items found for prefix %s (key=%s)", prefix, key) + if not items_list: + logger.warning("No items found for prefix %s (key=%s)", prefix, key) return None - for item in items_list: - item["id"] = item["id"].split(":",1)[1] - - items_json_str = json.dumps(items_list) - items_df = pd.read_json( StringIO(items_json_str), orient="records", lines=False) - - if prefix == "entities": - # Always preserve the Cosmos suffix for debugging/migrations - items_df["cosmos_id"] = items_df["id"] - items_df["id"] = items_df["id"].astype(int) # Only restore pipeline UUID id if we actually have it - - if "human_readable_id" in items_df.columns: - items_df["human_readable_id"] = items_df["human_readable_id"].astype(int) - else: - # Fresh run case: extract_graph entities may not have entity_id yet - # Keep id as the suffix (stable_key/index) for now. - logger.info("Entities loaded without entity_id; leaving id as cosmos suffix.") - - if items_df.empty: - logger.warning("No rows returned for prefix %s (key=%s)", prefix, key) - return None - return items_df.to_parquet() - item = self._container_client.read_item(item=key, partition_key=key) - item_body = item.get("body") - return json.dumps(item_body) - except Exception: # noqa: BLE001 - logger.warning("Error reading item %s", key) - return None + for item in items_list: + item["id"] = item["id"].split(":", 1)[1] + items_json_str = json.dumps(items_list) + items_df = pd.read_json( + StringIO(items_json_str), orient="records", lines=False + ) + + if prefix == "entities": + # Always preserve the Cosmos suffix for debugging/migrations + items_df["cosmos_id"] = items_df["id"] + items_df["id"] = items_df["id"].astype( + str + ) # Only restore pipeline UUID id if we actually have it + + if "human_readable_id" in items_df.columns: + # Fill any NaN values before converting to int + items_df["human_readable_id"] = ( + items_df["human_readable_id"] + .fillna(items_df["id"]) + .astype(int) + ) + else: + # Fresh run case: extract_graph entities may not have entity_id yet + # Keep id as the suffix (stable_key/index) for now. + logger.info( + "Entities loaded without entity_id; leaving id as cosmos suffix." + ) + + if items_df.empty: + logger.warning( + "No rows returned for prefix %s (key=%s)", prefix, key + ) + return None + return items_df.to_parquet() + item = self._container_client.read_item(item=key, partition_key=key) + item_body = item.get("body") + return json.dumps(item_body) + except Exception: + logger.exception("Error reading item %s", key) + return None async def set(self, key: str, value: Any, encoding: str | None = None) -> None: + """Write an item to Cosmos DB. If the value is bytes, we assume it's a parquet file and we write each row as a separate item with id formatted as {prefix}:{stable_key_or_index}.""" + if not self._database_client or not self._container_client: + error_msg = "Database or container not initialized. Cannot write item." + raise ValueError(error_msg) try: - if not self._database_client or not self._container_client: - raise ValueError("Database or container not initialized") - if isinstance(value, bytes): prefix = self._get_prefix(key) value_df = pd.read_parquet(BytesIO(value)) - + # Decide once per dataframe df_has_id = "id" in value_df.columns - + # IMPORTANT: if we now have ids, undo the earlier "no id" marking if df_has_id: self._no_id_prefixes.discard(prefix) else: self._no_id_prefixes.add(prefix) - + cosmosdb_item_list = json.loads( value_df.to_json(orient="records", lines=False, force_ascii=False) ) - + for index, cosmosdb_item in enumerate(cosmosdb_item_list): if prefix == "entities": # Stable key for Cosmos identity stable_key = cosmosdb_item.get("human_readable_id", index) cosmos_id = f"{prefix}:{stable_key}" - + # If the pipeline provided a final UUID, store it separately if "id" in cosmosdb_item: cosmosdb_item["entity_id"] = cosmosdb_item["id"] - + # Cosmos identity must be stable and NEVER change cosmosdb_item["id"] = cosmos_id logger.info("Print ids") - logger.info(f"{cosmos_id}") + logger.info("%s", cosmos_id) else: # Original behavior for non-entity prefixes @@ -278,12 +293,12 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: cosmosdb_item["id"] = f"{prefix}:{cosmosdb_item['id']}" else: cosmosdb_item["id"] = f"{prefix}:{index}" - + self._container_client.upsert_item(body=cosmosdb_item) else: cosmosdb_item = {"id": key, "body": json.loads(value)} self._container_client.upsert_item(body=cosmosdb_item) - + except Exception: logger.exception("Error writing item %s", key) @@ -293,16 +308,66 @@ async def has(self, key: str) -> bool: return False if ".parquet" in key: prefix = self._get_prefix(key) - query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608 - queried_items = self._container_client.query_items( - query=query, enable_cross_partition_query=True + count = self._query_count( + self._container_client, + query_filter=f"STARTSWITH(c.id, '{prefix}:')", ) - return len(list(queried_items)) > 0 - query = f"SELECT * FROM c WHERE c.id = '{key}'" # noqa: S608 - queried_items = self._container_client.query_items( - query=query, enable_cross_partition_query=True + return count > 0 + count = self._query_count( + self._container_client, + query_filter=f"c.id = '{key}'", ) - return len(list(queried_items)) == 1 + return count >= 1 + + def _query_all_items( + self, + container_client: ContainerProxy, + query: str, + parameters: list[dict[str, Any]] | None = None, + page_size: int = _DEFAULT_PAGE_SIZE, + ) -> list[dict[str, Any]]: + """Fetch all items from a Cosmos DB query using pagination. + + This avoids the pitfalls of calling list() on the full pager, which can + time out or return incomplete results for large result sets. + """ + results: list[dict[str, Any]] = [] + query_kwargs: dict[str, Any] = { + "query": query, + "enable_cross_partition_query": True, + "max_item_count": page_size, + } + if parameters: + query_kwargs["parameters"] = parameters + + pager = container_client.query_items(**query_kwargs).by_page() + for page in pager: + results.extend(page) + return results + + def _query_count( + self, + container_client: ContainerProxy, + query_filter: str, + parameters: list[dict[str, Any]] | None = None, + ) -> int: + """Return the count of items matching a filter, without fetching them all. + + Parameters + ---------- + query_filter: + The WHERE clause (without 'WHERE'), e.g. "STARTSWITH(c.id, 'entities:')". + """ + count_query = f"SELECT VALUE COUNT(1) FROM c WHERE {query_filter}" # noqa: S608 + query_kwargs: dict[str, Any] = { + "query": count_query, + "enable_cross_partition_query": True, + } + if parameters: + query_kwargs["parameters"] = parameters + + results = list(container_client.query_items(**query_kwargs)) + return int(results[0]) if results else 0 # type: ignore[arg-type] async def delete(self, key: str) -> None: """Delete all cosmosdb items belonging to the given filename key.""" @@ -311,11 +376,12 @@ async def delete(self, key: str) -> None: try: if ".parquet" in key: prefix = self._get_prefix(key) - query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608 - queried_items = self._container_client.query_items( - query=query, enable_cross_partition_query=True + query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608 + items = self._query_all_items( + self._container_client, + query=query, ) - for item in queried_items: + for item in items: self._container_client.delete_item( item=item["id"], partition_key=item["id"] ) diff --git a/packages/graphrag-storage/graphrag_storage/tables/cosmosdb_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/cosmosdb_table_provider.py new file mode 100644 index 0000000000..5ce3474d5b --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/cosmosdb_table_provider.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parquet-based table provider implementation.""" + +import logging +import re +from io import BytesIO + +import pandas as pd + +from graphrag_storage.storage import Storage +from graphrag_storage.tables.table_provider import TableProvider + +logger = logging.getLogger(__name__) + + +class CosmosDBTableProvider(TableProvider): + """Table provider that stores tables as Parquet files using an underlying Storage instance. + + This provider converts between pandas DataFrames and Parquet format, + storing the data through a Storage backend (file, blob, cosmos, etc.). + """ + + def __init__(self, storage: Storage, **kwargs) -> None: + """Initialize the Parquet table provider with an underlying storage instance. + + Args + ---- + storage: Storage + The storage instance to use for reading and writing Parquet files. + **kwargs: Any + Additional keyword arguments (currently unused). + """ + self._storage = storage + + async def read_dataframe(self, table_name: str) -> pd.DataFrame: + """Read a table from storage as a pandas DataFrame. + + Args + ---- + table_name: str + The name of the table to read. The file will be accessed as '{table_name}.parquet'. + + Returns + ------- + pd.DataFrame: + The table data loaded from the Parquet file. + + Raises + ------ + ValueError: + If the table file does not exist in storage. + Exception: + If there is an error reading or parsing the Parquet file. + """ + filename = f"{table_name}.parquet" + file_exists = await self._storage.has(filename) + + if not file_exists: + msg = f"Could not find {filename} in storage!" + raise ValueError(msg) + try: + logger.info("reading table from storage: %s", filename) + return pd.read_parquet( + BytesIO(await self._storage.get(filename, as_bytes=True)) + ) + except Exception: + logger.exception("error loading table from storage: %s", filename) + raise + + async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: + """Write a pandas DataFrame to storage as a Parquet file. + + Args + ---- + table_name: str + The name of the table to write. The file will be saved as '{table_name}.parquet'. + df: pd.DataFrame + The DataFrame to write to storage. + """ + await self._storage.set(f"{table_name}.parquet", df.to_parquet()) + + async def has(self, table_name: str) -> bool: + """Check if a table exists in storage. + + Args + ---- + table_name: str + The name of the table to check. + + Returns + ------- + bool: + True if the table exists, False otherwise. + """ + return await self._storage.has(f"{table_name}.parquet") + + def list(self) -> list[str]: + """List all table names in storage. + + Returns + ------- + list[str]: + List of table names (without .parquet extension). + """ + return [ + file.replace(".parquet", "") + for file in self._storage.find(re.compile(r"\.parquet$")) + ] diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py index d79c01e07b..b9583f55f4 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py @@ -60,6 +60,13 @@ def create_table_provider( if table_type not in table_provider_factory: match table_type: + case TableType.CosmosDB: + from graphrag_storage.tables.cosmosdb_table_provider import ( + CosmosDBTableProvider, + ) + + register_table_provider(TableType.CosmosDB, CosmosDBTableProvider) + case TableType.Parquet: from graphrag_storage.tables.parquet_table_provider import ( ParquetTableProvider, diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_type.py b/packages/graphrag-storage/graphrag_storage/tables/table_type.py index 3397390b77..bc24905f47 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/table_type.py +++ b/packages/graphrag-storage/graphrag_storage/tables/table_type.py @@ -12,3 +12,4 @@ class TableType(StrEnum): Parquet = "parquet" CSV = "csv" + CosmosDB = "cosmosdb" diff --git a/packages/graphrag/graphrag/data_model/dfs.py b/packages/graphrag/graphrag/data_model/dfs.py index d6d7e729fc..7d8ae1f5ab 100644 --- a/packages/graphrag/graphrag/data_model/dfs.py +++ b/packages/graphrag/graphrag/data_model/dfs.py @@ -28,6 +28,11 @@ ) +def _safe_int(series: pd.Series, fill: int = -1) -> pd.Series: + """Convert a series to int, filling NaN values first.""" + return series.fillna(fill).astype(int) + + def _split_list_column(value: Any) -> list[Any]: """Split a column containing a list string into an actual list.""" if isinstance(value, str): @@ -38,13 +43,13 @@ def _split_list_column(value: Any) -> list[Any]: def entities_typed(df: pd.DataFrame) -> pd.DataFrame: """Return the entities dataframe with correct types, in case it was stored in a weakly-typed format.""" if SHORT_ID in df.columns: - df[SHORT_ID] = df[SHORT_ID].astype(int) + df[SHORT_ID] = _safe_int(df[SHORT_ID]) if TEXT_UNIT_IDS in df.columns: df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column) if NODE_FREQUENCY in df.columns: - df[NODE_FREQUENCY] = df[NODE_FREQUENCY].astype(int) + df[NODE_FREQUENCY] = _safe_int(df[NODE_FREQUENCY], 0) if NODE_DEGREE in df.columns: - df[NODE_DEGREE] = df[NODE_DEGREE].astype(int) + df[NODE_DEGREE] = _safe_int(df[NODE_DEGREE], 0) return df @@ -52,11 +57,11 @@ def entities_typed(df: pd.DataFrame) -> pd.DataFrame: def relationships_typed(df: pd.DataFrame) -> pd.DataFrame: """Return the relationships dataframe with correct types, in case it was stored in a weakly-typed format.""" if SHORT_ID in df.columns: - df[SHORT_ID] = df[SHORT_ID].astype(int) + df[SHORT_ID] = _safe_int(df[SHORT_ID]) if EDGE_WEIGHT in df.columns: df[EDGE_WEIGHT] = df[EDGE_WEIGHT].astype(float) if EDGE_DEGREE in df.columns: - df[EDGE_DEGREE] = df[EDGE_DEGREE].astype(int) + df[EDGE_DEGREE] = _safe_int(df[EDGE_DEGREE], 0) if TEXT_UNIT_IDS in df.columns: df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column) diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index 68b757d60c..7a10cedb56 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -13,8 +13,10 @@ import pandas as pd from graphrag_cache import create_cache from graphrag_storage import create_storage +from graphrag_storage.storage_type import StorageType from graphrag_storage.tables.table_provider import TableProvider from graphrag_storage.tables.table_provider_factory import create_table_provider +from graphrag_storage.tables.table_type import TableType from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -39,6 +41,11 @@ async def run_pipeline( input_storage = create_storage(config.input_storage) output_storage = create_storage(config.output_storage) + + # Workaround to fix cosmosdb incompatibilities using a new provider + if config.output_storage.type == StorageType.AzureCosmos: + config.table_provider.type = TableType.CosmosDB + output_table_provider = create_table_provider(config.table_provider, output_storage) cache = create_cache(config.cache) From 3719b384df435cfbe24e6c183b61eefe3335808f Mon Sep 17 00:00:00 2001 From: Gaudy Blanco Meneses Date: Sat, 14 Feb 2026 21:49:18 -0600 Subject: [PATCH 3/6] semserver update --- .semversioner/next-release/patch-20260215034903124458.json | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .semversioner/next-release/patch-20260215034903124458.json diff --git a/.semversioner/next-release/patch-20260215034903124458.json b/.semversioner/next-release/patch-20260215034903124458.json new file mode 100644 index 0000000000..15c4511f45 --- /dev/null +++ b/.semversioner/next-release/patch-20260215034903124458.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "add support for cosmosdb output" +} From e0778ffe2ab83564ee5bffeeab4ef22c08800b3c Mon Sep 17 00:00:00 2001 From: Gaudy Blanco Meneses Date: Sun, 15 Feb 2026 00:06:10 -0600 Subject: [PATCH 4/6] remove unnecessary code --- .../tables/cosmosdb_table_provider.py | 110 ------------------ .../tables/table_provider_factory.py | 7 -- .../graphrag_storage/tables/table_type.py | 1 - .../graphrag/index/run/run_pipeline.py | 6 - 4 files changed, 124 deletions(-) delete mode 100644 packages/graphrag-storage/graphrag_storage/tables/cosmosdb_table_provider.py diff --git a/packages/graphrag-storage/graphrag_storage/tables/cosmosdb_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/cosmosdb_table_provider.py deleted file mode 100644 index 5ce3474d5b..0000000000 --- a/packages/graphrag-storage/graphrag_storage/tables/cosmosdb_table_provider.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Parquet-based table provider implementation.""" - -import logging -import re -from io import BytesIO - -import pandas as pd - -from graphrag_storage.storage import Storage -from graphrag_storage.tables.table_provider import TableProvider - -logger = logging.getLogger(__name__) - - -class CosmosDBTableProvider(TableProvider): - """Table provider that stores tables as Parquet files using an underlying Storage instance. - - This provider converts between pandas DataFrames and Parquet format, - storing the data through a Storage backend (file, blob, cosmos, etc.). - """ - - def __init__(self, storage: Storage, **kwargs) -> None: - """Initialize the Parquet table provider with an underlying storage instance. - - Args - ---- - storage: Storage - The storage instance to use for reading and writing Parquet files. - **kwargs: Any - Additional keyword arguments (currently unused). - """ - self._storage = storage - - async def read_dataframe(self, table_name: str) -> pd.DataFrame: - """Read a table from storage as a pandas DataFrame. - - Args - ---- - table_name: str - The name of the table to read. The file will be accessed as '{table_name}.parquet'. - - Returns - ------- - pd.DataFrame: - The table data loaded from the Parquet file. - - Raises - ------ - ValueError: - If the table file does not exist in storage. - Exception: - If there is an error reading or parsing the Parquet file. - """ - filename = f"{table_name}.parquet" - file_exists = await self._storage.has(filename) - - if not file_exists: - msg = f"Could not find {filename} in storage!" - raise ValueError(msg) - try: - logger.info("reading table from storage: %s", filename) - return pd.read_parquet( - BytesIO(await self._storage.get(filename, as_bytes=True)) - ) - except Exception: - logger.exception("error loading table from storage: %s", filename) - raise - - async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: - """Write a pandas DataFrame to storage as a Parquet file. - - Args - ---- - table_name: str - The name of the table to write. The file will be saved as '{table_name}.parquet'. - df: pd.DataFrame - The DataFrame to write to storage. - """ - await self._storage.set(f"{table_name}.parquet", df.to_parquet()) - - async def has(self, table_name: str) -> bool: - """Check if a table exists in storage. - - Args - ---- - table_name: str - The name of the table to check. - - Returns - ------- - bool: - True if the table exists, False otherwise. - """ - return await self._storage.has(f"{table_name}.parquet") - - def list(self) -> list[str]: - """List all table names in storage. - - Returns - ------- - list[str]: - List of table names (without .parquet extension). - """ - return [ - file.replace(".parquet", "") - for file in self._storage.find(re.compile(r"\.parquet$")) - ] diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py index b9583f55f4..d79c01e07b 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py @@ -60,13 +60,6 @@ def create_table_provider( if table_type not in table_provider_factory: match table_type: - case TableType.CosmosDB: - from graphrag_storage.tables.cosmosdb_table_provider import ( - CosmosDBTableProvider, - ) - - register_table_provider(TableType.CosmosDB, CosmosDBTableProvider) - case TableType.Parquet: from graphrag_storage.tables.parquet_table_provider import ( ParquetTableProvider, diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_type.py b/packages/graphrag-storage/graphrag_storage/tables/table_type.py index bc24905f47..3397390b77 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/table_type.py +++ b/packages/graphrag-storage/graphrag_storage/tables/table_type.py @@ -12,4 +12,3 @@ class TableType(StrEnum): Parquet = "parquet" CSV = "csv" - CosmosDB = "cosmosdb" diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index 7a10cedb56..5f5620e547 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -13,10 +13,8 @@ import pandas as pd from graphrag_cache import create_cache from graphrag_storage import create_storage -from graphrag_storage.storage_type import StorageType from graphrag_storage.tables.table_provider import TableProvider from graphrag_storage.tables.table_provider_factory import create_table_provider -from graphrag_storage.tables.table_type import TableType from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -42,10 +40,6 @@ async def run_pipeline( output_storage = create_storage(config.output_storage) - # Workaround to fix cosmosdb incompatibilities using a new provider - if config.output_storage.type == StorageType.AzureCosmos: - config.table_provider.type = TableType.CosmosDB - output_table_provider = create_table_provider(config.table_provider, output_storage) cache = create_cache(config.cache) From 8f05bc37180e5857937827c05458bcf1040a4226 Mon Sep 17 00:00:00 2001 From: Gaudy Blanco Meneses Date: Sun, 15 Feb 2026 00:16:53 -0600 Subject: [PATCH 5/6] clean code --- packages/graphrag/graphrag/index/run/run_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index 5f5620e547..55a4b00171 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -46,8 +46,6 @@ async def run_pipeline( # load existing state in case any workflows are stateful state_json = await output_storage.get("context.json") - logger.info("Printing state json") - logger.info(state_json) state = json.loads(state_json) if state_json else {} if additional_context: From f7c7a13eb1a4d85efb13a1a3fc66b127073ecc07 Mon Sep 17 00:00:00 2001 From: Gaudy Blanco Meneses Date: Tue, 17 Feb 2026 12:16:02 -0600 Subject: [PATCH 6/6] remove unnecessary prints --- .../graphrag-storage/graphrag_storage/azure_cosmos_storage.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index 9828d81486..5423c216ed 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -284,8 +284,6 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: # Cosmos identity must be stable and NEVER change cosmosdb_item["id"] = cosmos_id - logger.info("Print ids") - logger.info("%s", cosmos_id) else: # Original behavior for non-entity prefixes