diff --git a/.semversioner/next-release/patch-20260213160631396575.json b/.semversioner/next-release/patch-20260213160631396575.json new file mode 100644 index 0000000000..79a9221724 --- /dev/null +++ b/.semversioner/next-release/patch-20260213160631396575.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "add csv table smoke tests" +} diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py index 2561bde0d8..73b0540a53 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py @@ -70,7 +70,7 @@ async def read_dataframe(self, table_name: str) -> pd.DataFrame: # Handle empty CSV (pandas can't parse files with no columns) if not csv_data or csv_data.strip() == "": return pd.DataFrame() - return pd.read_csv(StringIO(csv_data)) + return pd.read_csv(StringIO(csv_data), keep_default_na=False) except Exception: logger.exception("error loading table from storage: %s", filename) raise diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index cc69e523ec..0ae2e06efc 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -21,9 +21,12 @@ 300 ], "max_runtime": 30, + "nan_allowed_columns": [ + "description" + ], "expected_artifacts": [ - "entities.parquet", - "relationships.parquet" + "entities.csv", + "relationships.csv" ] }, "create_communities": { @@ -32,7 +35,7 @@ 30 ], "max_runtime": 30, - "expected_artifacts": ["communities.parquet"] + "expected_artifacts": ["communities.csv"] }, "create_community_reports_text": { "row_range": [ @@ -51,7 +54,7 @@ "size" ], "max_runtime": 2000, - "expected_artifacts": ["community_reports.parquet"] + "expected_artifacts": ["community_reports.csv"] }, "create_final_text_units": { "row_range": [ @@ -64,7 +67,7 @@ "covariate_ids" ], "max_runtime": 30, - "expected_artifacts": ["text_units.parquet"] + "expected_artifacts": ["text_units.csv"] }, "create_final_documents": { "row_range": [ @@ -75,7 +78,7 @@ "raw_data" ], "max_runtime": 30, - "expected_artifacts": ["documents.parquet"] + "expected_artifacts": ["documents.csv"] }, "generate_text_embeddings": { "row_range": [ @@ -84,9 +87,9 @@ ], "max_runtime": 150, "expected_artifacts": [ - "embeddings.text_unit_text.parquet", - "embeddings.entity_description.parquet", - "embeddings.community_full_content.parquet" + "embeddings.text_unit_text.csv", + "embeddings.entity_description.csv", + "embeddings.community_full_content.csv" ] } }, diff --git a/tests/fixtures/text/settings.yml b/tests/fixtures/text/settings.yml index 6cf6f9074d..9f18f7680d 100644 --- a/tests/fixtures/text/settings.yml +++ b/tests/fixtures/text/settings.yml @@ -29,6 +29,9 @@ vector_store: api_key: ${AZURE_AI_SEARCH_API_KEY} container_name: "simple_text_ci" +table_provider: + type: csv + community_reports: prompt: "prompts/community_report.txt" max_length: 2000 diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index 7d43b0140a..8624930c14 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -178,25 +178,29 @@ def __assert_indexer_outputs( for artifact in workflow_artifacts: if artifact.endswith(".parquet"): output_df = pd.read_parquet(output_path / artifact) - - # Check number of rows between range - assert ( - config["row_range"][0] - <= len(output_df) - <= config["row_range"][1] - ), ( - f"Expected between {config['row_range'][0]} and {config['row_range'][1]}, found: {len(output_df)} for file: {artifact}" + elif artifact.endswith(".csv"): + output_df = pd.read_csv( + output_path / artifact, keep_default_na=False ) + else: + continue + + # Check number of rows between range + assert ( + config["row_range"][0] <= len(output_df) <= config["row_range"][1] + ), ( + f"Expected between {config['row_range'][0]} and {config['row_range'][1]}, found: {len(output_df)} for file: {artifact}" + ) - # Get non-nan rows - nan_df = output_df.loc[ - :, - ~output_df.columns.isin(config.get("nan_allowed_columns", [])), - ] - nan_df = nan_df[nan_df.isna().any(axis=1)] - assert len(nan_df) == 0, ( - f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}" - ) + # Get non-nan rows + nan_df = output_df.loc[ + :, + ~output_df.columns.isin(config.get("nan_allowed_columns", [])), + ] + nan_df = nan_df[nan_df.isna().any(axis=1)] + assert len(nan_df) == 0, ( + f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}" + ) def __run_query(self, root: Path, query_config: dict[str, str]): command = [