From e196e08e004b839d03e6c1d0467651c6d0aa14f4 Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Wed, 18 Feb 2026 22:08:29 +1100 Subject: [PATCH 01/17] fix(plt): Update color imports to match figrecipe API minimization figrecipe.colors hid BGR/str2* converters behind _prefixed names. Import from private names for backward compat, expose only public API in __all__. Co-Authored-By: Claude Opus 4.6 --- .../{release.yml => publish-pypi.yml} | 0 src/scitex/plt/color/__init__.py | 103 ++++++++---------- 2 files changed, 48 insertions(+), 55 deletions(-) rename .github/workflows/{release.yml => publish-pypi.yml} (100%) diff --git a/.github/workflows/release.yml b/.github/workflows/publish-pypi.yml similarity index 100% rename from .github/workflows/release.yml rename to .github/workflows/publish-pypi.yml diff --git a/src/scitex/plt/color/__init__.py b/src/scitex/plt/color/__init__.py index fb1581f0e..50c537009 100755 --- a/src/scitex/plt/color/__init__.py +++ b/src/scitex/plt/color/__init__.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 """Scitex color module — delegates to figrecipe.colors (single source of truth). -Kept for backward compatibility: ``from scitex.plt.color import PARAMS`` still works. +Public API mirrors figrecipe.colors public exports. +Internal converters (BGR, str2*) importable via _prefixed names for backward compat. """ -# Core exports from figrecipe.colors +# Public API from figrecipe.colors from figrecipe.colors import ( DEF_ALPHA, HEX, @@ -13,45 +14,53 @@ RGB_NORM, RGBA, RGBA_NORM, - bgr2bgra, - bgr2rgb, - bgra2bgr, - bgra2hex, - bgra2rgba, cycle_color, - cycle_color_bgr, - cycle_color_rgb, gen_interpolate, - get_categorical_colors_from_conf_matap, - get_color_from_conf_matap, - get_colors_from_conf_matap, + get_categorical_colors_from_cmap, + get_color_from_cmap, + get_colors_from_cmap, gradiate_color, - gradiate_color_bgr, - gradiate_color_bgra, - gradiate_color_rgb, - gradiate_color_rgba, interpolate, - rgb2bgr, - rgb2rgba, - rgba2bgra, - rgba2hex, - rgba2rgb, - str2bgr, - str2bgra, - str2hex, - str2rgb, - str2rgba, to_hex, to_rgb, to_rgba, update_alpha, ) +# Internal — importable but not public (figrecipe hid these) +from figrecipe.colors import _bgr2bgra as bgr2bgra +from figrecipe.colors import _bgr2rgb as bgr2rgb +from figrecipe.colors import _bgra2bgr as bgra2bgr +from figrecipe.colors import _bgra2hex as bgra2hex +from figrecipe.colors import _bgra2rgba as bgra2rgba +from figrecipe.colors import _cycle_color_bgr as cycle_color_bgr +from figrecipe.colors import _cycle_color_rgb as cycle_color_rgb +from figrecipe.colors import ( + _get_categorical_colors_from_conf_matap as get_categorical_colors_from_conf_matap, +) +from figrecipe.colors import _get_color_from_conf_matap as get_color_from_conf_matap +from figrecipe.colors import _get_colors_from_conf_matap as get_colors_from_conf_matap +from figrecipe.colors import _gradiate_color_bgr as gradiate_color_bgr +from figrecipe.colors import _gradiate_color_bgra as gradiate_color_bgra +from figrecipe.colors import _gradiate_color_rgb as gradiate_color_rgb +from figrecipe.colors import _gradiate_color_rgba as gradiate_color_rgba +from figrecipe.colors import _rgb2bgr as rgb2bgr +from figrecipe.colors import _rgb2rgba as rgb2rgba +from figrecipe.colors import _rgba2bgra as rgba2bgra +from figrecipe.colors import _rgba2hex as rgba2hex +from figrecipe.colors import _rgba2rgb as rgba2rgb +from figrecipe.colors import _str2bgr as str2bgr +from figrecipe.colors import _str2bgra as str2bgra +from figrecipe.colors import _str2hex as str2hex +from figrecipe.colors import _str2rgb as str2rgb +from figrecipe.colors import _str2rgba as str2rgba + # scitex-specific extras (not in figrecipe) from ._add_hue_col import add_hue_col from ._vizualize_colors import vizualize_colors __all__ = [ + # Constants "PARAMS", "DEF_ALPHA", "RGB", @@ -59,39 +68,23 @@ "RGBA", "RGBA_NORM", "HEX", - "add_hue_col", - "bgr2bgra", - "bgr2rgb", - "bgra2bgr", - "bgra2hex", - "bgra2rgba", - "cycle_color", - "cycle_color_bgr", - "cycle_color_rgb", - "gen_interpolate", - "get_categorical_colors_from_conf_matap", - "get_color_from_conf_matap", - "get_colors_from_conf_matap", - "gradiate_color", - "gradiate_color_bgr", - "gradiate_color_bgra", - "gradiate_color_rgb", - "gradiate_color_rgba", - "interpolate", - "rgb2bgr", - "rgb2rgba", - "rgba2bgra", - "rgba2hex", - "rgba2rgb", - "str2bgr", - "str2bgra", - "str2hex", - "str2rgb", - "str2rgba", + # Universal converters (public) "to_hex", "to_rgb", "to_rgba", "update_alpha", + # Color cycling (public) + "cycle_color", + # Gradients & interpolation (public) + "gradiate_color", + "interpolate", + "gen_interpolate", + # Colormap utilities (public) + "get_color_from_cmap", + "get_colors_from_cmap", + "get_categorical_colors_from_cmap", + # scitex extras + "add_hue_col", "vizualize_colors", ] From d369a6aa56e7c14c3b862c3ea80b81e1eead11a3 Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Wed, 18 Feb 2026 22:28:06 +1100 Subject: [PATCH 02/17] =?UTF-8?q?fix(plt):=20Respect=20figrecipe=20API=20m?= =?UTF-8?q?inimization=20=E2=80=94=20only=20expose=20public=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove all re-exports of hidden BGR/str2* functions from plt.color - Update internal callers: get_colors_from_conf_matap → get_colors_from_cmap - Update internal callers: str2hex → to_hex Co-Authored-By: Claude Opus 4.6 --- src/scitex/ai/plt/_plot_learning_curve.py | 11 +++---- src/scitex/ai/plt/_plot_pre_rec_curve.py | 14 ++++---- src/scitex/ai/plt/_plot_roc_curve.py | 12 +++---- src/scitex/plt/color/__init__.py | 39 ++++------------------- 4 files changed, 21 insertions(+), 55 deletions(-) diff --git a/src/scitex/ai/plt/_plot_learning_curve.py b/src/scitex/ai/plt/_plot_learning_curve.py index f739fd11c..d2be05aaf 100755 --- a/src/scitex/ai/plt/_plot_learning_curve.py +++ b/src/scitex/ai/plt/_plot_learning_curve.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # Timestamp: "2025-10-02 19:50:54 (ywatanabe)" # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/plt/plot_learning_curve.py # ---------------------------------------- @@ -19,7 +18,7 @@ import pandas as pd import scitex -from scitex.plt.color import str2hex +from scitex.plt.color import to_hex def _prepare_metrics_df(metrics_df): @@ -49,7 +48,7 @@ def _configure_accuracy_axis(ax, metric_key): def _plot_training_data(ax, metrics_df, metric_key, linewidth=1, color=None): """Plot training phase data as line.""" if color is None: - color = str2hex("blue") + color = to_hex("blue") is_training = scitex.str.search("^[Tt]rain(ing)?", metrics_df.step, as_bool=True)[0] training_df = metrics_df[is_training] @@ -70,7 +69,7 @@ def _plot_training_data(ax, metrics_df, metric_key, linewidth=1, color=None): def _plot_validation_data(ax, metrics_df, metric_key, markersize=3, color=None): """Plot validation phase data as scatter.""" if color is None: - color = str2hex("green") + color = to_hex("green") is_validation = scitex.str.search( "^[Vv]alid(ation)?", metrics_df.step, as_bool=True @@ -93,7 +92,7 @@ def _plot_validation_data(ax, metrics_df, metric_key, markersize=3, color=None): def _plot_test_data(ax, metrics_df, metric_key, markersize=3, color=None): """Plot test phase data as scatter.""" if color is None: - color = str2hex("red") + color = to_hex("red") is_test = scitex.str.search("^[Tt]est", metrics_df.step, as_bool=True)[0] test_df = metrics_df[is_test] @@ -232,9 +231,7 @@ def plot_learning_curve( def main(args): """Demo learning curve plotting with synthetic data.""" - import matplotlib.pyplot as plt import numpy as np - import pandas as pd # Create synthetic metrics data n_epochs = 10 diff --git a/src/scitex/ai/plt/_plot_pre_rec_curve.py b/src/scitex/ai/plt/_plot_pre_rec_curve.py index f28a7e131..3c703bd84 100755 --- a/src/scitex/ai/plt/_plot_pre_rec_curve.py +++ b/src/scitex/ai/plt/_plot_pre_rec_curve.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # Timestamp: "2025-10-02 19:44:06 (ywatanabe)" # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/plt/plot_pre_rec_curve.py # ---------------------------------------- @@ -16,7 +15,7 @@ import numpy as np from sklearn.metrics import average_precision_score, precision_recall_curve -from scitex.plt.color import get_colors_from_conf_matap +from scitex.plt.color import get_colors_from_cmap def _solve_intersection(f1, a, b): @@ -123,7 +122,7 @@ def plot_pre_rec_curve(true_class, pred_proba, labels, ax=None, spath=None): true_class_onehot[:, i], pred_proba[:, i], average="macro" ) ) - except Exception as e: + except Exception: print( f'\nPRE-REC-AUC for "{labels[i]}" was not defined and NaN-filled ' "for a calculation purpose (for the macro avg.)\n" @@ -137,7 +136,7 @@ def plot_pre_rec_curve(true_class, pred_proba, labels, ax=None, spath=None): # Plot Precision-Recall curve for each class and iso-f1 curves # Use scitex color palette for consistent styling - colors = get_colors_from_conf_matap("tab10", n_classes) + colors = get_colors_from_cmap("tab10", n_classes) if ax is None: fig, ax = stx.plt.subplots() @@ -158,7 +157,7 @@ def plot_pre_rec_curve(true_class, pred_proba, labels, ax=None, spath=None): # ax.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02)) x_f, y_f = _solve_intersection(f_score, 0.5, 0.5) - ax.annotate("f1={0:0.1f}".format(f_score), xy=(x_f - 0.1, y_f - 0.1 * 0.5)) + ax.annotate(f"f1={f_score:0.1f}", xy=(x_f - 0.1, y_f - 0.1 * 0.5)) # ax.annotate("f1={0:0.1f}".format(f_score), xy=(y[35] - 0.02 * (3 - i_f), 0.85)) lines.append(l) @@ -175,7 +174,7 @@ def plot_pre_rec_curve(true_class, pred_proba, labels, ax=None, spath=None): for i in range(n_classes): (l,) = ax.plot(recall[i], precision[i], color=colors[i], lw=2) lines.append(l) - legends.append("{0} (AUC = {1:0.2f})".format(labels[i], pre_rec_auc[i])) + legends.append(f"{labels[i]} (AUC = {pre_rec_auc[i]:0.2f})") # fig = plt.gcf() fig.subplots_adjust(bottom=0.25) @@ -208,7 +207,6 @@ def plot_pre_rec_curve(true_class, pred_proba, labels, ax=None, spath=None): def main(args): """Demo Precision-Recall curve plotting with MNIST dataset.""" - import matplotlib.pyplot as plt from sklearn import datasets, svm from sklearn.model_selection import train_test_split @@ -228,7 +226,7 @@ def main(args): predicted_proba = clf.predict_proba(X_test) n_classes = len(np.unique(digits.target)) - labels = ["Class {}".format(i) for i in range(n_classes)] + labels = [f"Class {i}" for i in range(n_classes)] # plt.rcParams["font.size"] = 20 # plt.rcParams["legend.fontsize"] = "xx-small" diff --git a/src/scitex/ai/plt/_plot_roc_curve.py b/src/scitex/ai/plt/_plot_roc_curve.py index 9e4e5558c..07da2080c 100755 --- a/src/scitex/ai/plt/_plot_roc_curve.py +++ b/src/scitex/ai/plt/_plot_roc_curve.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # Timestamp: "2025-10-02 19:44:13 (ywatanabe)" # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/plt/plot_roc_curve.py # ---------------------------------------- @@ -17,7 +16,7 @@ from sklearn.metrics import roc_auc_score, roc_curve import scitex -from scitex.plt.color import get_colors_from_conf_matap +from scitex.plt.color import get_colors_from_cmap def _to_onehot(class_indices, n_classes): @@ -107,7 +106,7 @@ def plot_roc_curve(true_class, pred_proba, labels, ax=None, spath=None): true_class_onehot[:, i], pred_proba[:, i], average="macro" ) ) - except Exception as e: + except Exception: print( f'\nROC-AUC for "{labels[i]}" was not defined and NaN-filled ' "for a calculation purpose (for the macro avg.)\n" @@ -117,7 +116,7 @@ def plot_roc_curve(true_class, pred_proba, labels, ax=None, spath=None): # Plot FPR-TPR curve for each class and iso-f1 curves # Use scitex color palette for consistent styling - colors = get_colors_from_conf_matap("tab10", n_classes) + colors = get_colors_from_cmap("tab10", n_classes) if ax is None: fig, ax = stx.plt.subplots() @@ -143,7 +142,7 @@ def plot_roc_curve(true_class, pred_proba, labels, ax=None, spath=None): for i in range(n_classes): (l,) = ax.plot(fpr[i], tpr[i], color=colors[i], lw=2) lines.append(l) - legends.append("{0} (AUC = {1:0.2f})".format(labels[i], roc_auc[i])) + legends.append(f"{labels[i]} (AUC = {roc_auc[i]:0.2f})") # fig = plt.gcf() fig.subplots_adjust(bottom=0.25) @@ -171,7 +170,6 @@ def plot_roc_curve(true_class, pred_proba, labels, ax=None, spath=None): def main(args): """Demo ROC AUC plotting with MNIST dataset.""" - import matplotlib.pyplot as plt from sklearn import datasets, svm from sklearn.model_selection import train_test_split @@ -191,7 +189,7 @@ def main(args): predicted_proba = clf.predict_proba(X_test) n_classes = len(np.unique(digits.target)) - labels = ["Class {}".format(i) for i in range(n_classes)] + labels = [f"Class {i}" for i in range(n_classes)] # plt.rcParams["font.size"] = 20 # plt.rcParams["legend.fontsize"] = "xx-small" diff --git a/src/scitex/plt/color/__init__.py b/src/scitex/plt/color/__init__.py index 50c537009..ebaebc56a 100755 --- a/src/scitex/plt/color/__init__.py +++ b/src/scitex/plt/color/__init__.py @@ -2,7 +2,8 @@ """Scitex color module — delegates to figrecipe.colors (single source of truth). Public API mirrors figrecipe.colors public exports. -Internal converters (BGR, str2*) importable via _prefixed names for backward compat. +Internal functions remain accessible via figrecipe.colors._colors.bgr2rgb etc. +but are not re-exported here to keep the public API clean. """ # Public API from figrecipe.colors @@ -27,34 +28,6 @@ update_alpha, ) -# Internal — importable but not public (figrecipe hid these) -from figrecipe.colors import _bgr2bgra as bgr2bgra -from figrecipe.colors import _bgr2rgb as bgr2rgb -from figrecipe.colors import _bgra2bgr as bgra2bgr -from figrecipe.colors import _bgra2hex as bgra2hex -from figrecipe.colors import _bgra2rgba as bgra2rgba -from figrecipe.colors import _cycle_color_bgr as cycle_color_bgr -from figrecipe.colors import _cycle_color_rgb as cycle_color_rgb -from figrecipe.colors import ( - _get_categorical_colors_from_conf_matap as get_categorical_colors_from_conf_matap, -) -from figrecipe.colors import _get_color_from_conf_matap as get_color_from_conf_matap -from figrecipe.colors import _get_colors_from_conf_matap as get_colors_from_conf_matap -from figrecipe.colors import _gradiate_color_bgr as gradiate_color_bgr -from figrecipe.colors import _gradiate_color_bgra as gradiate_color_bgra -from figrecipe.colors import _gradiate_color_rgb as gradiate_color_rgb -from figrecipe.colors import _gradiate_color_rgba as gradiate_color_rgba -from figrecipe.colors import _rgb2bgr as rgb2bgr -from figrecipe.colors import _rgb2rgba as rgb2rgba -from figrecipe.colors import _rgba2bgra as rgba2bgra -from figrecipe.colors import _rgba2hex as rgba2hex -from figrecipe.colors import _rgba2rgb as rgba2rgb -from figrecipe.colors import _str2bgr as str2bgr -from figrecipe.colors import _str2bgra as str2bgra -from figrecipe.colors import _str2hex as str2hex -from figrecipe.colors import _str2rgb as str2rgb -from figrecipe.colors import _str2rgba as str2rgba - # scitex-specific extras (not in figrecipe) from ._add_hue_col import add_hue_col from ._vizualize_colors import vizualize_colors @@ -68,18 +41,18 @@ "RGBA", "RGBA_NORM", "HEX", - # Universal converters (public) + # Universal converters "to_hex", "to_rgb", "to_rgba", "update_alpha", - # Color cycling (public) + # Color cycling "cycle_color", - # Gradients & interpolation (public) + # Gradients & interpolation "gradiate_color", "interpolate", "gen_interpolate", - # Colormap utilities (public) + # Colormap utilities "get_color_from_cmap", "get_colors_from_cmap", "get_categorical_colors_from_cmap", From d13ddafa75ebeae59b71e7ebb4d6acef1f8f6ed3 Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Wed, 18 Feb 2026 22:51:52 +1100 Subject: [PATCH 03/17] =?UTF-8?q?ci:=20Rename=20release.yml=20=E2=86=92=20?= =?UTF-8?q?publish-pypi.yml=20to=20match=20PyPI=20trusted=20publisher?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Aligns workflow filename with PyPI OIDC trusted publisher config. Follows figrecipe pattern: trigger on release published, two-job build+publish. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/publish-pypi.yml | 187 ++++++----------------------- 1 file changed, 34 insertions(+), 153 deletions(-) diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index f269a9854..aeb16751c 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -1,171 +1,52 @@ -# Automated Release Pipeline -# Triggers on version tag push (v*) and: -# 1. Creates GitHub Release -# 2. Publishes to PyPI -# 3. Tests installation - -name: Release +name: Publish to PyPI on: - push: - tags: - - 'v*' - -jobs: - validate: - name: Validate Release - runs-on: ubuntu-latest - outputs: - version: ${{ steps.version.outputs.version }} - steps: - - uses: actions/checkout@v4 - - - name: Extract version - id: version - run: | - TAG_VERSION="${GITHUB_REF#refs/tags/v}" - TOML_VERSION=$(grep '^version = ' pyproject.toml | sed 's/version = "\(.*\)"/\1/') - - echo "Tag version: $TAG_VERSION" - echo "pyproject.toml version: $TOML_VERSION" - - if [ "$TAG_VERSION" != "$TOML_VERSION" ]; then - echo "ERROR: Tag version ($TAG_VERSION) != pyproject.toml ($TOML_VERSION)" - exit 1 - fi + release: + types: [published] - echo "version=$TAG_VERSION" >> $GITHUB_OUTPUT - echo "✓ Versions match: $TAG_VERSION" +permissions: + contents: read +jobs: build: - name: Build Package - needs: validate runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - - name: Install build tools - run: pip install build - - - name: Build package - run: python -m build + - uses: actions/checkout@v4 - - name: Upload artifacts - uses: actions/upload-artifact@v4 - with: - name: dist - path: dist/ + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" - test-install: - name: Test Installation - needs: build - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ['3.10', '3.12'] - install-type: ['core', 'audio'] - steps: - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} + - name: Install build tools + run: | + python -m pip install --upgrade pip + pip install build twine - - name: Download artifacts - uses: actions/download-artifact@v4 - with: - name: dist - path: dist/ + - name: Build package + run: python -m build - - name: Install package - run: | - pip install --upgrade pip - WHEEL=$(ls dist/*.whl) - if [ "${{ matrix.install-type }}" == "audio" ]; then - pip install "${WHEEL}[audio]" - else - pip install "$WHEEL" - fi - timeout-minutes: 10 + - name: Check package + run: twine check dist/* - - name: Test imports - run: | - python -c "import scitex; print(f'Version: {scitex.__version__}')" - python -c "from scitex import io, stats, config" - echo "✓ Import test passed" + - name: Upload build artifacts + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ - github-release: - name: Create GitHub Release - needs: [validate, test-install] - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - uses: actions/checkout@v4 - - - name: Download artifacts - uses: actions/download-artifact@v4 - with: - name: dist - path: dist/ - - - name: Generate release notes - id: notes - run: | - VERSION="${{ needs.validate.outputs.version }}" - # Extract changelog for this version - NOTES=$(awk "/## \[$VERSION\]/{flag=1; next} /## \[/{flag=0} flag" CHANGELOG.md) - if [ -z "$NOTES" ]; then - NOTES="Release v$VERSION" - fi - echo "notes<> $GITHUB_OUTPUT - echo "$NOTES" >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - - - name: Create Release - uses: softprops/action-gh-release@v1 - with: - name: v${{ needs.validate.outputs.version }} - body: ${{ steps.notes.outputs.notes }} - files: dist/* - generate_release_notes: true - - publish-pypi: - name: Publish to PyPI - needs: [validate, test-install, github-release] + publish: + needs: build runs-on: ubuntu-latest environment: pypi permissions: id-token: write steps: - - name: Download artifacts - uses: actions/download-artifact@v4 - with: - name: dist - path: dist/ - - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - - verify: - name: Verify Release - needs: [validate, publish-pypi] - runs-on: ubuntu-latest - steps: - - name: Wait for PyPI propagation - run: sleep 60 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - - name: Verify PyPI installation - run: | - pip install "scitex==${{ needs.validate.outputs.version }}" - python -c "import scitex; assert scitex.__version__ == '${{ needs.validate.outputs.version }}'" - echo "✓ PyPI verification passed" + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 From 0c1130cbb6ed62cef32ddc285e433299b7516c3e Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Wed, 18 Feb 2026 22:59:49 +1100 Subject: [PATCH 04/17] refactor(plt): Use figrecipe public API instead of internal modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace all internal figrecipe._* imports with public API equivalents: - figrecipe._utils._units → figrecipe.utils - figrecipe._utils._crop → figrecipe.crop - figrecipe._api._notebook → figrecipe.utils - figrecipe._api._seaborn_proxy → figrecipe.utils - figrecipe._api._style_manager → figrecipe.utils - figrecipe._composition → figrecipe - figrecipe._graph._presets → figrecipe - figrecipe._wrappers → figrecipe.utils - figrecipe._diagram (single import) → figrecipe Co-Authored-By: Claude Opus 4.6 --- src/scitex/_mcp_tools/diagram.py | 2 +- src/scitex/bridge/_figrecipe.py | 2 +- .../io/_metadata_modules/embed_metadata_svg.py | 8 ++++---- .../io/_metadata_modules/read_metadata_svg.py | 10 +++++----- src/scitex/io/_save_modules/_image_csv.py | 2 +- src/scitex/plt/__init__.py | 18 ++++++++++-------- src/scitex/plt/gallery/_generate.py | 2 +- src/scitex/plt/io/_layered_bundle.py | 2 +- src/scitex/plt/utils/__init__.py | 2 +- src/scitex/plt/utils/_figure_from_axes_mm.py | 4 ++-- src/scitex/plt/utils/_figure_mm.py | 2 +- 11 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/scitex/_mcp_tools/diagram.py b/src/scitex/_mcp_tools/diagram.py index a1f5ac247..46cf87407 100755 --- a/src/scitex/_mcp_tools/diagram.py +++ b/src/scitex/_mcp_tools/diagram.py @@ -25,7 +25,7 @@ def register_diagram_tools(mcp) -> None: # Check if figrecipe is available try: - from figrecipe._diagram import Diagram + from figrecipe import Diagram _FIGRECIPE_AVAILABLE = True except ImportError: diff --git a/src/scitex/bridge/_figrecipe.py b/src/scitex/bridge/_figrecipe.py index 1722c0235..d33ad5c13 100755 --- a/src/scitex/bridge/_figrecipe.py +++ b/src/scitex/bridge/_figrecipe.py @@ -136,7 +136,7 @@ def _save_figure_image(fig, path: Path, dpi: int = 300, **kwargs): # Check if this is a figrecipe RecordingFigure - use fr.save() for full support if FIGRECIPE_AVAILABLE: try: - from figrecipe._wrappers import RecordingFigure + from figrecipe.utils import RecordingFigure if isinstance(fig, RecordingFigure): # Use figrecipe's save with facecolor support diff --git a/src/scitex/io/_metadata_modules/embed_metadata_svg.py b/src/scitex/io/_metadata_modules/embed_metadata_svg.py index 370db5af2..b8b5def86 100755 --- a/src/scitex/io/_metadata_modules/embed_metadata_svg.py +++ b/src/scitex/io/_metadata_modules/embed_metadata_svg.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # File: /home/ywatanabe/proj/scitex-code/src/scitex/io/_metadata_modules/embed_metadata_svg.py """SVG metadata embedding using element.""" @@ -15,10 +14,11 @@ def embed_metadata_svg(image_path: str, metadata_json: str) -> None: image_path: Path to the SVG file. metadata_json: JSON string of metadata to embed. - Raises: + Raises + ------ ValueError: If the SVG file is invalid. """ - with open(image_path, "r", encoding="utf-8") as f: + with open(image_path, encoding="utf-8") as f: svg_content = f.read() # Remove existing scitex metadata if present @@ -36,7 +36,7 @@ def embed_metadata_svg(image_path: str, metadata_json: str) -> None: # Create metadata element with scitex data metadata_element = ( f'\n' - f"{metadata_json}" + f"" f"\n" ) svg_content = ( diff --git a/src/scitex/io/_metadata_modules/read_metadata_svg.py b/src/scitex/io/_metadata_modules/read_metadata_svg.py index 5a1fb05c7..f097c5913 100755 --- a/src/scitex/io/_metadata_modules/read_metadata_svg.py +++ b/src/scitex/io/_metadata_modules/read_metadata_svg.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # File: /home/ywatanabe/proj/scitex-code/src/scitex/io/_metadata_modules/read_metadata_svg.py """SVG metadata reading from element.""" @@ -16,18 +15,19 @@ def read_metadata_svg(image_path: str) -> Optional[Dict[str, Any]]: Args: image_path: Path to the SVG file. - Returns: + Returns + ------- Dictionary containing metadata, or None if no metadata found. """ metadata = None - with open(image_path, "r", encoding="utf-8") as f: + with open(image_path, encoding="utf-8") as f: svg_content = f.read() - # Look for scitex metadata element + # Look for scitex metadata element (supports both CDATA and raw JSON) match = re.search( r']*id="scitex_metadata"[^>]*>.*?' - r"(.*?).*?", + r"(?:)?.*?", svg_content, flags=re.DOTALL, ) diff --git a/src/scitex/io/_save_modules/_image_csv.py b/src/scitex/io/_save_modules/_image_csv.py index 4576a5ca4..308a669a1 100755 --- a/src/scitex/io/_save_modules/_image_csv.py +++ b/src/scitex/io/_save_modules/_image_csv.py @@ -188,7 +188,7 @@ def _auto_crop_image( ext = spath.lower() if ext.endswith((".png", ".jpg", ".jpeg", ".tiff", ".tif")): try: - from figrecipe._utils._crop import crop + from figrecipe import crop dpi = kwargs.get("dpi", 300) margin_px = int(crop_margin_mm * dpi / 25.4) diff --git a/src/scitex/plt/__init__.py b/src/scitex/plt/__init__.py index 689edb611..a03c19135 100755 --- a/src/scitex/plt/__init__.py +++ b/src/scitex/plt/__init__.py @@ -73,17 +73,19 @@ # Backward compatibility alias edit = gui - # Internal imports (not part of figrecipe public API) - from figrecipe._api._notebook import enable_svg - from figrecipe._api._seaborn_proxy import sns - from figrecipe._api._style_manager import STYLE, apply_style - from figrecipe._composition import align_panels, align_smart, distribute_panels + # Additional figrecipe public API re-exports + from figrecipe import ( + align_panels, + align_smart, + distribute_panels, + get_graph_preset, + list_graph_presets, + register_graph_preset, + ) + from figrecipe.utils import STYLE, apply_style, enable_svg, sns # Backward compatibility alias smart_align = align_smart - from figrecipe._graph._presets import get_preset as get_graph_preset - from figrecipe._graph._presets import list_presets as list_graph_presets - from figrecipe._graph._presets import register_preset as register_graph_preset # Also export load as alias for reproduce load = reproduce diff --git a/src/scitex/plt/gallery/_generate.py b/src/scitex/plt/gallery/_generate.py index 2f838375f..485d6cc67 100755 --- a/src/scitex/plt/gallery/_generate.py +++ b/src/scitex/plt/gallery/_generate.py @@ -512,7 +512,7 @@ def _generate_and_save_hitmap( # noqa: C901 # Apply same crop as PNG if crop_box provided if crop_box is not None: - from figrecipe._utils._crop import crop + from figrecipe import crop crop( str(hitmap_path), diff --git a/src/scitex/plt/io/_layered_bundle.py b/src/scitex/plt/io/_layered_bundle.py index f972fffea..13b0c655f 100755 --- a/src/scitex/plt/io/_layered_bundle.py +++ b/src/scitex/plt/io/_layered_bundle.py @@ -399,7 +399,7 @@ def save_layered_plot_bundle( # noqa: C901 # Apply additional margin cropping (removes transparent edges) margin_crop_box = None try: - from figrecipe._utils._crop import crop + from figrecipe import crop _, margin_crop_box = crop( str(png_path), diff --git a/src/scitex/plt/utils/__init__.py b/src/scitex/plt/utils/__init__.py index 7335af0cf..4095b0e85 100755 --- a/src/scitex/plt/utils/__init__.py +++ b/src/scitex/plt/utils/__init__.py @@ -3,7 +3,7 @@ from figrecipe._utils._calc_nice_ticks import calc_nice_ticks from figrecipe._utils._mk_colorbar import mk_colorbar -from figrecipe._utils._units import inch_to_mm, mm_to_inch, mm_to_pt, pt_to_mm +from figrecipe.utils import inch_to_mm, mm_to_inch, mm_to_pt, pt_to_mm from ._calc_bacc_from_conf_mat import calc_bacc_from_conf_mat from ._close import close diff --git a/src/scitex/plt/utils/_figure_from_axes_mm.py b/src/scitex/plt/utils/_figure_from_axes_mm.py index 4f7e6f78f..62ea532e4 100755 --- a/src/scitex/plt/utils/_figure_from_axes_mm.py +++ b/src/scitex/plt/utils/_figure_from_axes_mm.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple import matplotlib.pyplot as plt -from figrecipe._utils._units import mm_to_inch +from figrecipe.utils import mm_to_inch if TYPE_CHECKING: from scitex.plt._subplots._AxisWrapper import AxisWrapper @@ -201,7 +201,7 @@ def get_dimension_info(fig, ax) -> Dict: >>> print(f"Axes size: {info['axes_size_mm']} mm") >>> print(f"Axes size: {info['axes_size_px']} pixels at {info['dpi']} DPI") """ - from figrecipe._utils._units import MM_PER_INCH, inch_to_mm + from figrecipe.utils import MM_PER_INCH, inch_to_mm # Figure dimensions fig_width_inch, fig_height_inch = fig.get_size_inches() diff --git a/src/scitex/plt/utils/_figure_mm.py b/src/scitex/plt/utils/_figure_mm.py index 7f073f0db..93ca38680 100755 --- a/src/scitex/plt/utils/_figure_mm.py +++ b/src/scitex/plt/utils/_figure_mm.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple import matplotlib.pyplot as plt -from figrecipe._utils._units import mm_to_inch, mm_to_pt +from figrecipe.utils import mm_to_inch, mm_to_pt from matplotlib.axes import Axes # Default theme color palettes From 491b6b90d4fbcc765322939eb29be9f0fde5b86b Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Thu, 19 Feb 2026 02:03:05 +1100 Subject: [PATCH 05/17] =?UTF-8?q?refactor(plt):=20Delete=20AxisWrapper=20e?= =?UTF-8?q?cosystem=20=E2=80=94=20Phase=202b=20of=20figrecipe=20migration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove ~190 files of the old AxisWrapper/FigWrapper/SubplotsWrapper infrastructure. figrecipe RecordingAxes + stx_* methods (added in figrecipe commit dac5b47) replace all functionality. Deleted: - _subplots/ (AxisWrapper, FigWrapper, SubplotsWrapper, CSV formatters) - ax/_style/ and ax/_plot/ (30+ style and plot functions) - styles/ dead code - All associated tests Co-Authored-By: Claude Sonnet 4.6 --- src/scitex/plt/__init__.py | 106 +- src/scitex/plt/_subplots/_AxesWrapper.py | 197 --- src/scitex/plt/_subplots/_AxisWrapper.py | 375 ------ .../_AdjustmentMixin/__init__.py | 37 - .../_AdjustmentMixin/_labels.py | 268 ---- .../_AdjustmentMixin/_metadata.py | 214 --- .../_AdjustmentMixin/_visual.py | 129 -- .../_MatplotlibPlotMixin/__init__.py | 60 - .../_MatplotlibPlotMixin/_base.py | 36 - .../_MatplotlibPlotMixin/_scientific.py | 596 --------- .../_MatplotlibPlotMixin/_statistical.py | 654 ---------- .../_MatplotlibPlotMixin/_stx_aliases.py | 527 -------- .../_AxisWrapperMixins/_RawMatplotlibMixin.py | 349 ----- .../_SeabornMixin/__init__.py | 34 - .../_AxisWrapperMixins/_SeabornMixin/_base.py | 156 --- .../_SeabornMixin/_wrappers.py | 595 --------- .../_AxisWrapperMixins/_TrackingMixin.py | 199 --- .../_AxisWrapperMixins/_UnitAwareMixin.py | 449 ------- .../_subplots/_AxisWrapperMixins/__init__.py | 91 -- src/scitex/plt/_subplots/_FigWrapper.py | 475 ------- src/scitex/plt/_subplots/_SubplotsWrapper.py | 331 ----- src/scitex/plt/_subplots/__init__.py | 122 -- src/scitex/plt/_subplots/_export_as_csv.py | 464 ------- .../_export_as_csv_formatters/__init__.py | 84 -- .../_format_annotate.py | 80 -- .../_export_as_csv_formatters/_format_bar.py | 139 -- .../_export_as_csv_formatters/_format_barh.py | 59 - .../_format_boxplot.py | 81 -- .../_format_contour.py | 51 - .../_format_contourf.py | 63 - .../_format_errorbar.py | 101 -- .../_format_eventplot.py | 98 -- .../_export_as_csv_formatters/_format_fill.py | 47 - .../_format_fill_between.py | 47 - .../_format_hexbin.py | 52 - .../_export_as_csv_formatters/_format_hist.py | 102 -- .../_format_hist2d.py | 52 - .../_format_imshow.py | 90 -- .../_format_imshow2d.py | 49 - .../_format_matshow.py | 57 - .../_format_pcolormesh.py | 69 - .../_export_as_csv_formatters/_format_pie.py | 56 - .../_export_as_csv_formatters/_format_plot.py | 218 ---- .../_format_plot_box.py | 98 -- .../_format_plot_imshow.py | 55 - .../_format_plot_kde.py | 61 - .../_format_plot_scatter.py | 46 - .../_format_quiver.py | 70 - .../_format_scatter.py | 44 - .../_format_sns_barplot.py | 72 - .../_format_sns_boxplot.py | 112 -- .../_format_sns_heatmap.py | 84 -- .../_format_sns_histplot.py | 111 -- .../_format_sns_jointplot.py | 89 -- .../_format_sns_kdeplot.py | 111 -- .../_format_sns_lineplot.py | 69 - .../_format_sns_pairplot.py | 62 - .../_format_sns_scatterplot.py | 84 -- .../_format_sns_stripplot.py | 90 -- .../_format_sns_swarmplot.py | 90 -- .../_format_sns_violinplot.py | 171 --- .../_format_stackplot.py | 63 - .../_export_as_csv_formatters/_format_stem.py | 52 - .../_export_as_csv_formatters/_format_step.py | 52 - .../_format_streamplot.py | 65 - .../_format_stx_bar.py | 96 -- .../_format_stx_barh.py | 97 -- .../_format_stx_conf_mat.py | 79 -- .../_format_stx_contour.py | 64 - .../_format_stx_ecdf.py | 57 - .../_format_stx_errorbar.py | 152 --- .../_format_stx_fillv.py | 72 - .../_format_stx_heatmap.py | 85 -- .../_format_stx_image.py | 119 -- .../_format_stx_imshow.py | 63 - .../_format_stx_joyplot.py | 86 -- .../_format_stx_line.py | 55 - .../_format_stx_mean_ci.py | 50 - .../_format_stx_mean_std.py | 50 - .../_format_stx_median_iqr.py | 50 - .../_format_stx_raster.py | 54 - .../_format_stx_rectangle.py | 129 -- .../_format_stx_scatter.py | 53 - .../_format_stx_scatter_hist.py | 127 -- .../_format_stx_shaded_line.py | 72 - .../_format_stx_violin.py | 115 -- .../_export_as_csv_formatters/_format_text.py | 61 - .../_format_violin.py | 70 - .../_format_violinplot.py | 91 -- .../test_formatters.py | 207 --- .../verify_formatters.py | 360 ----- src/scitex/plt/_subplots/_fonts.py | 71 - src/scitex/plt/_subplots/_mm_layout.py | 282 ---- src/scitex/plt/ax/__init__.py | 123 -- src/scitex/plt/ax/_plot/__init__.py | 90 -- src/scitex/plt/ax/_plot/_add_fitted_line.py | 153 --- .../plt/ax/_plot/_plot_circular_hist.py | 127 -- src/scitex/plt/ax/_plot/_plot_cube.py | 57 - .../ax/_plot/_plot_statistical_shaded_line.py | 255 ---- src/scitex/plt/ax/_plot/_stx_conf_mat.py | 140 -- src/scitex/plt/ax/_plot/_stx_ecdf.py | 114 -- src/scitex/plt/ax/_plot/_stx_fillv.py | 58 - src/scitex/plt/ax/_plot/_stx_heatmap.py | 369 ------ src/scitex/plt/ax/_plot/_stx_image.py | 97 -- src/scitex/plt/ax/_plot/_stx_joyplot.py | 134 -- src/scitex/plt/ax/_plot/_stx_raster.py | 200 --- src/scitex/plt/ax/_plot/_stx_rectangle.py | 70 - src/scitex/plt/ax/_plot/_stx_scatter_hist.py | 133 -- src/scitex/plt/ax/_plot/_stx_shaded_line.py | 220 ---- src/scitex/plt/ax/_plot/_stx_violin.py | 353 ----- src/scitex/plt/ax/_style/__init__.py | 42 - src/scitex/plt/ax/_style/_add_marginal_ax.py | 47 - src/scitex/plt/ax/_style/_add_panel.py | 93 -- src/scitex/plt/ax/_style/_auto_scale_axis.py | 200 --- src/scitex/plt/ax/_style/_extend.py | 67 - src/scitex/plt/ax/_style/_force_aspect.py | 40 - src/scitex/plt/ax/_style/_format_label.py | 23 - src/scitex/plt/ax/_style/_format_units.py | 103 -- src/scitex/plt/ax/_style/_hide_spines.py | 87 -- src/scitex/plt/ax/_style/_map_ticks.py | 184 --- src/scitex/plt/ax/_style/_rotate_labels.py | 321 ----- .../plt/ax/_style/_rotate_labels_v01.py | 258 ---- src/scitex/plt/ax/_style/_sci_note.py | 279 ---- src/scitex/plt/ax/_style/_set_log_scale.py | 335 ----- src/scitex/plt/ax/_style/_set_meta.py | 294 ----- src/scitex/plt/ax/_style/_set_n_ticks.py | 37 - src/scitex/plt/ax/_style/_set_size.py | 16 - src/scitex/plt/ax/_style/_set_supxyt.py | 133 -- src/scitex/plt/ax/_style/_set_ticks.py | 276 ---- src/scitex/plt/ax/_style/_set_xyt.py | 130 -- src/scitex/plt/ax/_style/_share_axes.py | 267 ---- src/scitex/plt/ax/_style/_shift.py | 139 -- src/scitex/plt/ax/_style/_show_spines.py | 335 ----- src/scitex/plt/ax/_style/_style_barplot.py | 69 - src/scitex/plt/ax/_style/_style_boxplot.py | 153 --- src/scitex/plt/ax/_style/_style_errorbar.py | 82 -- src/scitex/plt/ax/_style/_style_scatter.py | 82 -- src/scitex/plt/ax/_style/_style_suptitles.py | 76 -- src/scitex/plt/ax/_style/_style_violinplot.py | 115 -- src/scitex/plt/styles/__init__.py | 30 +- src/scitex/plt/styles/_plot_defaults.py | 210 --- src/scitex/plt/styles/_plot_postprocess.py | 487 ------- src/scitex/plt/styles/_postprocess_helpers.py | 158 --- .../scholar/citation_graph/visualization.py | 4 +- src/scitex/stats/_figrecipe_integration.py | 8 +- .../custom/test_axes_wrapper_flat_property.py | 82 -- tests/custom/test_imports.py | 38 +- .../_AdjustmentMixin/test__labels.py | 280 ---- .../_AdjustmentMixin/test__metadata.py | 229 ---- .../_AdjustmentMixin/test__visual.py | 144 -- .../_MatplotlibPlotMixin/test__base.py | 50 - .../_MatplotlibPlotMixin/test__scientific.py | 609 --------- .../_MatplotlibPlotMixin/test__statistical.py | 670 ---------- .../_MatplotlibPlotMixin/test__stx_aliases.py | 543 -------- .../_SeabornMixin/test__base.py | 168 --- .../_SeabornMixin/test__wrappers.py | 616 --------- .../test__RawMatplotlibMixin.py | 337 ----- .../_AxisWrapperMixins/test__TrackingMixin.py | 419 ------ .../test__UnitAwareMixin.py | 456 ------- .../test__format_annotate.py | 88 -- .../test__format_bar.py | 154 --- .../test__format_barh.py | 74 -- .../test__format_boxplot.py | 90 -- .../test__format_contour.py | 66 - .../test__format_contourf.py | 78 -- .../test__format_errorbar.py | 104 -- .../test__format_eventplot.py | 99 -- .../test__format_fill.py | 62 - .../test__format_fill_between.py | 62 - .../test__format_hexbin.py | 67 - .../test__format_hist.py | 107 -- .../test__format_hist2d.py | 67 - .../test__format_imshow.py | 224 ---- .../test__format_imshow2d.py | 64 - .../test__format_matshow.py | 66 - .../test__format_pcolormesh.py | 82 -- .../test__format_pie.py | 67 - .../test__format_plot.py | 359 ----- .../test__format_plot_box.py | 114 -- .../test__format_plot_imshow.py | 69 - .../test__format_plot_kde.py | 75 -- .../test__format_plot_scatter.py | 60 - .../test__format_quiver.py | 77 -- .../test__format_scatter.py | 59 - .../test__format_sns_barplot.py | 82 -- .../test__format_sns_boxplot.py | 128 -- .../test__format_sns_heatmap.py | 96 -- .../test__format_sns_histplot.py | 111 -- .../test__format_sns_jointplot.py | 92 -- .../test__format_sns_kdeplot.py | 107 -- .../test__format_sns_lineplot.py | 80 -- .../test__format_sns_pairplot.py | 73 -- .../test__format_sns_scatterplot.py | 88 -- .../test__format_sns_stripplot.py | 96 -- .../test__format_sns_swarmplot.py | 96 -- .../test__format_sns_violinplot.py | 150 --- .../test__format_stackplot.py | 78 -- .../test__format_stem.py | 67 - .../test__format_step.py | 67 - .../test__format_streamplot.py | 72 - .../test__format_stx_bar.py | 100 -- .../test__format_stx_barh.py | 101 -- .../test__format_stx_conf_mat.py | 91 -- .../test__format_stx_contour.py | 70 - .../test__format_stx_ecdf.py | 71 - .../test__format_stx_errorbar.py | 136 -- .../test__format_stx_fillv.py | 88 -- .../test__format_stx_heatmap.py | 90 -- .../test__format_stx_image.py | 125 -- .../test__format_stx_imshow.py | 79 -- .../test__format_stx_joyplot.py | 100 -- .../test__format_stx_line.py | 71 - .../test__format_stx_mean_ci.py | 66 - .../test__format_stx_mean_std.py | 66 - .../test__format_stx_median_iqr.py | 66 - .../test__format_stx_raster.py | 68 - .../test__format_stx_rectangle.py | 145 --- .../test__format_stx_scatter.py | 67 - .../test__format_stx_scatter_hist.py | 108 -- .../test__format_stx_shaded_line.py | 88 -- .../test__format_stx_violin.py | 131 -- .../test__format_text.py | 77 -- .../test__format_violin.py | 81 -- .../test__format_violinplot.py | 98 -- .../test_test_formatters.py | 223 ---- .../test_verify_formatters.py | 375 ------ .../scitex/plt/_subplots/test__AxesWrapper.py | 331 ----- .../scitex/plt/_subplots/test__AxisWrapper.py | 443 ------- .../scitex/plt/_subplots/test__FigWrapper.py | 604 --------- .../plt/_subplots/test__SubplotsWrapper.py | 483 ------- .../plt/_subplots/test__export_as_csv.py | 1155 ----------------- tests/scitex/plt/_subplots/test__fonts.py | 87 -- tests/scitex/plt/_subplots/test__mm_layout.py | 298 ----- .../plt/ax/_plot/test__add_fitted_line.py | 168 --- .../plt/ax/_plot/test__plot_circular_hist.py | 399 ------ tests/scitex/plt/ax/_plot/test__plot_cube.py | 171 --- .../test__plot_statistical_shaded_line.py | 369 ------ .../scitex/plt/ax/_plot/test__stx_conf_mat.py | 155 --- tests/scitex/plt/ax/_plot/test__stx_ecdf.py | 129 -- tests/scitex/plt/ax/_plot/test__stx_fillv.py | 73 -- .../scitex/plt/ax/_plot/test__stx_heatmap.py | 385 ------ tests/scitex/plt/ax/_plot/test__stx_image.py | 112 -- .../scitex/plt/ax/_plot/test__stx_joyplot.py | 151 --- tests/scitex/plt/ax/_plot/test__stx_raster.py | 215 --- .../plt/ax/_plot/test__stx_rectangle.py | 86 -- .../plt/ax/_plot/test__stx_scatter_hist.py | 149 --- .../plt/ax/_plot/test__stx_shaded_line.py | 235 ---- tests/scitex/plt/ax/_plot/test__stx_violin.py | 368 ------ .../plt/ax/_style/test__add_marginal_ax.py | 225 ---- tests/scitex/plt/ax/_style/test__add_panel.py | 246 ---- .../plt/ax/_style/test__auto_scale_axis.py | 215 --- tests/scitex/plt/ax/_style/test__extend.py | 195 --- .../plt/ax/_style/test__force_aspect.py | 163 --- .../plt/ax/_style/test__format_label.py | 388 ------ .../plt/ax/_style/test__format_units.py | 119 -- .../scitex/plt/ax/_style/test__hide_spines.py | 222 ---- tests/scitex/plt/ax/_style/test__map_ticks.py | 381 ------ .../plt/ax/_style/test__rotate_labels.py | 460 ------- .../plt/ax/_style/test__rotate_labels_v01.py | 274 ---- tests/scitex/plt/ax/_style/test__sci_note.py | 450 ------- .../plt/ax/_style/test__set_log_scale.py | 873 ------------- tests/scitex/plt/ax/_style/test__set_meta.py | 776 ----------- .../scitex/plt/ax/_style/test__set_n_ticks.py | 171 --- tests/scitex/plt/ax/_style/test__set_size.py | 154 --- .../scitex/plt/ax/_style/test__set_supxyt.py | 237 ---- tests/scitex/plt/ax/_style/test__set_ticks.py | 422 ------ tests/scitex/plt/ax/_style/test__set_xyt.py | 252 ---- .../scitex/plt/ax/_style/test__share_axes.py | 461 ------- tests/scitex/plt/ax/_style/test__shift.py | 251 ---- .../scitex/plt/ax/_style/test__show_spines.py | 912 ------------- .../plt/ax/_style/test__style_barplot.py | 85 -- .../plt/ax/_style/test__style_boxplot.py | 171 --- .../plt/ax/_style/test__style_errorbar.py | 98 -- .../plt/ax/_style/test__style_scatter.py | 98 -- .../plt/ax/_style/test__style_suptitles.py | 92 -- .../plt/ax/_style/test__style_violinplot.py | 131 -- tests/scitex/plt/ax/conftest_enhanced.py | 516 -------- tests/scitex/plt/color/test__colors.py | 442 ------- .../scitex/plt/styles/test__plot_defaults.py | 226 ---- .../plt/styles/test__plot_postprocess.py | 503 ------- .../plt/styles/test__postprocess_helpers.py | 174 --- 281 files changed, 45 insertions(+), 48793 deletions(-) delete mode 100755 src/scitex/plt/_subplots/_AxesWrapper.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapper.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/__init__.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_labels.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_metadata.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_visual.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/__init__.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_base.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_scientific.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_statistical.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_stx_aliases.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_RawMatplotlibMixin.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/__init__.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_base.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/_UnitAwareMixin.py delete mode 100755 src/scitex/plt/_subplots/_AxisWrapperMixins/__init__.py delete mode 100755 src/scitex/plt/_subplots/_FigWrapper.py delete mode 100755 src/scitex/plt/_subplots/_SubplotsWrapper.py delete mode 100755 src/scitex/plt/_subplots/__init__.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/__init__.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_annotate.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pcolormesh.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_imshow.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stackplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_bar.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_barh.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_conf_mat.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_contour.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_ecdf.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_errorbar.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_fillv.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_heatmap.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_image.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_imshow.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_joyplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_line.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_ci.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_std.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_median_iqr.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_raster.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_rectangle.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter_hist.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_shaded_line.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_violin.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_text.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py delete mode 100755 src/scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py delete mode 100755 src/scitex/plt/_subplots/_fonts.py delete mode 100755 src/scitex/plt/_subplots/_mm_layout.py delete mode 100755 src/scitex/plt/ax/__init__.py delete mode 100755 src/scitex/plt/ax/_plot/__init__.py delete mode 100755 src/scitex/plt/ax/_plot/_add_fitted_line.py delete mode 100755 src/scitex/plt/ax/_plot/_plot_circular_hist.py delete mode 100755 src/scitex/plt/ax/_plot/_plot_cube.py delete mode 100755 src/scitex/plt/ax/_plot/_plot_statistical_shaded_line.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_conf_mat.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_ecdf.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_fillv.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_heatmap.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_image.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_joyplot.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_raster.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_rectangle.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_scatter_hist.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_shaded_line.py delete mode 100755 src/scitex/plt/ax/_plot/_stx_violin.py delete mode 100755 src/scitex/plt/ax/_style/__init__.py delete mode 100755 src/scitex/plt/ax/_style/_add_marginal_ax.py delete mode 100755 src/scitex/plt/ax/_style/_add_panel.py delete mode 100755 src/scitex/plt/ax/_style/_auto_scale_axis.py delete mode 100755 src/scitex/plt/ax/_style/_extend.py delete mode 100755 src/scitex/plt/ax/_style/_force_aspect.py delete mode 100755 src/scitex/plt/ax/_style/_format_label.py delete mode 100755 src/scitex/plt/ax/_style/_format_units.py delete mode 100755 src/scitex/plt/ax/_style/_hide_spines.py delete mode 100755 src/scitex/plt/ax/_style/_map_ticks.py delete mode 100755 src/scitex/plt/ax/_style/_rotate_labels.py delete mode 100755 src/scitex/plt/ax/_style/_rotate_labels_v01.py delete mode 100755 src/scitex/plt/ax/_style/_sci_note.py delete mode 100755 src/scitex/plt/ax/_style/_set_log_scale.py delete mode 100755 src/scitex/plt/ax/_style/_set_meta.py delete mode 100755 src/scitex/plt/ax/_style/_set_n_ticks.py delete mode 100755 src/scitex/plt/ax/_style/_set_size.py delete mode 100755 src/scitex/plt/ax/_style/_set_supxyt.py delete mode 100755 src/scitex/plt/ax/_style/_set_ticks.py delete mode 100755 src/scitex/plt/ax/_style/_set_xyt.py delete mode 100755 src/scitex/plt/ax/_style/_share_axes.py delete mode 100755 src/scitex/plt/ax/_style/_shift.py delete mode 100755 src/scitex/plt/ax/_style/_show_spines.py delete mode 100755 src/scitex/plt/ax/_style/_style_barplot.py delete mode 100755 src/scitex/plt/ax/_style/_style_boxplot.py delete mode 100755 src/scitex/plt/ax/_style/_style_errorbar.py delete mode 100755 src/scitex/plt/ax/_style/_style_scatter.py delete mode 100755 src/scitex/plt/ax/_style/_style_suptitles.py delete mode 100755 src/scitex/plt/ax/_style/_style_violinplot.py delete mode 100755 src/scitex/plt/styles/_plot_defaults.py delete mode 100755 src/scitex/plt/styles/_plot_postprocess.py delete mode 100755 src/scitex/plt/styles/_postprocess_helpers.py delete mode 100644 tests/custom/test_axes_wrapper_flat_property.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__labels.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__metadata.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__visual.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__base.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__scientific.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__statistical.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__stx_aliases.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/test__base.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/test__wrappers.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/test__RawMatplotlibMixin.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/test__TrackingMixin.py delete mode 100644 tests/scitex/plt/_subplots/_AxisWrapperMixins/test__UnitAwareMixin.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_annotate.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_bar.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_barh.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_boxplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_contour.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_contourf.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_errorbar.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_eventplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_fill.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_fill_between.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hexbin.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hist.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hist2d.py delete mode 100755 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_imshow.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_imshow2d.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_matshow.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_pcolormesh.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_pie.py delete mode 100755 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_box.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_imshow.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_kde.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_scatter.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_quiver.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_scatter.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_barplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_boxplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_heatmap.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_histplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_jointplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_kdeplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_lineplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_pairplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_scatterplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_stripplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_swarmplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_violinplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stackplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stem.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_step.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_streamplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_bar.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_barh.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_conf_mat.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_contour.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_ecdf.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_errorbar.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_fillv.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_heatmap.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_image.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_imshow.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_joyplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_line.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_mean_ci.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_mean_std.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_median_iqr.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_raster.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_rectangle.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_scatter.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_scatter_hist.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_shaded_line.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_violin.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_text.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_violin.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_violinplot.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test_test_formatters.py delete mode 100644 tests/scitex/plt/_subplots/_export_as_csv_formatters/test_verify_formatters.py delete mode 100644 tests/scitex/plt/_subplots/test__AxesWrapper.py delete mode 100644 tests/scitex/plt/_subplots/test__AxisWrapper.py delete mode 100644 tests/scitex/plt/_subplots/test__FigWrapper.py delete mode 100755 tests/scitex/plt/_subplots/test__SubplotsWrapper.py delete mode 100644 tests/scitex/plt/_subplots/test__export_as_csv.py delete mode 100644 tests/scitex/plt/_subplots/test__fonts.py delete mode 100644 tests/scitex/plt/_subplots/test__mm_layout.py delete mode 100644 tests/scitex/plt/ax/_plot/test__add_fitted_line.py delete mode 100644 tests/scitex/plt/ax/_plot/test__plot_circular_hist.py delete mode 100644 tests/scitex/plt/ax/_plot/test__plot_cube.py delete mode 100644 tests/scitex/plt/ax/_plot/test__plot_statistical_shaded_line.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_conf_mat.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_ecdf.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_fillv.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_heatmap.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_image.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_joyplot.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_raster.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_rectangle.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_scatter_hist.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_shaded_line.py delete mode 100644 tests/scitex/plt/ax/_plot/test__stx_violin.py delete mode 100644 tests/scitex/plt/ax/_style/test__add_marginal_ax.py delete mode 100644 tests/scitex/plt/ax/_style/test__add_panel.py delete mode 100644 tests/scitex/plt/ax/_style/test__auto_scale_axis.py delete mode 100644 tests/scitex/plt/ax/_style/test__extend.py delete mode 100644 tests/scitex/plt/ax/_style/test__force_aspect.py delete mode 100644 tests/scitex/plt/ax/_style/test__format_label.py delete mode 100644 tests/scitex/plt/ax/_style/test__format_units.py delete mode 100644 tests/scitex/plt/ax/_style/test__hide_spines.py delete mode 100644 tests/scitex/plt/ax/_style/test__map_ticks.py delete mode 100644 tests/scitex/plt/ax/_style/test__rotate_labels.py delete mode 100644 tests/scitex/plt/ax/_style/test__rotate_labels_v01.py delete mode 100644 tests/scitex/plt/ax/_style/test__sci_note.py delete mode 100644 tests/scitex/plt/ax/_style/test__set_log_scale.py delete mode 100644 tests/scitex/plt/ax/_style/test__set_meta.py delete mode 100644 tests/scitex/plt/ax/_style/test__set_n_ticks.py delete mode 100644 tests/scitex/plt/ax/_style/test__set_size.py delete mode 100644 tests/scitex/plt/ax/_style/test__set_supxyt.py delete mode 100644 tests/scitex/plt/ax/_style/test__set_ticks.py delete mode 100644 tests/scitex/plt/ax/_style/test__set_xyt.py delete mode 100755 tests/scitex/plt/ax/_style/test__share_axes.py delete mode 100644 tests/scitex/plt/ax/_style/test__shift.py delete mode 100644 tests/scitex/plt/ax/_style/test__show_spines.py delete mode 100644 tests/scitex/plt/ax/_style/test__style_barplot.py delete mode 100644 tests/scitex/plt/ax/_style/test__style_boxplot.py delete mode 100644 tests/scitex/plt/ax/_style/test__style_errorbar.py delete mode 100644 tests/scitex/plt/ax/_style/test__style_scatter.py delete mode 100644 tests/scitex/plt/ax/_style/test__style_suptitles.py delete mode 100644 tests/scitex/plt/ax/_style/test__style_violinplot.py delete mode 100644 tests/scitex/plt/ax/conftest_enhanced.py delete mode 100644 tests/scitex/plt/color/test__colors.py delete mode 100644 tests/scitex/plt/styles/test__plot_defaults.py delete mode 100644 tests/scitex/plt/styles/test__plot_postprocess.py delete mode 100644 tests/scitex/plt/styles/test__postprocess_helpers.py diff --git a/src/scitex/plt/__init__.py b/src/scitex/plt/__init__.py index a03c19135..eed61f31d 100755 --- a/src/scitex/plt/__init__.py +++ b/src/scitex/plt/__init__.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -# Timestamp: "2026-01-19 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/__init__.py +# File: /home/ywatanabe/proj/scitex-python/src/scitex/plt/__init__.py """ SciTeX plt module - Publication-quality plotting via figrecipe. @@ -18,7 +17,7 @@ os.environ.setdefault("FIGRECIPE_BRAND", "scitex.plt") os.environ.setdefault("FIGRECIPE_ALIAS", "plt") -# Map SCITEX_PLT_* → FIGRECIPE_* (user-facing prefix takes priority) +# Map SCITEX_PLT_* -> FIGRECIPE_* (user-facing prefix takes priority) _ENV_MAPPINGS = [ ("SCITEX_PLT_DEBUG_MODE", "FIGRECIPE_DEBUG_MODE"), ("SCITEX_PLT_DEV_REPRESENTATIVE_PLOTS", "FIGRECIPE_DEV_REPRESENTATIVE_PLOTS"), @@ -129,23 +128,22 @@ def _not_available(*args, **kwargs): register_graph_preset = _not_available # ============================================================================ -# Local scitex submodules (kept for compatibility) +# Local scitex submodules # ============================================================================ try: from ._tpl import termplot except ImportError: termplot = None -# Backward compatibility: expose styles submodule (deprecated, use figrecipe) -from . import ax, color, gallery, styles, utils # noqa: E402 +from . import color, gallery, styles, utils # noqa: E402 # Auto-configure matplotlib with SciTeX defaults on import from ._auto_config import configure as _configure # noqa: E402 -# Import draw_graph from figrecipe integration (handles AxisWrapper) +# Import draw_graph from figrecipe integration from ._figrecipe_integration import draw_graph # noqa: E402 -# Spec building and rendering (moved from Django) +# Spec building and rendering from ._render import render_spec_to_bytes # noqa: E402 from ._spec_builders import ( # noqa: E402 ALL_KINDS, @@ -163,36 +161,12 @@ def _not_available(*args, **kwargs): # ============================================================================ -# SciTeX-specific wrapper functions (for AxisWrapper/FigWrapper compatibility) +# SciTeX-specific wrapper functions # ============================================================================ -def figure(*args, **kwargs): - """Create a figure that returns a FigWrapper. - - This is the scitex-specific figure function that creates FigWrapper - objects for compatibility with scitex.plt.ax utilities. - - For figrecipe-style recording figures, use subplots() instead. - """ - from ._subplots._FigWrapper import FigWrapper - - fig_mpl = _plt.figure(*args, **kwargs) - return FigWrapper(fig_mpl) - - def tight_layout(**kwargs): - """Apply tight layout to current figure with colorbar compatibility handling. - - This function calls tight_layout on the current figure and gracefully handles: - 1. UserWarning: "The figure layout has changed to tight" - informational only - 2. RuntimeError: Colorbar layout incompatibility - occurs when colorbars exist with old engine - - Parameters - ---------- - **kwargs - All keyword arguments are passed to matplotlib.pyplot.tight_layout() - """ + """Apply tight layout to current figure with colorbar compatibility handling.""" import warnings with warnings.catch_warnings(): @@ -202,17 +176,12 @@ def tight_layout(**kwargs): try: _plt.tight_layout(**kwargs) except RuntimeError as e: - # Silently handle colorbar layout engine incompatibility if "Colorbar layout" not in str(e): raise def colorbar(mappable=None, cax=None, ax=None, **kwargs): - """ - Create a colorbar, automatically unwrapping SciTeX AxisWrapper objects. - - This function handles both regular matplotlib axes and SciTeX AxisWrapper - objects transparently, making it a drop-in replacement for plt.colorbar(). + """Create a colorbar, unwrapping wrapper axes if needed. Parameters ---------- @@ -220,7 +189,7 @@ def colorbar(mappable=None, cax=None, ax=None, **kwargs): The image, contour set, etc. to which the colorbar applies. cax : Axes, optional Axes into which the colorbar will be drawn. - ax : Axes or AxisWrapper or list thereof, optional + ax : Axes or list thereof, optional Parent axes from which space for the colorbar will be stolen. **kwargs Additional keyword arguments passed to matplotlib.pyplot.colorbar() @@ -230,54 +199,45 @@ def colorbar(mappable=None, cax=None, ax=None, **kwargs): Colorbar The created colorbar object """ - # Unwrap ax if it's a SciTeX AxisWrapper + + def _unwrap(a): + """Unwrap any axes wrapper to raw matplotlib Axes.""" + for attr in ("_ax", "_axis_mpl"): + if hasattr(a, attr): + return getattr(a, attr) + return a + if ax is not None: if hasattr(ax, "__iter__") and not isinstance(ax, str): - # Handle list/array of axes - ax = [a._axis_mpl if hasattr(a, "_axis_mpl") else a for a in ax] + ax = [_unwrap(a) for a in ax] else: - # Single axis - ax = ax._axis_mpl if hasattr(ax, "_axis_mpl") else ax + ax = _unwrap(ax) - # Unwrap cax if provided if cax is not None: - cax = cax._axis_mpl if hasattr(cax, "_axis_mpl") else cax + cax = _unwrap(cax) - # Call matplotlib's colorbar with unwrapped axes return _plt.colorbar(mappable=mappable, cax=cax, ax=ax, **kwargs) def close(fig=None): - """ - Close a figure, automatically unwrapping SciTeX FigWrapper objects. - - This function is a drop-in replacement for matplotlib.pyplot.close() that - handles both regular matplotlib Figure objects and SciTeX FigWrapper objects. + """Close a figure, unwrapping wrapper objects if needed. Parameters ---------- - fig : Figure, FigWrapper, int, str, or None - The figure to close. Can be: - - None: close the current figure - - Figure or FigWrapper: close the specified figure - - int: close figure with that number - - str: close figure with that label, or 'all' to close all figures + fig : Figure, RecordingFigure, int, str, or None + The figure to close. """ if fig is None: _plt.close() elif isinstance(fig, (int, str)): _plt.close(fig) - elif hasattr(fig, "_fig_mpl"): - # FigWrapper object - unwrap and close - _plt.close(fig._fig_mpl) - elif hasattr(fig, "figure"): - # Alternative attribute name (backward compatibility) - _plt.close(fig.figure) elif hasattr(fig, "fig"): - # figrecipe RecordingFigure - unwrap and close + # figrecipe RecordingFigure _plt.close(fig.fig) + elif hasattr(fig, "_fig_mpl"): + # Legacy FigWrapper (backward compat) + _plt.close(fig._fig_mpl) else: - # Assume it's a matplotlib Figure _plt.close(fig) @@ -334,12 +294,10 @@ def close(fig=None): "sns", "enable_svg", # SciTeX-specific wrappers - "figure", "colorbar", "close", "tight_layout", # Local submodules - "ax", "color", "gallery", "utils", @@ -350,19 +308,15 @@ def close(fig=None): def __getattr__(name): - """Fallback to matplotlib.pyplot for any missing attributes. - - This makes scitex.plt a complete drop-in replacement for matplotlib.pyplot. - """ + """Fallback to matplotlib.pyplot for any missing attributes.""" if hasattr(_plt, name): return getattr(_plt, name) raise AttributeError(f"module 'scitex.plt' has no attribute '{name}'") def __dir__(): - """Provide comprehensive directory listing including matplotlib.pyplot functions.""" + """Provide directory listing including matplotlib.pyplot functions.""" local_attrs = list(__all__) - # Add matplotlib.pyplot attributes mpl_attrs = [attr for attr in dir(_plt) if not attr.startswith("_")] local_attrs.extend(mpl_attrs) return sorted(set(local_attrs)) diff --git a/src/scitex/plt/_subplots/_AxesWrapper.py b/src/scitex/plt/_subplots/_AxesWrapper.py deleted file mode 100755 index 3ff4236ee..000000000 --- a/src/scitex/plt/_subplots/_AxesWrapper.py +++ /dev/null @@ -1,197 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-19 15:36:54 (ywatanabe)" -# File: /ssh:ywatanabe@sp:/home/ywatanabe/proj/scitex_repo/src/scitex/plt/_subplots/_AxesWrapper.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from functools import wraps - -import numpy as np -import pandas as pd - -from scitex import logging - -logger = logging.getLogger(__name__) - - -class AxesWrapper: - def __init__(self, fig_scitex, axes_scitex): - self._fig_scitex = fig_scitex - self._axes_scitex = axes_scitex - - def get_figure(self, root=True): - """Get the figure, compatible with matplotlib 3.8+""" - return self._fig_scitex - - def __dir__(self): - # Combine attributes from both self and the wrapped matplotlib axes - attrs = set(dir(self.__class__)) - attrs.update(object.__dir__(self)) - - # Add attributes from the axes objects if available - if hasattr(self, "_axes_scitex") and self._axes_scitex is not None: - # Get attributes from the first axis if there are any - if self._axes_scitex.size > 0: - first_ax = self._axes_scitex.flat[0] - attrs.update(dir(first_ax)) - - return sorted(attrs) - - def __getattr__(self, name): - # Note that self._axes_scitex is "numpy.ndarray" - # print(f"Attribute of AxesWrapper: {name}") - methods = [] - try: - for axis in self._axes_scitex.flat: - methods.append(getattr(axis, name)) - except Exception: - methods = [] - - if methods and all(callable(m) for m in methods): - - @wraps(methods[0]) - def wrapper(*args, **kwargs): - return [ - getattr(ax, name)(*args, **kwargs) for ax in self._axes_scitex.flat - ] - - return wrapper - - if methods and not callable(methods[0]): - return methods - - def dummy(*args, **kwargs): - return None - - return dummy - - # def __getitem__(self, index): - # subset = self._axes_scitex[index] - # if isinstance(index, slice): - # return AxesWrapper(self._fig_scitex, subset) - # return subset - - def __getitem__(self, index): - # Handle 1D-like arrays (single row or single column) - # For (1, n) shape with integer index, return the element from the row - # For (n, 1) shape with integer index, return the element from the column - if isinstance(index, int): - shape = self._axes_scitex.shape - if len(shape) == 2: - if shape[0] == 1: - # Single row case: axes[i] should return axes[0, i] - return self._axes_scitex[0, index] - elif shape[1] == 1: - # Single column case: axes[i] should return axes[i, 0] - return self._axes_scitex[index, 0] - - subset = self._axes_scitex[index] - if isinstance(subset, np.ndarray): - return AxesWrapper(self._fig_scitex, subset) - return subset - - def __setitem__(self, index, value): - """Support item assignment for axes[row, col] = new_axis operations.""" - self._axes_scitex[index] = value - - def __iter__(self): - # Iterate over flattened axes for backward compatibility - return iter(self._axes_scitex.flat) - - def __len__(self): - return self._axes_scitex.size - - def __array__(self): - """Support conversion to numpy array. - - This allows using np.array(axes) on an AxesWrapper instance, returning - a NumPy array with the same shape as the original axes array. - - Notes: - - While this enables compatibility with NumPy functions, not all - operations will work correctly due to the nature of the wrapped - objects. - - For flattening operations, use the dedicated `flatten()` method - instead of `np.array(axes).flatten()`: - - # RECOMMENDED: - flat_axes = list(axes.flatten()) - - # AVOID (may cause "invalid __array_struct__" error): - flat_axes = np.array(axes).flatten() - - Returns: - np.ndarray: Array of wrapped axes with the same shape - """ - # Show a warning to help users avoid common mistakes - logger.warning( - "Converting AxesWrapper to numpy array. If you're trying to flatten " - "the axes, use 'list(axes.flatten())' instead of 'np.array(axes).flatten()'." - ) - - # Convert the underlying axes to a compatible numpy array representation - flat_axes = [ax for ax in self._axes_scitex.flat] - array_compatible = np.empty(len(flat_axes), dtype=object) - for idx, ax in enumerate(flat_axes): - array_compatible[idx] = ax - return array_compatible.reshape(self._axes_scitex.shape) - - def legend(self, loc="best"): - """Add legend to all axes with 'best' automatic placement by default.""" - return [ax.legend(loc=loc) for ax in self._axes_scitex.flat] - - @property - def history(self): - return [ax.history for ax in self._axes_scitex.flat] - - @property - def shape(self): - return self._axes_scitex.shape - - @property - def flat(self): - """Return a flat iterator over all axes. - - This property provides direct access to the flattened axes array, - matching numpy array behavior. - - Returns: - Iterator over all axes in row-major (C-style) order - """ - return self._axes_scitex.flat - - def flatten(self): - """Return a flattened array of all axes in the AxesWrapper. - - This method collects all axes from the flat iterator and returns them - as a NumPy array. This ensures compatibility with code that expects - a flat collection of axes. - - Returns: - np.ndarray: A flattened array containing all axes - - Example: - # Preferred way to get a list of all axes: - axes_list = list(axes.flatten()) - - # Alternatively, if you need a NumPy array: - axes_array = axes.flatten() - """ - return np.array([ax for ax in self._axes_scitex.flat]) - - def export_as_csv(self): - dfs = [] - for ii, ax in enumerate(self._axes_scitex.flat): - df = ax.export_as_csv() - # Column names already include axis position via get_csv_column_name - # No need to add extra prefix - dfs.append(df) - return pd.concat(dfs, axis=1) if dfs else pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapper.py b/src/scitex/plt/_subplots/_AxisWrapper.py deleted file mode 100755 index 07ec03e8f..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapper.py +++ /dev/null @@ -1,375 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 10:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapper.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import warnings -from functools import wraps - -import matplotlib - -from scitex import logging - -logger = logging.getLogger(__name__) - -from scitex.plt.styles import apply_plot_defaults, apply_plot_postprocess - -from ._AxisWrapperMixins import ( - AdjustmentMixin, - MatplotlibPlotMixin, - RawMatplotlibMixin, - SeabornMixin, - TrackingMixin, - UnitAwareMixin, -) - - -class AxisWrapper( - MatplotlibPlotMixin, - SeabornMixin, - RawMatplotlibMixin, - AdjustmentMixin, - TrackingMixin, - UnitAwareMixin, -): - def __init__(self, fig_scitex, axis_mpl, track): - """Initialize the AxisWrapper. - - Args: - fig_scitex: Parent figure wrapper - axis_mpl: Matplotlib axis to wrap - track: Whether to track plotting operations - """ - self._fig_mpl = fig_scitex._fig_mpl - # Axis Properties - # self.axis = axis_mpl - # self._axis = axis_mpl - # self._axis_scitex = self - self._axis_mpl = axis_mpl - - # Axes Properties - # self.axes = axis_mpl - # self._axes = axis_mpl - self._axes_mpl = axis_mpl - # self._axes_scitex = self - - # Tracking properties - self._ax_history = {} - self._method_counters = {} # Track method counts for auto-generated IDs - self.track = track - self.id = 0 - self._counter_part = matplotlib.axes.Axes - self._tracking_depth = 0 # Depth counter to prevent tracking internal calls - - # Initialize unit awareness - UnitAwareMixin.__init__(self) - - def get_figure(self, root=True): - """Get the figure, compatible with matplotlib 3.8+""" - return self._fig_mpl - - def twinx(self): - """Create a twin y-axis and wrap it with AxisWrapper.""" - twin_ax = self._axes_mpl.twinx() - - # Create a mock figure wrapper for the twin axis - class MockFigWrapper: - def __init__(self, fig_mpl): - self._fig_mpl = fig_mpl - - mock_fig = MockFigWrapper(self._fig_mpl) - return AxisWrapper(fig_scitex=mock_fig, axis_mpl=twin_ax, track=self.track) - - def twiny(self): - """Create a twin x-axis and wrap it with AxisWrapper.""" - twin_ax = self._axes_mpl.twiny() - - # Create a mock figure wrapper for the twin axis - class MockFigWrapper: - def __init__(self, fig_mpl): - self._fig_mpl = fig_mpl - - mock_fig = MockFigWrapper(self._fig_mpl) - return AxisWrapper(fig_scitex=mock_fig, axis_mpl=twin_ax, track=self.track) - - def __getattr__(self, name): - # 0. Check if the attribute is explicitly defined in AxisWrapper or its Mixins - # This check happens implicitly before __getattr__ is called. - # If a method like `plot` is defined in BasicPlotMixin, it will be found first. - - # print(f"Attribute of AxisWrapper: {name}") - - # 1. Try to get the attribute from the wrapped axes instance - if hasattr(self._axes_mpl, name): - orig_attr = getattr(self._axes_mpl, name) - - if callable(orig_attr): - - @wraps(orig_attr) - def wrapper(*args, __method_name__=name, **kwargs): - id_value = kwargs.pop("id", None) - track_override = kwargs.pop("track", None) - - # Increment tracking depth to detect internal calls - # Internal calls (depth > 1) won't be tracked - self._tracking_depth += 1 - is_top_level_call = self._tracking_depth == 1 - - try: - # Apply pre-processing defaults from styles module - apply_plot_defaults( - __method_name__, kwargs, id_value, self._axes_mpl - ) - - # Pop scitex-specific kwargs before calling matplotlib - # These are handled in post-processing - scitex_kwargs = {} - if __method_name__ == "violinplot": - scitex_kwargs["boxplot"] = kwargs.pop("boxplot", True) - - # Call the original matplotlib method - result = orig_attr(*args, **kwargs) - - # Store the scitex id on the result for later retrieval - # This is used by _collect_figure_metadata to map traces to CSV columns - if id_value is not None: - if isinstance(result, list): - # plot() returns list of Line2D objects - for item in result: - item._scitex_id = id_value - elif hasattr(result, "__iter__") and not isinstance( - result, str - ): - # Other containers (e.g., bar containers) - try: - for item in result: - item._scitex_id = id_value - except (TypeError, AttributeError): - pass - else: - # Single object - try: - result._scitex_id = id_value - except AttributeError: - pass - - # Restore scitex kwargs for post-processing - kwargs.update(scitex_kwargs) - - # Apply post-processing styling from styles module - apply_plot_postprocess( - __method_name__, result, self._axes_mpl, kwargs, args - ) - - # Determine if tracking should occur - # Only track top-level calls (depth == 1), not internal matplotlib calls - should_track = ( - track_override if track_override is not None else self.track - ) and is_top_level_call - - # Track the method call if tracking enabled - # Expanded list of matplotlib plotting methods to track - tracking_methods = { - # Basic plots - "plot", - "scatter", - "bar", - "barh", - "hist", - "boxplot", - "violinplot", - # Line plots - "fill_between", - "fill_betweenx", - "errorbar", - "step", - "stem", - # Fill and area plots - "fill", - "stackplot", - # Statistical plots - "hist2d", - "hexbin", - "pie", - "eventplot", - # Contour plots - "contour", - "contourf", - "tricontour", - "tricontourf", - # Image plots - "imshow", - "matshow", - "spy", - "pcolormesh", - "pcolor", - # Quiver plots - "quiver", - "streamplot", - # 3D-related (if axes3d) - "plot3D", - "scatter3D", - "bar3d", - "plot_surface", - "plot_wireframe", - # Text and annotations (data-containing) - "annotate", - "text", - } - if should_track and __method_name__ in tracking_methods: - # Use the _track method from TrackingMixin - # If no id provided, it will auto-generate one - try: - # Convert args to tracked_dict for consistency with other tracking - tracked_dict = {"args": args} - self._track( - should_track, - id_value, - __method_name__, - tracked_dict, - kwargs, - ) - except AttributeError: - logger.warning( - f"Tracking setup incomplete for AxisWrapper ({__method_name__})." - ) - except Exception as e: - # Silently continue if tracking fails to not break plotting - pass - return result # Return the result of the original call - finally: - # Always decrement depth, even if exception occurs - self._tracking_depth -= 1 - - return wrapper - else: - # If it's a non-callable attribute (property, etc.), return it directly - return orig_attr - - # 2. If not found on instance, try the counterpart type (fallback) - if hasattr(self._counter_part, name): - counterpart_attr = getattr(self._counter_part, name) - logger.warning( - f"SciTeX Axis_MplWrapper: '{name}' not directly handled. " - f"Falling back to underlying '{self._counter_part.__name__}' attribute." - ) - # If the counterpart attribute is callable (likely a method descriptor) - if callable(counterpart_attr): - # Return a new function that calls the counterpart method on self._axes_mpl - @wraps(counterpart_attr) - def fallback_method(*args, **kwargs): - # Note: No id/track handling for fallback methods - return counterpart_attr(self._axes_mpl, *args, **kwargs) - - return fallback_method - else: - # Non-callable class attribute. Attempt to get from instance again, - # otherwise return the class attribute/descriptor. - try: - return getattr(self._axes_mpl, name) - except AttributeError: - return counterpart_attr - - # 3. If not found anywhere, raise AttributeError - raise AttributeError( - f"'{type(self).__name__}' object and its underlying '{self._counter_part.__name__}' " - f"have no attribute '{name}'" - ) - - def __dir__(self): - # Start with attributes from the class and all parent classes (mixins) - attrs = set() - - # Get attributes from all parent classes including mixins - for cls in self.__class__.__mro__: - attrs.update(cls.__dict__.keys()) - - # Add instance attributes - attrs.update(self.__dict__.keys()) - - # Safely get matplotlib axes attributes - try: - # Get attributes from the wrapped matplotlib axes - if hasattr(self._axes_mpl, "__class__"): - # Get class methods from matplotlib.axes.Axes - for cls in self._axes_mpl.__class__.__mro__: - attrs.update( - name for name in cls.__dict__.keys() if not name.startswith("_") - ) - - # Add instance attributes of the matplotlib axes - if hasattr(self._axes_mpl, "__dict__"): - attrs.update( - name - for name in self._axes_mpl.__dict__.keys() - if not name.startswith("_") - ) - - except Exception: - # If any error occurs, add common matplotlib methods manually - attrs.update( - [ - "plot", - "scatter", - "bar", - "barh", - "hist", - "boxplot", - "set_xlabel", - "set_ylabel", - "set_title", - "legend", - "set_xlim", - "set_ylim", - "grid", - "annotate", - "text", - ] - ) - - # Remove private attributes - attrs = {attr for attr in attrs if not attr.startswith("_")} - - return sorted(attrs) - - def flatten(self): - """Return a list containing just this axis. - - This method makes AxisWrapper compatible with code that calls flatten() - on an axes collection. It returns a list containing just this single axis - to maintain consistency with AxesWrapper.flatten(). - - Returns: - list: A list containing this axis wrapper - - Example: - # When working with either AxesWrapper or AxisWrapper, this works: - axes_list = list(axes.flatten()) - """ - return [self] - - -""" -import matplotlib.pyplot as plt -import scitex.plt as mplt - -fig_scitex, axes = plt.subplots(ncols=2) -mfig_scitex, maxes = mplt.subplots(ncols=2) - -print(set(dir(mfig_scitex)) - set(dir(fig_scitex))) -print(set(dir(maxes)) - set(dir(axes))) - -is_compatible = np.all([kk in set(dir(msubplots)) for kk in set(dir(counter_part))]) -if is_compatible: - print(f"{msubplots.__name__} is compatible with {counter_part.__name__}") -else: - print(f"{msubplots.__name__} is incompatible with {counter_part.__name__}") -""" - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/__init__.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/__init__.py deleted file mode 100755 index f459598e9..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: __init__.py - AdjustmentMixin package - -""" -AdjustmentMixin - Modular axis adjustment mixin for AxisWrapper. - -This package provides axis adjustment functionality split into logical submodules: -- _labels: Label rotation and legend positioning -- _metadata: Axis labels, titles, and scientific metadata -- _visual: Visual adjustments (ticks, spines, extend, shift) -""" - -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - -from ._labels import LabelsMixin -from ._metadata import MetadataMixin -from ._visual import VisualAdjustmentMixin - - -class AdjustmentMixin(LabelsMixin, MetadataMixin, VisualAdjustmentMixin): - """Mixin class for matplotlib axis adjustments. - - Combines multiple specialized mixins: - - LabelsMixin: Label rotation and legend positioning - - MetadataMixin: Axis labels, titles, and scientific metadata - - VisualAdjustmentMixin: Ticks, spines, extend, shift - """ - - pass - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_labels.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_labels.py deleted file mode 100755 index a470f9e36..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_labels.py +++ /dev/null @@ -1,268 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: _labels.py - Label rotation and legend handling - -"""Mixin for label rotation and legend positioning.""" - -import os - -from scitex import logging - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - -logger = logging.getLogger(__name__) - - -class LabelsMixin: - """Mixin for label rotation and legend positioning.""" - - def _get_ax_module(self): - """Lazy import ax module to avoid circular imports.""" - from .....plt import ax as ax_module - - return ax_module - - def rotate_labels( - self, - x: float = None, - y: float = None, - x_ha: str = None, - y_ha: str = None, - x_va: str = None, - y_va: str = None, - auto_adjust: bool = True, - scientific_convention: bool = True, - tight_layout: bool = False, - ) -> None: - """Rotate x and y axis labels with automatic positioning. - - Parameters - ---------- - x : float or None, optional - Rotation angle for x-axis labels in degrees. - y : float or None, optional - Rotation angle for y-axis labels in degrees. - x_ha, y_ha : str or None, optional - Horizontal alignment for x/y-axis labels. - x_va, y_va : str or None, optional - Vertical alignment for x/y-axis labels. - auto_adjust : bool, optional - Whether to automatically adjust alignment. Default is True. - scientific_convention : bool, optional - Whether to follow scientific conventions. Default is True. - tight_layout : bool, optional - Whether to apply tight_layout. Default is False. - """ - self._axis_mpl = self._get_ax_module().rotate_labels( - self._axis_mpl, - x=x, - y=y, - x_ha=x_ha, - y_ha=y_ha, - x_va=x_va, - y_va=y_va, - auto_adjust=auto_adjust, - scientific_convention=scientific_convention, - tight_layout=tight_layout, - ) - - def legend( - self, *args, loc: str = "best", check_overlap: bool = False, **kwargs - ) -> None: - """Places legend at specified location, with support for outside positions. - - Parameters - ---------- - *args : tuple - Positional arguments (handles, labels) as in matplotlib - loc : str - Legend position. Default is "best" (matplotlib auto-placement). - Special positions: - - "best": Matplotlib automatic placement - - "outer": Place outside plot area (right side) - - "separate": Save legend as a separate figure file - - upper/lower/center variants: e.g. "upper right out" - check_overlap : bool - If True, checks for overlap between legend and data. - **kwargs : dict - Additional keyword arguments passed to legend() - """ - import matplotlib.pyplot as plt - - if loc == "outer": - legend = self._axis_mpl.legend( - *args, loc="center left", bbox_to_anchor=(1.02, 0.5), **kwargs - ) - if hasattr(self, "_figure_wrapper") and self._figure_wrapper: - self._figure_wrapper._fig_mpl.tight_layout() - self._figure_wrapper._fig_mpl.subplots_adjust(right=0.85) - return legend - - elif loc == "separate": - handles, labels = self._axis_mpl.get_legend_handles_labels() - if not handles: - logger.warning("No legend handles found.") - return None - - fig = self._axis_mpl.get_figure() - if not hasattr(fig, "_separate_legend_params"): - fig._separate_legend_params = [] - - figsize = kwargs.pop("figsize", (4, 3)) - dpi = kwargs.pop("dpi", 150) - frameon = kwargs.pop("frameon", True) - fancybox = kwargs.pop("fancybox", True) - shadow = kwargs.pop("shadow", True) - - axis_id = self._get_axis_id(fig) - - fig._separate_legend_params.append( - { - "axis": self._axis_mpl, - "axis_id": axis_id, - "handles": handles, - "labels": labels, - "figsize": figsize, - "dpi": dpi, - "frameon": frameon, - "fancybox": fancybox, - "shadow": shadow, - "kwargs": kwargs, - } - ) - - if self._axis_mpl.get_legend(): - self._axis_mpl.get_legend().remove() - - return None - - outside_positions = { - "upper right out": ("center left", (1.15, 0.85)), - "right upper out": ("center left", (1.15, 0.85)), - "center right out": ("center left", (1.15, 0.5)), - "right out": ("center left", (1.15, 0.5)), - "right": ("center left", (1.05, 0.5)), - "lower right out": ("center left", (1.15, 0.15)), - "right lower out": ("center left", (1.15, 0.15)), - "upper left out": ("center right", (-0.25, 0.85)), - "left upper out": ("center right", (-0.25, 0.85)), - "center left out": ("center right", (-0.25, 0.5)), - "left out": ("center right", (-0.25, 0.5)), - "left": ("center right", (-0.15, 0.5)), - "lower left out": ("center right", (-0.25, 0.15)), - "left lower out": ("center right", (-0.25, 0.15)), - "upper center out": ("lower center", (0.5, 1.25)), - "upper out": ("lower center", (0.5, 1.25)), - "lower center out": ("upper center", (0.5, -0.25)), - "lower out": ("upper center", (0.5, -0.25)), - } - - if loc in outside_positions: - location, bbox = outside_positions[loc] - legend_obj = self._axis_mpl.legend( - *args, loc=location, bbox_to_anchor=bbox, **kwargs - ) - else: - legend_obj = self._axis_mpl.legend(*args, loc=loc, **kwargs) - - if check_overlap and legend_obj is not None: - self._check_legend_overlap(legend_obj) - - return legend_obj - - def _get_axis_id(self, fig): - """Get unique axis identifier for separate legend handling.""" - axis_id = None - - try: - fig_axes = fig.get_axes() - for idx, ax in enumerate(fig_axes): - if ax is self._axis_mpl: - axis_id = f"ax_{idx:02d}" - break - except: - pass - - if axis_id is None and hasattr(self._axis_mpl, "get_subplotspec"): - try: - spec = self._axis_mpl.get_subplotspec() - if spec is not None: - gridspec = spec.get_gridspec() - nrows, ncols = gridspec.get_geometry() - rowspan = spec.rowspan - colspan = spec.colspan - row_start = rowspan.start if hasattr(rowspan, "start") else rowspan - col_start = colspan.start if hasattr(colspan, "start") else colspan - flat_idx = row_start * ncols + col_start - axis_id = f"ax_{flat_idx:02d}" - except: - pass - - if axis_id is None: - axis_id = f"ax_{len(fig._separate_legend_params):02d}" - - return axis_id - - def _check_legend_overlap(self, legend_obj): - """Check if legend overlaps with plotted data and issue warning if needed.""" - import warnings - - import matplotlib.transforms as transforms - import numpy as np - - try: - fig = self._axis_mpl.get_figure() - fig.canvas.draw() - - legend_bbox = legend_obj.get_window_extent(fig.canvas.get_renderer()) - inv_transform = self._axis_mpl.transData.inverted() - legend_bbox_data = legend_bbox.transformed(inv_transform) - - data_bboxes = [] - - for line in self._axis_mpl.get_lines(): - if line.get_visible(): - try: - data = line.get_xydata() - if len(data) > 0: - data_bboxes.append(data) - except: - pass - - for collection in self._axis_mpl.collections: - if collection.get_visible(): - try: - offsets = collection.get_offsets() - if len(offsets) > 0: - data_bboxes.append(offsets) - except: - pass - - if data_bboxes: - all_data = np.vstack(data_bboxes) - - x_overlap = (all_data[:, 0] >= legend_bbox_data.x0) & ( - all_data[:, 0] <= legend_bbox_data.x1 - ) - y_overlap = (all_data[:, 1] >= legend_bbox_data.y0) & ( - all_data[:, 1] <= legend_bbox_data.y1 - ) - overlap_points = np.sum(x_overlap & y_overlap) - overlap_pct = (overlap_points / len(all_data)) * 100 - - if overlap_pct > 5: - logger.warning( - f"Legend overlaps with {overlap_pct:.1f}% of data points. " - f"Consider using loc='outer' or loc='separate'." - ) - return True - - except Exception: - pass - - return False - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_metadata.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_metadata.py deleted file mode 100755 index 5cd8f6e9f..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_metadata.py +++ /dev/null @@ -1,214 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: _metadata.py - Axis metadata and labels - -"""Mixin for axis labels, titles, and metadata.""" - -import os -from typing import Optional - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - - -class MetadataMixin: - """Mixin for setting axis labels, titles, and metadata.""" - - def _get_ax_module(self): - """Lazy import ax module to avoid circular imports.""" - from .....plt import ax as ax_module - - return ax_module - - def set_xyt( - self, - x: Optional[str] = None, - y: Optional[str] = None, - t: Optional[str] = None, - format_labels: bool = True, - ) -> None: - """Set xlabel, ylabel, and title.""" - self._axis_mpl = self._get_ax_module().set_xyt( - self._axis_mpl, - x=x, - y=y, - t=t, - format_labels=format_labels, - ) - - def set_xytc( - self, - x: Optional[str] = None, - y: Optional[str] = None, - t: Optional[str] = None, - c: Optional[str] = None, - format_labels: bool = True, - ) -> None: - """Set xlabel, ylabel, title, and caption for automatic saving. - - Parameters - ---------- - x : str, optional - X-axis label - y : str, optional - Y-axis label - t : str, optional - Title - c : str, optional - Caption to be saved automatically with scitex.io.save() - format_labels : bool, optional - Whether to apply automatic formatting, by default True - """ - self._axis_mpl = self._get_ax_module().set_xytc( - self._axis_mpl, - x=x, - y=y, - t=t, - c=c, - format_labels=format_labels, - ) - - if c is not False and c is not None: - self._scitex_caption = c - - def set_supxyt( - self, - xlabel: Optional[str] = None, - ylabel: Optional[str] = None, - title: Optional[str] = None, - format_labels: bool = True, - ) -> None: - """Set figure-level xlabel, ylabel, and title (suptitle).""" - self._axis_mpl = self._get_ax_module().set_supxyt( - self._axis_mpl, - xlabel=xlabel, - ylabel=ylabel, - title=title, - format_labels=format_labels, - ) - - def set_supxytc( - self, - xlabel: Optional[str] = None, - ylabel: Optional[str] = None, - title: Optional[str] = None, - caption: Optional[str] = None, - format_labels: bool = True, - ) -> None: - """Set figure-level xlabel, ylabel, title, and caption. - - Parameters - ---------- - xlabel : str, optional - Figure-level X-axis label - ylabel : str, optional - Figure-level Y-axis label - title : str, optional - Figure-level title (suptitle) - caption : str, optional - Figure-level caption for automatic saving - format_labels : bool, optional - Whether to apply automatic formatting - """ - self._axis_mpl = self._get_ax_module().set_supxytc( - self._axis_mpl, - xlabel=xlabel, - ylabel=ylabel, - title=title, - caption=caption, - format_labels=format_labels, - ) - - if caption is not False and caption is not None: - fig = self._axis_mpl.get_figure() - fig._scitex_main_caption = caption - - def set_meta( - self, - caption=None, - methods=None, - stats=None, - keywords=None, - experimental_details=None, - journal_style=None, - significance=None, - **kwargs, - ) -> None: - """Set comprehensive scientific metadata with YAML export capability. - - Parameters - ---------- - caption : str, optional - Figure caption text - methods : str, optional - Experimental methods description - stats : str, optional - Statistical analysis details - keywords : List[str], optional - Keywords for categorization - experimental_details : Dict[str, Any], optional - Structured experimental parameters - journal_style : str, optional - Target journal style - significance : str, optional - Significance statement - **kwargs : additional metadata - """ - self._axis_mpl = self._get_ax_module().set_meta( - self._axis_mpl, - caption=caption, - methods=methods, - stats=stats, - keywords=keywords, - experimental_details=experimental_details, - journal_style=journal_style, - significance=significance, - **kwargs, - ) - - def set_figure_meta( - self, - caption=None, - methods=None, - stats=None, - significance=None, - funding=None, - conflicts=None, - data_availability=None, - **kwargs, - ) -> None: - """Set figure-level metadata for multi-panel figures. - - Parameters - ---------- - caption : str, optional - Figure-level caption - methods : str, optional - Overall experimental methods - stats : str, optional - Overall statistical approach - significance : str, optional - Significance and implications - funding : str, optional - Funding acknowledgments - conflicts : str, optional - Conflict of interest statement - data_availability : str, optional - Data availability statement - **kwargs : additional metadata - """ - self._axis_mpl = self._get_ax_module().set_figure_meta( - self._axis_mpl, - caption=caption, - methods=methods, - stats=stats, - significance=significance, - funding=funding, - conflicts=conflicts, - data_availability=data_availability, - **kwargs, - ) - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_visual.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_visual.py deleted file mode 100755 index 1ede4662b..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_visual.py +++ /dev/null @@ -1,129 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: _visual.py - Visual adjustments (ticks, spines, position) - -"""Mixin for visual adjustments including ticks, spines, and positioning.""" - -import os -from typing import List, Optional, Union - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - - -class VisualAdjustmentMixin: - """Mixin for visual adjustments to axis appearance.""" - - def _get_ax_module(self): - """Lazy import ax module to avoid circular imports.""" - from .....plt import ax as ax_module - - return ax_module - - def set_ticks( - self, - xvals: Optional[List[Union[int, float]]] = None, - xticks: Optional[List[str]] = None, - yvals: Optional[List[Union[int, float]]] = None, - yticks: Optional[List[str]] = None, - ) -> None: - """Set custom tick positions and labels. - - Parameters - ---------- - xvals : list of numbers, optional - Positions for x-axis ticks - xticks : list of str, optional - Labels for x-axis ticks - yvals : list of numbers, optional - Positions for y-axis ticks - yticks : list of str, optional - Labels for y-axis ticks - """ - self._axis_mpl = self._get_ax_module().set_ticks( - self._axis_mpl, - xvals=xvals, - xticks=xticks, - yvals=yvals, - yticks=yticks, - ) - - def set_n_ticks(self, n_xticks: int = 4, n_yticks: int = 4) -> None: - """Set the number of ticks on each axis. - - Parameters - ---------- - n_xticks : int, optional - Number of ticks on x-axis, by default 4 - n_yticks : int, optional - Number of ticks on y-axis, by default 4 - """ - self._axis_mpl = self._get_ax_module().set_n_ticks( - self._axis_mpl, n_xticks=n_xticks, n_yticks=n_yticks - ) - - def hide_spines( - self, - top: bool = True, - bottom: bool = False, - left: bool = False, - right: bool = True, - ticks: bool = False, - labels: bool = False, - ) -> None: - """Hide specific spines and optionally ticks/labels. - - Parameters - ---------- - top : bool, optional - Hide top spine, by default True - bottom : bool, optional - Hide bottom spine, by default False - left : bool, optional - Hide left spine, by default False - right : bool, optional - Hide right spine, by default True - ticks : bool, optional - Hide all ticks, by default False - labels : bool, optional - Hide all tick labels, by default False - """ - self._axis_mpl = self._get_ax_module().hide_spines( - self._axis_mpl, - top=top, - bottom=bottom, - left=left, - right=right, - ticks=ticks, - labels=labels, - ) - - def extend(self, x_ratio: float = 1.0, y_ratio: float = 1.0) -> None: - """Extend axis limits by a ratio. - - Parameters - ---------- - x_ratio : float, optional - Ratio to extend x-axis by, by default 1.0 - y_ratio : float, optional - Ratio to extend y-axis by, by default 1.0 - """ - self._axis_mpl = self._get_ax_module().extend( - self._axis_mpl, x_ratio=x_ratio, y_ratio=y_ratio - ) - - def shift(self, dx: float = 0, dy: float = 0) -> None: - """Shift axis position. - - Parameters - ---------- - dx : float, optional - Horizontal shift, by default 0 - dy : float, optional - Vertical shift, by default 0 - """ - self._axis_mpl = self._get_ax_module().shift(self._axis_mpl, dx=dx, dy=dy) - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/__init__.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/__init__.py deleted file mode 100755 index fc3f6f10b..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/__init__.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: __init__.py - MatplotlibPlotMixin package - -""" -MatplotlibPlotMixin - Modular plotting mixin for AxisWrapper. - -This package provides plotting functionality split into logical submodules: -- _base: Core helper methods -- _scientific: Scientific/specialized plots (stx_image, stx_kde, stx_conf_mat, etc.) -- _statistical: Statistical plots (stx_line, stx_mean_std, stx_box, stx_violin, hist) -- _stx_aliases: stx_ prefixed aliases for standard matplotlib methods - -API Layer Design: ------------------ -stx_* (SciTeX canonical): - - Full tracking, metadata, and reproducibility support - - Output connects to .plot / .figure format - - Purpose: publication / reproducibility - -mpl_* (Matplotlib compatibility - see _RawMatplotlibMixin): - - Raw matplotlib API without scitex processing - - Purpose: compatibility / low-level control / escape hatch - -sns_* (Seaborn - see _SeabornMixin): - - DataFrame-centric with data=, x=, y=, hue= interface - - Purpose: exploratory / grouped stats -""" - -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - -from ._base import PlotBaseMixin -from ._scientific import ScientificPlotMixin -from ._statistical import StatisticalPlotMixin -from ._stx_aliases import StxAliasesMixin - - -class MatplotlibPlotMixin( - PlotBaseMixin, - ScientificPlotMixin, - StatisticalPlotMixin, - StxAliasesMixin, -): - """Mixin class for basic plotting operations. - - Combines multiple specialized mixins: - - PlotBaseMixin: Core helper methods (_get_ax_module, _apply_scitex_postprocess) - - ScientificPlotMixin: Scientific plots (stx_image, stx_kde, stx_conf_mat, etc.) - - StatisticalPlotMixin: Statistical line plots and distributions - - StxAliasesMixin: stx_ prefixed matplotlib aliases - """ - - pass - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_base.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_base.py deleted file mode 100755 index 674cc59b2..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_base.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: _base.py - Core helper methods for MatplotlibPlotMixin - -"""Base mixin with core helper methods for plotting.""" - -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - - -class PlotBaseMixin: - """Base mixin with core helper methods for plotting.""" - - def _get_ax_module(self): - """Lazy import ax module to avoid circular imports.""" - from .....plt import ax as ax_module - - return ax_module - - def _apply_scitex_postprocess( - self, method_name, result=None, kwargs=None, args=None - ): - """Apply scitex post-processing styling after plotting. - - This ensures all scitex wrapper methods get the same styling - as matplotlib methods going through __getattr__ (tick locator, spines, etc.). - """ - from scitex.plt.styles import apply_plot_postprocess - - apply_plot_postprocess(method_name, result, self._axis_mpl, kwargs or {}, args) - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_scientific.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_scientific.py deleted file mode 100755 index 67d2c7daa..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_scientific.py +++ /dev/null @@ -1,596 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: _scientific.py - Scientific/specialized plot methods - -"""Scientific and domain-specific plotting methods.""" - -import os -from typing import Any, Dict, List, Optional, Tuple - -import matplotlib -import numpy as np -import pandas as pd -from scipy.stats import gaussian_kde - -from scitex.pd import to_xyz -from scitex.types import ArrayLike - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - - -class ScientificPlotMixin: - """Mixin for scientific and domain-specific plotting methods. - - Provides specialized visualizations for: - - Image display with colorbars - - Kernel density estimation - - Confusion matrices - - Raster plots (spike trains) - - ECDF plots - - Joint distributions (scatter + marginal histograms) - - Heatmaps with annotations - """ - - def stx_image( - self, - data: ArrayLike, - *, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> "Axes": - """Display a 2D array as an image with SciTeX styling. - - Parameters - ---------- - data : array-like - 2D array to display as an image. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the image function. - Common options: cmap, vmin, vmax, aspect, colorbar. - - Returns - ------- - Axes - The axes with the image displayed. - - See Also - -------- - stx_imshow : Lower-level image display. - stx_heatmap : Annotated heatmap. - sns_heatmap : DataFrame-based heatmap. - - Examples - -------- - >>> ax.stx_image(matrix, cmap='viridis', colorbar=True) - """ - method_name = "stx_image" - - with self._no_tracking(): - self._axis_mpl = self._get_ax_module().stx_image( - self._axis_mpl, data, **kwargs - ) - - tracked_dict = {"image_df": pd.DataFrame(data)} - if kwargs.get("xyz", False): - tracked_dict["image_df"] = to_xyz(tracked_dict["image_df"]) - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl - - def stx_kde( - self, - data: ArrayLike, - *, - cumulative: bool = False, - fill: bool = False, - xlim: Optional[Tuple[float, float]] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> "Axes": - """Plot a kernel density estimate of the data. - - Parameters - ---------- - data : array-like - 1D array of values for density estimation. - cumulative : bool, default False - If True, plot cumulative distribution instead of density. - fill : bool, default False - If True, fill the area under the curve. - xlim : tuple of float, optional - Range for the x-axis. If None, uses data range. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the plot function. - Common options: color, linewidth, linestyle, label. - - Returns - ------- - Axes - The axes with the KDE plot. - - See Also - -------- - sns_kdeplot : DataFrame-based KDE plot. - stx_ecdf : Empirical cumulative distribution function. - hist : Histogram alternative. - - Examples - -------- - >>> ax.stx_kde(samples, fill=True, alpha=0.3) - >>> ax.stx_kde(data, cumulative=True, label='CDF') - """ - method_name = "stx_kde" - - n_samples = (~np.isnan(data)).sum() - if kwargs.get("label"): - kwargs["label"] = f"{kwargs['label']} ($n$={n_samples})" - - if xlim is None: - xlim = (np.nanmin(data), np.nanmax(data)) - - xx = np.linspace(xlim[0], xlim[1], int(1e3)) - density = gaussian_kde(data)(xx) - density /= density.sum() - - if cumulative: - density = np.cumsum(density) - - with self._no_tracking(): - from scitex.plt.utils import mm_to_pt - - if "linewidth" not in kwargs and "lw" not in kwargs: - kwargs["linewidth"] = mm_to_pt(0.2) - if "color" not in kwargs and "c" not in kwargs: - kwargs["color"] = "black" - if "linestyle" not in kwargs and "ls" not in kwargs: - kwargs["linestyle"] = "--" - - if fill: - self._axis_mpl.fill_between(xx, density, **kwargs) - else: - self._axis_mpl.plot(xx, density, **kwargs) - - tracked_dict = {"x": xx, "kde": density, "n": n_samples} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl - - def stx_conf_mat( - self, - data: ArrayLike, - *, - x_labels: Optional[List[str]] = None, - y_labels: Optional[List[str]] = None, - title: str = "Confusion Matrix", - cmap: str = "Blues", - cbar: bool = True, - cbar_kw: Optional[Dict[str, Any]] = None, - label_rotation_xy: Tuple[float, float] = (15, 15), - x_extend_ratio: float = 1.0, - y_extend_ratio: float = 1.0, - calc_bacc: bool = False, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple["Axes", Optional[float]]: - """Plot a confusion matrix with optional balanced accuracy calculation. - - Parameters - ---------- - data : array-like - 2D confusion matrix array. - x_labels : list of str, optional - Labels for x-axis (predicted classes). - y_labels : list of str, optional - Labels for y-axis (true classes). - title : str, default 'Confusion Matrix' - Title for the plot. - cmap : str, default 'Blues' - Colormap for the heatmap. - cbar : bool, default True - Whether to show the colorbar. - cbar_kw : dict, optional - Additional keyword arguments for the colorbar. - label_rotation_xy : tuple of float, default (15, 15) - Rotation angles for (x, y) axis labels. - x_extend_ratio : float, default 1.0 - Ratio to extend x-axis limits. - y_extend_ratio : float, default 1.0 - Ratio to extend y-axis limits. - calc_bacc : bool, default False - Whether to calculate and return balanced accuracy. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the heatmap function. - - Returns - ------- - tuple - (Axes, balanced_accuracy) - balanced_accuracy is None if calc_bacc=False. - - Examples - -------- - >>> ax.stx_conf_mat(cm, x_labels=['A', 'B'], y_labels=['A', 'B']) - >>> ax, bacc = ax.stx_conf_mat(cm, calc_bacc=True) - """ - method_name = "stx_conf_mat" - - if cbar_kw is None: - cbar_kw = {} - - with self._no_tracking(): - self._axis_mpl, bacc_val = self._get_ax_module().stx_conf_mat( - self._axis_mpl, - data, - x_labels=x_labels, - y_labels=y_labels, - title=title, - cmap=cmap, - cbar=cbar, - cbar_kw=cbar_kw, - label_rotation_xy=label_rotation_xy, - x_extend_ratio=x_extend_ratio, - y_extend_ratio=y_extend_ratio, - calc_bacc=calc_bacc, - **kwargs, - ) - - tracked_dict = { - "args": [data], - "balanced_accuracy": bacc_val, - "x_labels": x_labels, - "y_labels": y_labels, - } - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl, bacc_val - - def stx_raster( - self, - spike_times: List[ArrayLike], - *, - time: Optional[ArrayLike] = None, - labels: Optional[List[str]] = None, - colors: Optional[List[str]] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple["Axes", pd.DataFrame]: - """Plot a raster plot (spike train visualization). - - Parameters - ---------- - spike_times : list of array-like - List of arrays, each containing spike times for one unit/neuron. - time : array-like, optional - Time axis reference. If None, uses spike time range. - labels : list of str, optional - Labels for each unit/row. - colors : list of str, optional - Colors for each unit/row. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the raster function. - - Returns - ------- - tuple - (Axes, DataFrame) - The axes and digitized raster data. - - Examples - -------- - >>> ax.stx_raster([spikes_unit1, spikes_unit2], labels=['Unit 1', 'Unit 2']) - """ - method_name = "stx_raster" - - with self._no_tracking(): - self._axis_mpl, raster_digit_df = self._get_ax_module().stx_raster( - self._axis_mpl, spike_times, time=time - ) - - tracked_dict = {"raster_digit_df": raster_digit_df} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl, raster_digit_df - - def stx_ecdf( - self, - data: ArrayLike, - *, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple["Axes", pd.DataFrame]: - """Plot an empirical cumulative distribution function (ECDF). - - Parameters - ---------- - data : array-like - 1D array of values. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the ECDF function. - Common options: color, linewidth, label. - - Returns - ------- - tuple - (Axes, DataFrame) - The axes and ECDF data (x, y columns). - - See Also - -------- - stx_kde : Kernel density estimate (continuous). - hist : Histogram (discrete bins). - - Examples - -------- - >>> ax.stx_ecdf(samples, label='Distribution A') - """ - method_name = "stx_ecdf" - - with self._no_tracking(): - self._axis_mpl, ecdf_df = self._get_ax_module().stx_ecdf( - self._axis_mpl, data, **kwargs - ) - - tracked_dict = {"ecdf_df": ecdf_df} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl, ecdf_df - - def stx_joyplot( - self, - data: ArrayLike, - *, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> "Axes": - """Plot a joyplot (ridgeline plot) for distribution comparison. - - Parameters - ---------- - data : array-like - 2D array where each row is a distribution to plot. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the joyplot function. - - Returns - ------- - Axes - The axes with the joyplot. - - Examples - -------- - >>> ax.stx_joyplot(distributions_2d, overlap=0.5) - """ - method_name = "stx_joyplot" - - with self._no_tracking(): - self._axis_mpl = self._get_ax_module().stx_joyplot( - self._axis_mpl, data, **kwargs - ) - - tracked_dict = {"joyplot_data": data} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl - - def stx_scatter_hist( - self, - x: ArrayLike, - y: ArrayLike, - *, - hist_bins: int = 20, - scatter_alpha: float = 0.6, - scatter_size: float = 20, - scatter_color: str = "blue", - hist_color_x: str = "blue", - hist_color_y: str = "red", - hist_alpha: float = 0.5, - scatter_ratio: float = 0.8, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple["Axes", "Axes", "Axes", Dict]: - """Plot a scatter plot with marginal histograms. - - Parameters - ---------- - x : array-like - X coordinates of the scatter points. - y : array-like - Y coordinates of the scatter points. - hist_bins : int, default 20 - Number of bins for the marginal histograms. - scatter_alpha : float, default 0.6 - Transparency of scatter points. - scatter_size : float, default 20 - Size of scatter points. - scatter_color : str, default 'blue' - Color of scatter points. - hist_color_x : str, default 'blue' - Color of x-marginal histogram. - hist_color_y : str, default 'red' - Color of y-marginal histogram. - hist_alpha : float, default 0.5 - Transparency of histograms. - scatter_ratio : float, default 0.8 - Ratio of scatter plot area to total area. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the scatter function. - - Returns - ------- - tuple - (main_ax, hist_x_ax, hist_y_ax, hist_data) - Axes and histogram data. - - See Also - -------- - stx_scatter : Simple scatter plot. - sns_jointplot : Seaborn joint plot. - - Examples - -------- - >>> ax, ax_hx, ax_hy, data = ax.stx_scatter_hist(x, y, hist_bins=30) - """ - method_name = "stx_scatter_hist" - - with self._no_tracking(): - ( - self._axis_mpl, - ax_histx, - ax_histy, - hist_data, - ) = self._get_ax_module().stx_scatter_hist( - self._axis_mpl, - x, - y, - hist_bins=hist_bins, - scatter_alpha=scatter_alpha, - scatter_size=scatter_size, - scatter_color=scatter_color, - hist_color_x=hist_color_x, - hist_color_y=hist_color_y, - hist_alpha=hist_alpha, - scatter_ratio=scatter_ratio, - **kwargs, - ) - - tracked_dict = { - "x": x, - "y": y, - "hist_x": hist_data["hist_x"], - "hist_y": hist_data["hist_y"], - "bin_edges_x": hist_data["bin_edges_x"], - "bin_edges_y": hist_data["bin_edges_y"], - } - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl, ax_histx, ax_histy, hist_data - - def stx_heatmap( - self, - data: ArrayLike, - *, - x_labels: Optional[List[str]] = None, - y_labels: Optional[List[str]] = None, - cmap: str = "viridis", - cbar_label: str = "ColorBar Label", - value_format: str = "{x:.1f}", - show_annot: bool = True, - annot_color_lighter: str = "white", - annot_color_darker: str = "black", - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple["Axes", matplotlib.image.AxesImage, matplotlib.colorbar.Colorbar]: - """Plot an annotated heatmap. - - Parameters - ---------- - data : array-like - 2D array of values to display. - x_labels : list of str, optional - Labels for x-axis (columns). - y_labels : list of str, optional - Labels for y-axis (rows). - cmap : str, default 'viridis' - Colormap name. - cbar_label : str, default 'ColorBar Label' - Label for the colorbar. - value_format : str, default '{x:.1f}' - Format string for cell annotations. - show_annot : bool, default True - Whether to show value annotations in cells. - annot_color_lighter : str, default 'white' - Annotation color for dark backgrounds. - annot_color_darker : str, default 'black' - Annotation color for light backgrounds. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the heatmap function. - - Returns - ------- - tuple - (Axes, AxesImage, Colorbar) - The axes, image, and colorbar objects. - - See Also - -------- - sns_heatmap : DataFrame-based heatmap. - stx_conf_mat : Confusion matrix heatmap. - stx_image : Simple image display. - - Examples - -------- - >>> ax, im, cbar = ax.stx_heatmap(matrix, x_labels=['A', 'B'], cmap='coolwarm') - """ - method_name = "stx_heatmap" - - with self._no_tracking(): - ax, im, cbar = self._get_ax_module().stx_heatmap( - self._axis_mpl, - data, - x_labels=x_labels, - y_labels=y_labels, - cmap=cmap, - cbar_label=cbar_label, - value_format=value_format, - show_annot=show_annot, - annot_color_lighter=annot_color_lighter, - annot_color_darker=annot_color_darker, - **kwargs, - ) - - tracked_dict = { - "data": data, - "x_labels": x_labels, - "y_labels": y_labels, - } - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return ax, im, cbar - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_statistical.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_statistical.py deleted file mode 100755 index 5913ba738..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_statistical.py +++ /dev/null @@ -1,654 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: _statistical.py - Statistical plot methods - -"""Statistical plotting methods including line plots, box plots, and violin plots.""" - -import os -from typing import List, Optional, Sequence, Tuple, Union - -import numpy as np -import pandas as pd - -from scitex.types import ArrayLike - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - - -class StatisticalPlotMixin: - """Mixin for statistical plotting methods. - - Provides methods for: - - Distribution plots (boxplot, violin) - - Line plots with uncertainty (mean±std, mean±CI, median±IQR) - - Histograms with bin alignment - - Geometric shapes (rectangles, filled regions) - """ - - def stx_rectangle( - self, - x: float, - y: float, - width: float, - height: float, - *, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> "Axes": - """Draw a rectangle on the axes. - - Parameters - ---------- - x : float - X coordinate of the lower-left corner. - y : float - Y coordinate of the lower-left corner. - width : float - Width of the rectangle. - height : float - Height of the rectangle. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the rectangle function. - - Returns - ------- - Axes - The axes with the rectangle added. - - Examples - -------- - >>> ax.stx_rectangle(0, 0, 1, 2, color='blue', alpha=0.5) - """ - method_name = "stx_rectangle" - - with self._no_tracking(): - self._axis_mpl = self._get_ax_module().stx_rectangle( - self._axis_mpl, x, y, width, height, **kwargs - ) - - tracked_dict = {"x": x, "y": y, "width": width, "height": height} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl - - def stx_fillv( - self, - starts: ArrayLike, - ends: ArrayLike, - *, - color: str = "red", - alpha: float = 0.2, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> "Axes": - """Fill vertical spans between start and end positions. - - Parameters - ---------- - starts : array-like - Start x-coordinates of each span. - ends : array-like - End x-coordinates of each span. - color : str, default 'red' - Fill color. - alpha : float, default 0.2 - Transparency level. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the fill function. - - Returns - ------- - Axes - The axes with the filled spans added. - - Examples - -------- - >>> ax.stx_fillv([0, 2, 4], [1, 3, 5], color='green') - """ - method_name = "stx_fillv" - - with self._no_tracking(): - self._axis_mpl = self._get_ax_module().stx_fillv( - self._axis_mpl, starts, ends, color=color, alpha=alpha - ) - - tracked_dict = {"starts": starts, "ends": ends} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl - - def stx_box( - self, - data: Union[ArrayLike, Sequence[ArrayLike]], - *, - colors: Optional[List[str]] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> dict: - """Create a boxplot with SciTeX styling and tracking. - - Parameters - ---------- - data : array-like or sequence of array-like - Data for the boxplot. Can be a single array or list of arrays - where each array represents a group. - colors : list of str, optional - Colors for each box. If None, uses default palette. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `matplotlib.axes.Axes.boxplot`. - - Returns - ------- - dict - Dictionary mapping component names ('boxes', 'whiskers', etc.) - to lists of Line2D or Patch artists. - - See Also - -------- - stx_boxplot : Alias for this method. - sns_boxplot : DataFrame-based boxplot. - stx_violin : Violin plot alternative. - - Examples - -------- - >>> ax.stx_box([data1, data2, data3], labels=['A', 'B', 'C']) - >>> ax.stx_box(data, notch=True, patch_artist=True) - """ - method_name = "stx_box" - - _data = data.copy() - - if kwargs.get("label"): - n_per_group = [len(g) for g in data] - n_min, n_max = min(n_per_group), max(n_per_group) - n_str = str(n_min) if n_min == n_max else f"{n_min}-{n_max}" - kwargs["label"] = kwargs["label"] + f" ($n$={n_str})" - - if "patch_artist" not in kwargs: - kwargs["patch_artist"] = True - - with self._no_tracking(): - result = self._axis_mpl.boxplot(data, **kwargs) - - n_per_group = [len(g) for g in data] - tracked_dict = {"data": _data, "n": n_per_group} - self._track(track, id, method_name, tracked_dict, None) - - from scitex.plt.ax import style_boxplot - - style_boxplot(result, colors=colors) - - self._apply_scitex_postprocess(method_name, result) - - return result - - def hist( - self, - x: ArrayLike, - *, - bins: Union[int, str, ArrayLike] = 10, - range: Optional[Tuple[float, float]] = None, - align_bins: bool = True, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple[np.ndarray, np.ndarray, "BarContainer"]: - """Plot a histogram with optional bin alignment across multiple histograms. - - Parameters - ---------- - x : array-like - Input data for the histogram. - bins : int, str, or array-like, default 10 - Number of bins, binning strategy ('auto', 'fd', etc.), or bin edges. - range : tuple of float, optional - Lower and upper range of the bins. If None, uses data range. - align_bins : bool, default True - When True, aligns bins across multiple histograms on the same axes. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `matplotlib.axes.Axes.hist`. - - Returns - ------- - tuple - (counts, bin_edges, patches) from matplotlib hist. - - See Also - -------- - sns_histplot : DataFrame-based histogram with KDE support. - - Examples - -------- - >>> ax.hist(data, bins=20, density=True) - >>> ax.hist(data, bins='auto', alpha=0.7, label='Group A') - """ - method_name = "hist" - - axis_id = str(hash(self._axis_mpl)) - hist_id = id if id is not None else str(self.id) - - if align_bins: - from .....plt.utils import histogram_bin_manager - - bins, range = histogram_bin_manager.register_histogram( - axis_id, hist_id, x, bins, range - ) - - with self._no_tracking(): - hist_data = self._axis_mpl.hist(x, bins=bins, range=range, **kwargs) - - tracked_dict = { - "args": (x,), - "hist_result": (hist_data[0], hist_data[1]), - "bins": bins, - "range": range, - } - self._track(track, id, method_name, tracked_dict, kwargs) - self._apply_scitex_postprocess(method_name, hist_data) - - return hist_data - - def stx_violin( - self, - data: Union[pd.DataFrame, List, ArrayLike], - *, - x: Optional[str] = None, - y: Optional[str] = None, - hue: Optional[str] = None, - labels: Optional[List[str]] = None, - colors: Optional[List[str]] = None, - half: bool = False, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> "Axes": - """Create a violin plot with SciTeX styling and tracking. - - Parameters - ---------- - data : DataFrame, list, or array-like - Data for the violin plot. Can be: - - List of arrays (one per violin) - - DataFrame with columns specified by x, y, hue - x : str, optional - Column name for x-axis grouping (DataFrame input). - y : str, optional - Column name for y-axis values (DataFrame input). - hue : str, optional - Column name for color grouping (DataFrame input). - labels : list of str, optional - Labels for each violin (list input). - colors : list of str, optional - Colors for each violin. - half : bool, default False - If True, draw half-violins (useful for paired comparisons). - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the violin function. - - Returns - ------- - Axes - The axes with the violin plot. - - See Also - -------- - stx_violinplot : Alias for this method. - sns_violinplot : DataFrame-based violin plot. - stx_box : Boxplot alternative. - - Examples - -------- - >>> ax.stx_violin([data1, data2], labels=['A', 'B']) - >>> ax.stx_violin(df, x='group', y='value', hue='category') - """ - method_name = "stx_violin" - - with self._no_tracking(): - if isinstance(data, list) and all( - isinstance(item, (list, np.ndarray)) for item in data - ): - self._axis_mpl = self._get_ax_module().stx_violin( - self._axis_mpl, - values_list=data, - labels=labels, - colors=colors, - half=half, - **kwargs, - ) - else: - self._axis_mpl = self._get_ax_module().stx_violin( - self._axis_mpl, - data=data, - x=x, - y=y, - hue=hue, - half=half, - **kwargs, - ) - - tracked_dict = { - "data": data, - "x": x, - "y": y, - "hue": hue, - "half": half, - "labels": labels, - "colors": colors, - } - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl - - def stx_line( - self, - y: ArrayLike, - *, - x: Optional[ArrayLike] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple["Axes", pd.DataFrame]: - """Plot a simple line with SciTeX styling. - - Parameters - ---------- - y : array-like - Y values for the line. - x : array-like, optional - X values for the line. If None, uses integer indices. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the line plot function. - - Returns - ------- - tuple - (Axes, DataFrame) - The axes and a DataFrame with the plotted data. - - See Also - -------- - stx_mean_std : Line with standard deviation shading. - stx_shaded_line : Line with custom shaded region. - sns_lineplot : DataFrame-based line plot. - - Examples - -------- - >>> ax.stx_line(y_values) - >>> ax.stx_line(y, x=x, label='Series A', color='blue') - """ - method_name = "stx_line" - - with self._no_tracking(): - self._axis_mpl, plot_df = self._get_ax_module().stx_line( - self._axis_mpl, y, xx=x, **kwargs - ) - - tracked_dict = {"plot_df": plot_df} - self._track(track, id, method_name, tracked_dict, kwargs) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl, plot_df - - def stx_mean_std( - self, - data: ArrayLike, - *, - x: Optional[ArrayLike] = None, - sd: float = 1.0, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple["Axes", pd.DataFrame]: - """Plot mean line with standard deviation shading. - - Parameters - ---------- - data : array-like - 2D array where each row is an observation and columns are time points. - x : array-like, optional - X values. If None, uses integer indices. - sd : float, default 1.0 - Number of standard deviations for the shaded region. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the plot function. - - Returns - ------- - tuple - (Axes, DataFrame) - The axes and a DataFrame with mean, upper, lower. - - See Also - -------- - stx_mean_ci : Mean with confidence interval. - stx_median_iqr : Median with interquartile range. - stx_shaded_line : Custom shaded line. - - Examples - -------- - >>> ax.stx_mean_std(data_2d, sd=2, label='Mean±2SD') - """ - method_name = "stx_mean_std" - - with self._no_tracking(): - self._axis_mpl, plot_df = self._get_ax_module().stx_mean_std( - self._axis_mpl, data, xx=x, sd=sd, **kwargs - ) - - tracked_dict = {"plot_df": plot_df} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl, plot_df - - def stx_mean_ci( - self, - data: ArrayLike, - *, - x: Optional[ArrayLike] = None, - ci: float = 95.0, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple["Axes", pd.DataFrame]: - """Plot mean line with confidence interval shading. - - Parameters - ---------- - data : array-like - 2D array where each row is an observation and columns are time points. - x : array-like, optional - X values. If None, uses integer indices. - ci : float, default 95.0 - Confidence interval percentage (e.g., 95 for 95% CI). - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the plot function. - - Returns - ------- - tuple - (Axes, DataFrame) - The axes and a DataFrame with mean, upper, lower. - - See Also - -------- - stx_mean_std : Mean with standard deviation. - stx_median_iqr : Median with interquartile range. - - Examples - -------- - >>> ax.stx_mean_ci(data_2d, ci=99, label='Mean±99%CI') - """ - method_name = "stx_mean_ci" - - with self._no_tracking(): - self._axis_mpl, plot_df = self._get_ax_module().stx_mean_ci( - self._axis_mpl, data, xx=x, perc=ci, **kwargs - ) - - tracked_dict = {"plot_df": plot_df} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl, plot_df - - def stx_median_iqr( - self, - data: ArrayLike, - *, - x: Optional[ArrayLike] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple["Axes", pd.DataFrame]: - """Plot median line with interquartile range shading. - - Parameters - ---------- - data : array-like - 2D array where each row is an observation and columns are time points. - x : array-like, optional - X values. If None, uses integer indices. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the plot function. - - Returns - ------- - tuple - (Axes, DataFrame) - The axes and a DataFrame with median, Q1, Q3. - - See Also - -------- - stx_mean_std : Mean with standard deviation. - stx_mean_ci : Mean with confidence interval. - - Examples - -------- - >>> ax.stx_median_iqr(data_2d, label='Median±IQR') - """ - method_name = "stx_median_iqr" - - with self._no_tracking(): - self._axis_mpl, plot_df = self._get_ax_module().stx_median_iqr( - self._axis_mpl, data, xx=x, **kwargs - ) - - tracked_dict = {"plot_df": plot_df} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl, plot_df - - def stx_shaded_line( - self, - x: ArrayLike, - y_lower: ArrayLike, - y_middle: ArrayLike, - y_upper: ArrayLike, - *, - color: Optional[Union[str, List[str]]] = None, - label: Optional[Union[str, List[str]]] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> Tuple["Axes", pd.DataFrame]: - """Plot a line with shaded area between lower and upper bounds. - - Parameters - ---------- - x : array-like - X coordinates. - y_lower : array-like - Lower bound of the shaded region. - y_middle : array-like - Center line values. - y_upper : array-like - Upper bound of the shaded region. - color : str or list of str, optional - Color(s) for the line and shading. - label : str or list of str, optional - Label(s) for the legend. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the plot function. - - Returns - ------- - tuple - (Axes, DataFrame) - The axes and a DataFrame with the plotted data. - - See Also - -------- - stx_mean_std : Mean with standard deviation. - stx_fill_between : Simple fill between curves. - - Examples - -------- - >>> ax.stx_shaded_line(x, lower, mean, upper, color='blue', label='Result') - """ - method_name = "stx_shaded_line" - - with self._no_tracking(): - self._axis_mpl, plot_df = self._get_ax_module().stx_shaded_line( - self._axis_mpl, - x, - y_lower, - y_middle, - y_upper, - color=color, - label=label, - **kwargs, - ) - - tracked_dict = {"plot_df": plot_df} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name) - - return self._axis_mpl, plot_df - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_stx_aliases.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_stx_aliases.py deleted file mode 100755 index c8aafe960..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_stx_aliases.py +++ /dev/null @@ -1,527 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: _stx_aliases.py - stx_ aliases for standard matplotlib methods - -"""stx_ prefixed aliases for standard matplotlib methods with tracking support.""" - -import os -from typing import List, Optional, Sequence, Union - -import numpy as np -import pandas as pd -from matplotlib.collections import PathCollection -from matplotlib.container import BarContainer -from matplotlib.contour import QuadContourSet -from matplotlib.image import AxesImage - -from scitex.types import ArrayLike - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - - -class StxAliasesMixin: - """Mixin providing stx_ aliases for standard matplotlib methods. - - These methods wrap standard matplotlib plotting functions with: - - SciTeX styling applied automatically - - Data tracking for reproducibility - - Sample size annotations in labels - """ - - def stx_bar( - self, - x: ArrayLike, - height: ArrayLike, - *, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> BarContainer: - """Create a bar plot with SciTeX styling and tracking. - - Parameters - ---------- - x : array-like - X coordinates of the bars. - height : array-like - Heights of the bars. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `matplotlib.axes.Axes.bar`. - - Returns - ------- - BarContainer - Container with all the bars. - - See Also - -------- - stx_barh : Horizontal bar plot. - mpl_bar : Raw matplotlib bar without styling. - - Examples - -------- - >>> ax.stx_bar([1, 2, 3], [4, 5, 6]) - >>> ax.stx_bar(x, height, label="Group A", color="blue") - """ - method_name = "stx_bar" - - if kwargs.get("label"): - n_samples = len(x) - kwargs["label"] = f"{kwargs['label']} ($n$={n_samples})" - - with self._no_tracking(): - result = self._axis_mpl.bar(x, height, **kwargs) - - tracked_dict = {"bar_df": pd.DataFrame({"x": x, "height": height})} - self._track(track, id, method_name, tracked_dict, None) - - from scitex.plt.ax import style_barplot - - style_barplot(result) - - self._apply_scitex_postprocess(method_name, result) - - return result - - def stx_barh( - self, - y: ArrayLike, - width: ArrayLike, - *, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> BarContainer: - """Create a horizontal bar plot with SciTeX styling and tracking. - - Parameters - ---------- - y : array-like - Y coordinates of the bars. - width : array-like - Widths of the bars. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `matplotlib.axes.Axes.barh`. - - Returns - ------- - BarContainer - Container with all the bars. - - See Also - -------- - stx_bar : Vertical bar plot. - mpl_barh : Raw matplotlib barh without styling. - - Examples - -------- - >>> ax.stx_barh([1, 2, 3], [4, 5, 6]) - """ - method_name = "stx_barh" - - if kwargs.get("label"): - n_samples = len(y) - kwargs["label"] = f"{kwargs['label']} ($n$={n_samples})" - - with self._no_tracking(): - result = self._axis_mpl.barh(y, width, **kwargs) - - tracked_dict = {"barh_df": pd.DataFrame({"y": y, "width": width})} - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name, result) - - return result - - def stx_scatter( - self, - x: ArrayLike, - y: ArrayLike, - *, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> PathCollection: - """Create a scatter plot with SciTeX styling and tracking. - - Parameters - ---------- - x : array-like - X coordinates of the data points. - y : array-like - Y coordinates of the data points. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `matplotlib.axes.Axes.scatter`. - - Returns - ------- - PathCollection - Collection of scatter points. - - See Also - -------- - sns_scatterplot : DataFrame-based scatter plot. - mpl_scatter : Raw matplotlib scatter without styling. - - Examples - -------- - >>> ax.stx_scatter(x, y, label="Data", s=50) - """ - method_name = "stx_scatter" - - if kwargs.get("label"): - n_samples = len(x) - kwargs["label"] = f"{kwargs['label']} ($n$={n_samples})" - - with self._no_tracking(): - result = self._axis_mpl.scatter(x, y, **kwargs) - - tracked_dict = {"scatter_df": pd.DataFrame({"x": x, "y": y})} - self._track(track, id, method_name, tracked_dict, None) - - from scitex.plt.ax import style_scatter - - style_scatter(result) - - self._apply_scitex_postprocess(method_name, result) - - return result - - def stx_errorbar( - self, - x: ArrayLike, - y: ArrayLike, - *, - yerr: Optional[ArrayLike] = None, - xerr: Optional[ArrayLike] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create an error bar plot with SciTeX styling and tracking. - - Parameters - ---------- - x : array-like - X coordinates of the data points. - y : array-like - Y coordinates of the data points. - yerr : array-like, optional - Error values for y-axis (symmetric or asymmetric). - xerr : array-like, optional - Error values for x-axis (symmetric or asymmetric). - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `matplotlib.axes.Axes.errorbar`. - - Returns - ------- - ErrorbarContainer - Container with the plotted errorbar lines. - - See Also - -------- - stx_mean_std : Mean line with standard deviation shading. - stx_mean_ci : Mean line with confidence interval shading. - - Examples - -------- - >>> ax.stx_errorbar(x, y, yerr=std, fmt='o-') - """ - method_name = "stx_errorbar" - - if kwargs.get("label"): - n_samples = len(x) - kwargs["label"] = f"{kwargs['label']} ($n$={n_samples})" - - with self._no_tracking(): - result = self._axis_mpl.errorbar(x, y, yerr=yerr, xerr=xerr, **kwargs) - - df_dict = {"x": x, "y": y} - if yerr is not None: - df_dict["yerr"] = yerr - if xerr is not None: - df_dict["xerr"] = xerr - tracked_dict = {"errorbar_df": pd.DataFrame(df_dict)} - self._track(track, id, method_name, tracked_dict, None) - - from scitex.plt.ax import style_errorbar - - style_errorbar(result) - - self._apply_scitex_postprocess(method_name, result) - - return result - - def stx_fill_between( - self, - x: ArrayLike, - y1: ArrayLike, - y2: Union[float, ArrayLike] = 0, - *, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Fill the area between two curves with SciTeX styling and tracking. - - Parameters - ---------- - x : array-like - X coordinates for the fill region. - y1 : array-like - First y-boundary curve. - y2 : float or array-like, default 0 - Second y-boundary curve. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `matplotlib.axes.Axes.fill_between`. - - Returns - ------- - PolyCollection - Collection representing the filled area. - - See Also - -------- - stx_shaded_line : Line plot with shaded confidence region. - - Examples - -------- - >>> ax.stx_fill_between(x, y_lower, y_upper, alpha=0.3) - """ - method_name = "stx_fill_between" - - with self._no_tracking(): - result = self._axis_mpl.fill_between(x, y1, y2, **kwargs) - - tracked_dict = { - "fill_between_df": pd.DataFrame( - { - "x": x, - "y1": y1, - "y2": y2 if hasattr(y2, "__len__") else [y2] * len(x), - } - ) - } - self._track(track, id, method_name, tracked_dict, None) - self._apply_scitex_postprocess(method_name, result) - - return result - - def stx_contour( - self, - X: ArrayLike, - Y: ArrayLike, - Z: ArrayLike, - *, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> QuadContourSet: - """Create a contour plot with SciTeX styling and tracking. - - Parameters - ---------- - X : array-like - X coordinates of the grid (2D array or 1D for meshgrid). - Y : array-like - Y coordinates of the grid (2D array or 1D for meshgrid). - Z : array-like - Values at each grid point (2D array). - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `matplotlib.axes.Axes.contour`. - - Returns - ------- - QuadContourSet - The contour set object. - - See Also - -------- - stx_imshow : Display data as an image. - mpl_contour : Raw matplotlib contour without styling. - - Examples - -------- - >>> ax.stx_contour(X, Y, Z, levels=10) - """ - method_name = "stx_contour" - - with self._no_tracking(): - result = self._axis_mpl.contour(X, Y, Z, **kwargs) - - tracked_dict = { - "contour_df": pd.DataFrame( - {"X": np.ravel(X), "Y": np.ravel(Y), "Z": np.ravel(Z)} - ) - } - self._track(track, id, method_name, tracked_dict, None) - - self._apply_scitex_postprocess(method_name, result) - - return result - - def stx_imshow( - self, - data: ArrayLike, - *, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> AxesImage: - """Display data as an image with SciTeX styling and tracking. - - Parameters - ---------- - data : array-like - Image data (2D or 3D array for RGB/RGBA). - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `matplotlib.axes.Axes.imshow`. - - Returns - ------- - AxesImage - The image object. - - See Also - -------- - stx_image : Scientific image display with colorbar. - mpl_imshow : Raw matplotlib imshow without styling. - - Examples - -------- - >>> ax.stx_imshow(image_array, cmap='viridis') - """ - method_name = "stx_imshow" - - with self._no_tracking(): - result = self._axis_mpl.imshow(data, **kwargs) - - if hasattr(data, "shape") and len(data.shape) == 2: - n_rows, n_cols = data.shape - df = pd.DataFrame(data, columns=[f"col_{i}" for i in range(n_cols)]) - else: - df = pd.DataFrame(data) - tracked_dict = {"imshow_df": df} - self._track(track, id, method_name, tracked_dict, None) - - self._apply_scitex_postprocess(method_name, result) - - return result - - def stx_boxplot( - self, - data: Union[ArrayLike, Sequence[ArrayLike]], - *, - colors: Optional[List[str]] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> dict: - """Create a boxplot with SciTeX styling and tracking. - - This is an alias for :meth:`stx_box`. - - Parameters - ---------- - data : array-like or sequence of array-like - Data for the boxplot. Can be a single array or list of arrays. - colors : list of str, optional - Colors for each box. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `matplotlib.axes.Axes.boxplot`. - - Returns - ------- - dict - Dictionary mapping component names to artists. - - See Also - -------- - stx_box : Primary boxplot method. - sns_boxplot : DataFrame-based boxplot. - stx_violin : Violin plot alternative. - - Examples - -------- - >>> ax.stx_boxplot([data1, data2, data3], labels=['A', 'B', 'C']) - """ - return self.stx_box(data, colors=colors, track=track, id=id, **kwargs) - - def stx_violinplot( - self, - data: Union[ArrayLike, Sequence[ArrayLike]], - *, - colors: Optional[List[str]] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ) -> "Axes": - """Create a violin plot with SciTeX styling and tracking. - - This is an alias for :meth:`stx_violin`. - - Parameters - ---------- - data : array-like or sequence of array-like - Data for the violin plot. Can be a single array or list of arrays. - colors : list of str, optional - Colors for each violin. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the violin plot function. - - Returns - ------- - Axes - The axes with the violin plot. - - See Also - -------- - stx_violin : Primary violin plot method. - sns_violinplot : DataFrame-based violin plot. - stx_box : Boxplot alternative. - - Examples - -------- - >>> ax.stx_violinplot([data1, data2], labels=['A', 'B']) - """ - return self.stx_violin(data, colors=colors, track=track, id=id, **kwargs) - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_RawMatplotlibMixin.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_RawMatplotlibMixin.py deleted file mode 100755 index bc1d22c7d..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_RawMatplotlibMixin.py +++ /dev/null @@ -1,349 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_RawMatplotlibMixin.py - -""" -Matplotlib aliases (mpl_xxx) for explicit matplotlib-style API. - -Provides consistent naming convention: -- stx_xxx: scitex-specific methods (ArrayLike input, tracked) -- sns_xxx: seaborn wrappers (DataFrame input, tracked) -- mpl_xxx: matplotlib methods (matplotlib-style input, tracked) - -All three API layers track data for reproducibility. - -Usage: - ax.stx_line(y) # ArrayLike input - ax.sns_boxplot(data=df, x="group", y="value") # DataFrame input - ax.mpl_plot(x, y) # matplotlib-style input - ax.plot(x, y) # Same as mpl_plot -""" - -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - - -class RawMatplotlibMixin: - """Mixin providing mpl_xxx aliases for matplotlib-style API. - - These methods are identical to calling ax.plot(), ax.scatter(), etc. - They go through SciTeX's __getattr__ wrapper and are fully tracked. - - The mpl_* prefix provides: - - Explicit naming convention (mpl_* vs stx_* vs sns_*) - - Programmatic access via MPL_METHODS registry - - Same tracking and styling as regular matplotlib calls - """ - - # ========================================================================= - # Helper to call through __getattr__ wrapper (enables tracking) - # ========================================================================= - def _mpl_call(self, method_name, *args, **kwargs): - """Call matplotlib method through __getattr__ wrapper for tracking.""" - # Use object.__getattribute__ to get the __getattr__ from AxisWrapper - # Then call it with the method name to get the tracked wrapper - wrapper_class = type(self) - # Walk up MRO to find __getattr__ in AxisWrapper - for cls in wrapper_class.__mro__: - if "__getattr__" in cls.__dict__: - return cls.__getattr__(self, method_name)(*args, **kwargs) - # Fallback to direct call if no __getattr__ found - return getattr(self._axes_mpl, method_name)(*args, **kwargs) - - # ========================================================================= - # Line plots - # ========================================================================= - def mpl_plot(self, *args, **kwargs): - """Matplotlib plot() - tracked, identical to ax.plot().""" - return self._mpl_call("plot", *args, **kwargs) - - def mpl_step(self, *args, **kwargs): - """Matplotlib step() - tracked, identical to ax.step().""" - return self._mpl_call("step", *args, **kwargs) - - def mpl_stem(self, *args, **kwargs): - """Matplotlib stem() - tracked, identical to ax.stem().""" - return self._mpl_call("stem", *args, **kwargs) - - # ========================================================================= - # Scatter plots - # ========================================================================= - def mpl_scatter(self, *args, **kwargs): - """Matplotlib scatter() - tracked, identical to ax.scatter().""" - return self._mpl_call("scatter", *args, **kwargs) - - # ========================================================================= - # Bar plots - # ========================================================================= - def mpl_bar(self, *args, **kwargs): - """Matplotlib bar() - tracked, identical to ax.bar().""" - return self._mpl_call("bar", *args, **kwargs) - - def mpl_barh(self, *args, **kwargs): - """Matplotlib barh() - tracked, identical to ax.barh().""" - return self._mpl_call("barh", *args, **kwargs) - - def mpl_bar3d(self, *args, **kwargs): - """Matplotlib bar3d() (3D axes) - tracked.""" - return self._mpl_call("bar3d", *args, **kwargs) - - # ========================================================================= - # Histograms - # ========================================================================= - def mpl_hist(self, *args, **kwargs): - """Matplotlib hist() - tracked, identical to ax.hist().""" - return self._mpl_call("hist", *args, **kwargs) - - def mpl_hist2d(self, *args, **kwargs): - """Matplotlib hist2d() - tracked, identical to ax.hist2d().""" - return self._mpl_call("hist2d", *args, **kwargs) - - def mpl_hexbin(self, *args, **kwargs): - """Matplotlib hexbin() - tracked, identical to ax.hexbin().""" - return self._mpl_call("hexbin", *args, **kwargs) - - # ========================================================================= - # Statistical plots - # ========================================================================= - def mpl_boxplot(self, *args, **kwargs): - """Matplotlib boxplot() - tracked, identical to ax.boxplot().""" - return self._mpl_call("boxplot", *args, **kwargs) - - def mpl_violinplot(self, *args, **kwargs): - """Matplotlib violinplot() - tracked, identical to ax.violinplot().""" - return self._mpl_call("violinplot", *args, **kwargs) - - def mpl_errorbar(self, *args, **kwargs): - """Matplotlib errorbar() - tracked, identical to ax.errorbar().""" - return self._mpl_call("errorbar", *args, **kwargs) - - def mpl_eventplot(self, *args, **kwargs): - """Matplotlib eventplot() - tracked, identical to ax.eventplot().""" - return self._mpl_call("eventplot", *args, **kwargs) - - # ========================================================================= - # Fill and area plots - # ========================================================================= - def mpl_fill(self, *args, **kwargs): - """Matplotlib fill() - tracked, identical to ax.fill().""" - return self._mpl_call("fill", *args, **kwargs) - - def mpl_fill_between(self, *args, **kwargs): - """Matplotlib fill_between() - tracked, identical to ax.fill_between().""" - return self._mpl_call("fill_between", *args, **kwargs) - - def mpl_fill_betweenx(self, *args, **kwargs): - """Matplotlib fill_betweenx() - tracked, identical to ax.fill_betweenx().""" - return self._mpl_call("fill_betweenx", *args, **kwargs) - - def mpl_stackplot(self, *args, **kwargs): - """Matplotlib stackplot() - tracked, identical to ax.stackplot().""" - return self._mpl_call("stackplot", *args, **kwargs) - - # ========================================================================= - # Contour and heatmap plots - # ========================================================================= - def mpl_contour(self, *args, **kwargs): - """Matplotlib contour() - tracked, identical to ax.contour().""" - return self._mpl_call("contour", *args, **kwargs) - - def mpl_contourf(self, *args, **kwargs): - """Matplotlib contourf() - tracked, identical to ax.contourf().""" - return self._mpl_call("contourf", *args, **kwargs) - - def mpl_imshow(self, *args, **kwargs): - """Matplotlib imshow() - tracked, identical to ax.imshow().""" - return self._mpl_call("imshow", *args, **kwargs) - - def mpl_pcolormesh(self, *args, **kwargs): - """Matplotlib pcolormesh() - tracked, identical to ax.pcolormesh().""" - return self._mpl_call("pcolormesh", *args, **kwargs) - - def mpl_pcolor(self, *args, **kwargs): - """Matplotlib pcolor() - tracked, identical to ax.pcolor().""" - return self._mpl_call("pcolor", *args, **kwargs) - - def mpl_matshow(self, *args, **kwargs): - """Matplotlib matshow() - tracked, identical to ax.matshow().""" - return self._mpl_call("matshow", *args, **kwargs) - - # ========================================================================= - # Vector field plots - # ========================================================================= - def mpl_quiver(self, *args, **kwargs): - """Matplotlib quiver() - tracked, identical to ax.quiver().""" - return self._mpl_call("quiver", *args, **kwargs) - - def mpl_streamplot(self, *args, **kwargs): - """Matplotlib streamplot() - tracked, identical to ax.streamplot().""" - return self._mpl_call("streamplot", *args, **kwargs) - - def mpl_barbs(self, *args, **kwargs): - """Matplotlib barbs() - tracked, identical to ax.barbs().""" - return self._mpl_call("barbs", *args, **kwargs) - - # ========================================================================= - # Pie and polar plots - # ========================================================================= - def mpl_pie(self, *args, **kwargs): - """Matplotlib pie() - tracked, identical to ax.pie().""" - return self._mpl_call("pie", *args, **kwargs) - - # ========================================================================= - # Text and annotations - # ========================================================================= - def mpl_text(self, *args, **kwargs): - """Matplotlib text() - tracked, identical to ax.text().""" - return self._mpl_call("text", *args, **kwargs) - - def mpl_annotate(self, *args, **kwargs): - """Matplotlib annotate() - tracked, identical to ax.annotate().""" - return self._mpl_call("annotate", *args, **kwargs) - - # ========================================================================= - # Lines and spans - # ========================================================================= - def mpl_axhline(self, *args, **kwargs): - """Matplotlib axhline() - tracked, identical to ax.axhline().""" - return self._mpl_call("axhline", *args, **kwargs) - - def mpl_axvline(self, *args, **kwargs): - """Matplotlib axvline() - tracked, identical to ax.axvline().""" - return self._mpl_call("axvline", *args, **kwargs) - - def mpl_axhspan(self, *args, **kwargs): - """Matplotlib axhspan() - tracked, identical to ax.axhspan().""" - return self._mpl_call("axhspan", *args, **kwargs) - - def mpl_axvspan(self, *args, **kwargs): - """Matplotlib axvspan() - tracked, identical to ax.axvspan().""" - return self._mpl_call("axvspan", *args, **kwargs) - - # ========================================================================= - # Patches and shapes - # ========================================================================= - def mpl_add_patch(self, patch, **kwargs): - """Matplotlib add_patch() - tracked, identical to ax.add_patch().""" - return self._mpl_call("add_patch", patch, **kwargs) - - def mpl_add_artist(self, artist, **kwargs): - """Matplotlib add_artist() - tracked, identical to ax.add_artist().""" - return self._mpl_call("add_artist", artist, **kwargs) - - def mpl_add_collection(self, collection, **kwargs): - """Matplotlib add_collection() - tracked, identical to ax.add_collection().""" - return self._mpl_call("add_collection", collection, **kwargs) - - # ========================================================================= - # 3D plotting (if available) - # ========================================================================= - def mpl_plot_surface(self, *args, **kwargs): - """Matplotlib plot_surface() (3D axes) - tracked.""" - return self._mpl_call("plot_surface", *args, **kwargs) - - def mpl_plot_wireframe(self, *args, **kwargs): - """Matplotlib plot_wireframe() (3D axes) - tracked.""" - return self._mpl_call("plot_wireframe", *args, **kwargs) - - def mpl_contour3D(self, *args, **kwargs): - """Matplotlib contour3D() (3D axes) - tracked.""" - return self._mpl_call("contour3D", *args, **kwargs) - - def mpl_scatter3D(self, *args, **kwargs): - """Matplotlib scatter3D() (3D axes) - tracked.""" - return self._mpl_call("scatter3D", *args, **kwargs) - - # ========================================================================= - # Utility method to get raw axes - # ========================================================================= - @property - def mpl_axes(self): - """Direct access to underlying matplotlib axes object.""" - return self._axes_mpl - - def mpl_raw(self, method_name, *args, **kwargs): - """Call any matplotlib method by name without scitex processing. - - Parameters - ---------- - method_name : str - Name of matplotlib axes method to call - *args, **kwargs - Arguments to pass to the method - - Returns - ------- - result - Result from matplotlib method - - Example - ------- - >>> ax.mpl_raw("tricontour", x, y, z, levels=10) - """ - method = getattr(self._axes_mpl, method_name) - return method(*args, **kwargs) - - -# Registry of mpl_xxx methods for programmatic access -MPL_METHODS = [ - # Line plots - "mpl_plot", - "mpl_step", - "mpl_stem", - # Scatter - "mpl_scatter", - # Bar - "mpl_bar", - "mpl_barh", - "mpl_bar3d", - # Histograms - "mpl_hist", - "mpl_hist2d", - "mpl_hexbin", - # Statistical - "mpl_boxplot", - "mpl_violinplot", - "mpl_errorbar", - "mpl_eventplot", - # Fill/area - "mpl_fill", - "mpl_fill_between", - "mpl_fill_betweenx", - "mpl_stackplot", - # Contour/heatmap - "mpl_contour", - "mpl_contourf", - "mpl_imshow", - "mpl_pcolormesh", - "mpl_pcolor", - "mpl_matshow", - # Vector fields - "mpl_quiver", - "mpl_streamplot", - "mpl_barbs", - # Pie - "mpl_pie", - # Text/annotations - "mpl_text", - "mpl_annotate", - # Lines/spans - "mpl_axhline", - "mpl_axvline", - "mpl_axhspan", - "mpl_axvspan", - # Patches - "mpl_add_patch", - "mpl_add_artist", - "mpl_add_collection", - # 3D - "mpl_plot_surface", - "mpl_plot_wireframe", - "mpl_contour3D", - "mpl_scatter3D", -] - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/__init__.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/__init__.py deleted file mode 100755 index 634f95729..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: __init__.py - SeabornMixin package - -""" -SeabornMixin - Modular seaborn integration mixin for AxisWrapper. - -This package provides seaborn plotting functionality: -- _base: Helper methods and data preparation -- _wrappers: Individual seaborn plot wrappers (sns_xxx) -""" - -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - -from ._base import SeabornBaseMixin, sns_copy_doc -from ._wrappers import SeabornWrappersMixin - - -class SeabornMixin(SeabornBaseMixin, SeabornWrappersMixin): - """Mixin class for seaborn plotting integration. - - Combines: - - SeabornBaseMixin: Helper methods for tracking and data preparation - - SeabornWrappersMixin: Individual sns_ prefixed seaborn wrappers - """ - - pass - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_base.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_base.py deleted file mode 100755 index b8bd5299e..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_base.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: _base.py - Base seaborn functionality - -"""Base seaborn mixin with helper methods for tracking and data preparation.""" - -import os -from functools import wraps - -import numpy as np -import pandas as pd -import seaborn as sns - -import scitex - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - - -def sns_copy_doc(func): - """Decorator to copy docstring from seaborn function.""" - - @wraps(func) - def wrapper(self, *args, **kwargs): - return func(self, *args, **kwargs) - - sns_method_name = func.__name__.split("sns_")[-1] - wrapper.__doc__ = getattr(sns, sns_method_name).__doc__ - return wrapper - - -class SeabornBaseMixin: - """Base mixin for seaborn integration with tracking support.""" - - def _sns_base( - self, method_name, *args, track=True, track_obj=None, id=None, **kwargs - ): - """Execute seaborn plot method with tracking support.""" - sns_method_name = method_name.split("sns_")[-1] - - with self._no_tracking(): - sns_plot_fn = getattr(sns, sns_method_name) - - if kwargs.get("hue_colors"): - kwargs = scitex.gen.alternate_kwarg( - kwargs, primary_key="palette", alternate_key="hue_colors" - ) - - import warnings - - from scitex import logging - - mpl_logger = logging.getLogger("matplotlib") - original_level = mpl_logger.level - mpl_logger.setLevel(logging.WARNING) - - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - message=".*categorical units.*parsable as floats or dates.*", - category=UserWarning, - ) - warnings.filterwarnings( - "ignore", - message=".*Using categorical units.*", - module="matplotlib.*", - ) - warnings.simplefilter("ignore", UserWarning) - - self._axis_mpl = sns_plot_fn(ax=self._axis_mpl, *args, **kwargs) - finally: - mpl_logger.setLevel(original_level) - - # Post-processing for histplot with kde=True - if sns_method_name == "histplot" and kwargs.get("kde", False): - from scitex.plt.utils import mm_to_pt - - kde_lw = mm_to_pt(0.2) - for line in self._axis_mpl.get_lines(): - line.set_linewidth(kde_lw) - line.set_color("black") - line.set_linestyle("--") - - # Post-processing for histplot alpha - if sns_method_name == "histplot" and "alpha" not in kwargs: - for patch in self._axis_mpl.patches: - patch.set_alpha(1.0) - - track_obj = track_obj if track_obj is not None else args - tracked_dict = { - "data": track_obj, - "args": args, - } - self._track(track, id, method_name, tracked_dict, kwargs) - - def _sns_base_xyhue(self, method_name, *args, track=True, id=None, **kwargs): - """Execute seaborn plot with x/y/hue data preparation.""" - df = kwargs.get("data") - x, y, hue = kwargs.get("x"), kwargs.get("y"), kwargs.get("hue") - - track_obj = self._sns_prepare_xyhue(df, x, y, hue) if df is not None else None - self._sns_base( - method_name, - *args, - track=track, - track_obj=track_obj, - id=id, - **kwargs, - ) - - def _sns_prepare_xyhue(self, data=None, x=None, y=None, hue=None, **kwargs): - """Prepare data for tracking based on x/y/hue configuration.""" - data = data.reset_index() - - if hue is not None: - if x is None and y is None: - return data - elif x is None: - agg_dict = {} - for hh in data[hue].unique(): - agg_dict[hh] = data.loc[data[hue] == hh, y] - df = scitex.pd.force_df(agg_dict) - return df - elif y is None: - df = pd.concat( - [data.loc[data[hue] == hh, x] for hh in data[hue].unique()], - axis=1, - ) - return df - else: - pivoted_data = data.pivot_table( - values=y, - index=data.index, - columns=[x, hue], - aggfunc="first", - ) - pivoted_data.columns = [ - f"{col[0]}-{col[1]}" for col in pivoted_data.columns - ] - return pivoted_data - else: - if x is None and y is None: - return data - elif x is None: - return data[[y]] - elif y is None: - return data[[x]] - else: - return data.pivot_table( - values=y, index=data.index, columns=x, aggfunc="first" - ) - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py deleted file mode 100755 index 36add3820..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py +++ /dev/null @@ -1,595 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2025-12-13 (ywatanabe)" -# File: _wrappers.py - Seaborn plot wrappers - -"""Seaborn plot wrappers with SciTeX integration.""" - -import os -from typing import Optional, Union - -import numpy as np -import pandas as pd -import seaborn as sns - -from scitex.types import ArrayLike - -from ._base import sns_copy_doc - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - - -class SeabornWrappersMixin: - """Mixin providing sns_ prefixed seaborn wrappers. - - All methods use the seaborn DataFrame-centric interface: - - data: DataFrame containing the data - - x, y: Column names for axes - - hue: Column name for color grouping - - These methods integrate with SciTeX tracking and styling. - """ - - def _get_ax_module(self): - """Lazy import ax module to avoid circular imports.""" - from .....plt import ax as ax_module - - return ax_module - - @sns_copy_doc - def sns_barplot( - self, - data: Optional[pd.DataFrame] = None, - *, - x: Optional[str] = None, - y: Optional[str] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a bar plot showing point estimates and error bars. - - Parameters - ---------- - data : DataFrame, optional - Input data structure. - x : str, optional - Column name for x-axis categories. - y : str, optional - Column name for y-axis values. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.barplot`. - - See Also - -------- - stx_bar : Array-based bar plot. - """ - self._sns_base_xyhue( - "sns_barplot", data=data, x=x, y=y, track=track, id=id, **kwargs - ) - - @sns_copy_doc - def sns_boxplot( - self, - data: Optional[pd.DataFrame] = None, - *, - x: Optional[str] = None, - y: Optional[str] = None, - strip: bool = False, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a box plot showing distributions with quartiles. - - Parameters - ---------- - data : DataFrame, optional - Input data structure. - x : str, optional - Column name for x-axis grouping. - y : str, optional - Column name for y-axis values. - strip : bool, default False - If True, overlay a stripplot showing individual points. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.boxplot`. - - See Also - -------- - stx_box : Array-based boxplot. - sns_violinplot : Violin plot alternative. - """ - self._sns_base_xyhue( - "sns_boxplot", data=data, x=x, y=y, track=track, id=id, **kwargs - ) - - # Post-processing: Style boxplot elements (0.2mm black lines) - from scitex.plt.utils import mm_to_pt - - lw_pt = mm_to_pt(0.2) - for line in self._axis_mpl.get_lines(): - line.set_linewidth(lw_pt) - line.set_color("black") - for patch in self._axis_mpl.patches: - patch.set_linewidth(lw_pt) - patch.set_edgecolor("black") - - if strip: - strip_kwargs = kwargs.copy() - strip_kwargs.pop("notch", None) - strip_kwargs.pop("whis", None) - self.sns_stripplot( - data=data, - x=x, - y=y, - track=False, - id=f"{id}_strip", - **strip_kwargs, - ) - - @sns_copy_doc - def sns_heatmap( - self, - data: Union[pd.DataFrame, ArrayLike], - *, - xyz: bool = False, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a heatmap from rectangular data. - - Parameters - ---------- - data : DataFrame or array-like - 2D dataset for the heatmap. - xyz : bool, default False - If True, convert data to XYZ format before plotting. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.heatmap`. - - See Also - -------- - stx_heatmap : Array-based annotated heatmap. - stx_image : Simple image display. - """ - import scitex - - method_name = "sns_heatmap" - df = data - if xyz: - df = scitex.pd.to_xyz(df) - self._sns_base(method_name, df, track=track, track_obj=df, id=id, **kwargs) - - @sns_copy_doc - def sns_histplot( - self, - data: Optional[pd.DataFrame] = None, - *, - x: Optional[str] = None, - y: Optional[str] = None, - bins: int = 10, - align_bins: bool = True, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a histogram with optional kernel density estimate. - - Parameters - ---------- - data : DataFrame, optional - Input data structure. - x : str, optional - Column name for x-axis values. - y : str, optional - Column name for y-axis values. - bins : int, default 10 - Number of histogram bins. - align_bins : bool, default True - Align bins across multiple histograms on same axes. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.histplot`. - Common options: kde, stat, element, hue. - - See Also - -------- - hist : Array-based histogram. - stx_kde : Kernel density estimate. - """ - method_name = "sns_histplot" - - plot_data = None - if data is not None and x is not None: - plot_data = ( - data[x].values - if hasattr(data, "columns") and x in data.columns - else None - ) - - axis_id = str(hash(self._axis_mpl)) - hist_id = id if id is not None else str(self.id) - range_value = kwargs.get("binrange", None) - - if align_bins and plot_data is not None: - from .....plt.utils import histogram_bin_manager - - bins_val, range_val = histogram_bin_manager.register_histogram( - axis_id, hist_id, plot_data, bins, range_value - ) - kwargs["bins"] = bins_val - if range_value is not None: - kwargs["binrange"] = range_val - - with self._no_tracking(): - sns_plot = sns.histplot(data=data, x=x, y=y, ax=self._axis_mpl, **kwargs) - - hist_result = None - if hasattr(sns_plot, "patches") and sns_plot.patches: - patches = sns_plot.patches - if patches: - counts = np.array([p.get_height() for p in patches]) - bin_edges = [] - for p in patches: - bin_edges.append(p.get_x()) - if patches: - bin_edges.append(patches[-1].get_x() + patches[-1].get_width()) - hist_result = (counts, np.array(bin_edges)) - - track_obj = self._sns_prepare_xyhue(data, x, y, kwargs.get("hue")) - tracked_dict = { - "data": track_obj, - "args": (data, x, y), - "hist_result": hist_result, - } - self._track(track, id, method_name, tracked_dict, kwargs) - - return sns_plot - - @sns_copy_doc - def sns_kdeplot( - self, - data: Optional[pd.DataFrame] = None, - *, - x: Optional[str] = None, - y: Optional[str] = None, - xlim: Optional[tuple] = None, - ylim: Optional[tuple] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a kernel density estimate plot. - - Parameters - ---------- - data : DataFrame, optional - Input data structure. - x : str, optional - Column name for x-axis values. - y : str, optional - Column name for y-axis values. - xlim : tuple, optional - Limits for x-axis KDE range. - ylim : tuple, optional - Limits for y-axis KDE range. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to the KDE function. - - See Also - -------- - stx_kde : Array-based KDE plot. - sns_histplot : Histogram with optional KDE. - """ - hue_col = kwargs.pop("hue", None) - - if hue_col: - hues = data[hue_col] - if x is not None: - lim = xlim - for hue in np.unique(hues): - _data = data.loc[hues == hue, x] - self.stx_kde(_data, xlim=lim, label=hue, id=hue, **kwargs) - if y is not None: - lim = ylim - for hue in np.unique(hues): - _data = data.loc[hues == hue, y] - self.stx_kde(_data, xlim=lim, label=hue, id=hue, **kwargs) - else: - if x is not None: - _data, lim = data[x], xlim - if y is not None: - _data, lim = data[y], ylim - self.stx_kde(_data, xlim=lim, **kwargs) - - @sns_copy_doc - def sns_pairplot( - self, - *args, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a grid of pairwise relationships in a dataset. - - Parameters - ---------- - *args - Positional arguments passed to `seaborn.pairplot`. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.pairplot`. - """ - self._sns_base("sns_pairplot", *args, track=track, id=id, **kwargs) - - @sns_copy_doc - def sns_scatterplot( - self, - data: Optional[pd.DataFrame] = None, - *, - x: Optional[str] = None, - y: Optional[str] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a scatter plot with semantic mappings. - - Parameters - ---------- - data : DataFrame, optional - Input data structure. - x : str, optional - Column name for x-axis values. - y : str, optional - Column name for y-axis values. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.scatterplot`. - Common options: hue, size, style. - - See Also - -------- - stx_scatter : Array-based scatter plot. - """ - self._sns_base_xyhue( - "sns_scatterplot", - data=data, - x=x, - y=y, - track=track, - id=id, - **kwargs, - ) - - @sns_copy_doc - def sns_lineplot( - self, - data: Optional[pd.DataFrame] = None, - *, - x: Optional[str] = None, - y: Optional[str] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a line plot with semantic mappings. - - Parameters - ---------- - data : DataFrame, optional - Input data structure. - x : str, optional - Column name for x-axis values. - y : str, optional - Column name for y-axis values. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.lineplot`. - Common options: hue, size, style, estimator. - - See Also - -------- - stx_line : Array-based line plot. - stx_mean_std : Line with uncertainty shading. - """ - self._sns_base_xyhue( - "sns_lineplot", - data=data, - x=x, - y=y, - track=track, - id=id, - **kwargs, - ) - - @sns_copy_doc - def sns_swarmplot( - self, - data: Optional[pd.DataFrame] = None, - *, - x: Optional[str] = None, - y: Optional[str] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a categorical scatter plot with non-overlapping points. - - Parameters - ---------- - data : DataFrame, optional - Input data structure. - x : str, optional - Column name for x-axis grouping. - y : str, optional - Column name for y-axis values. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.swarmplot`. - - See Also - -------- - sns_stripplot : Jittered categorical scatter. - sns_boxplot : Box plot for distributions. - """ - self._sns_base_xyhue( - "sns_swarmplot", data=data, x=x, y=y, track=track, id=id, **kwargs - ) - - @sns_copy_doc - def sns_stripplot( - self, - data: Optional[pd.DataFrame] = None, - *, - x: Optional[str] = None, - y: Optional[str] = None, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a categorical scatter plot with jittered points. - - Parameters - ---------- - data : DataFrame, optional - Input data structure. - x : str, optional - Column name for x-axis grouping. - y : str, optional - Column name for y-axis values. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.stripplot`. - - See Also - -------- - sns_swarmplot : Non-overlapping categorical scatter. - sns_boxplot : Often combined with stripplot. - """ - self._sns_base_xyhue( - "sns_stripplot", data=data, x=x, y=y, track=track, id=id, **kwargs - ) - - @sns_copy_doc - def sns_violinplot( - self, - data: Optional[pd.DataFrame] = None, - *, - x: Optional[str] = None, - y: Optional[str] = None, - half: bool = False, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a violin plot combining box plot with kernel density. - - Parameters - ---------- - data : DataFrame, optional - Input data structure. - x : str, optional - Column name for x-axis grouping. - y : str, optional - Column name for y-axis values. - half : bool, default False - If True, draw half-violins (useful for paired comparisons). - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.violinplot`. - - See Also - -------- - stx_violin : Array-based violin plot. - sns_boxplot : Box plot alternative. - """ - if half: - with self._no_tracking(): - self._axis_mpl = self._get_ax_module().plot_half_violin( - self._axis_mpl, data=data, x=x, y=y, **kwargs - ) - else: - self._sns_base_xyhue( - "sns_violinplot", - data=data, - x=x, - y=y, - track=track, - id=id, - **kwargs, - ) - - track_obj = self._sns_prepare_xyhue(data, x, y, kwargs.get("hue")) - self._track(track, id, "sns_violinplot", track_obj, kwargs) - - return self._axis_mpl - - @sns_copy_doc - def sns_jointplot( - self, - *args, - track: bool = True, - id: Optional[str] = None, - **kwargs, - ): - """Create a figure with joint and marginal distributions. - - Parameters - ---------- - *args - Positional arguments passed to `seaborn.jointplot`. - track : bool, default True - Enable data tracking for reproducibility. - id : str, optional - Unique identifier for this plot element. - **kwargs - Additional arguments passed to `seaborn.jointplot`. - - See Also - -------- - stx_scatter_hist : Array-based scatter with marginal histograms. - """ - self._sns_base("sns_jointplot", *args, track=track, id=id, **kwargs) - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py deleted file mode 100755 index f8bedac6b..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py +++ /dev/null @@ -1,199 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-04-30 18:40:59 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -""" -Functionality: - * Handles tracking and history management for matplotlib plot operations -Input: - * Plot method calls, their arguments, and tracking configuration -Output: - * Tracked plotting history and DataFrame export for analysis -Prerequisites: - * pandas, matplotlib -""" - -from contextlib import contextmanager - -import pandas as pd - -from .._export_as_csv import export_as_csv as _export_as_csv - - -class TrackingMixin: - """Mixin class for tracking matplotlib plotting operations. - - Example - ------- - >>> fig, ax = plt.subplots() - >>> ax.track = True - >>> ax.id = 0 - >>> ax._ax_history = OrderedDict() - >>> ax.plot([1, 2, 3], [4, 5, 6], id="plot1") - >>> print(ax.history) - {'plot1': ('plot1', 'plot', {'plot_df': DataFrame, ...}, {})} - """ - - def _track(self, track, id, method_name, tracked_dict, kwargs=None): - """Track plotting operation with auto-generated IDs. - - Args: - track: Whether to track this operation - id: Identifier for the plot (can be None) - method_name: Name of the plotting method - tracked_dict: Dictionary of tracked data - kwargs: Original keyword arguments - """ - # Extract id from kwargs and remove it before passing to matplotlib - if kwargs is not None and hasattr(kwargs, "get") and "id" in kwargs: - id = kwargs.pop("id") - - # Default kwargs to empty dict if None - if kwargs is None: - kwargs = {} - - if track is None: - track = self.track - - if track: - # Get axes position from _scitex_metadata if available - ax_row, ax_col = 0, 0 - if hasattr(self, "_axis_mpl") and hasattr( - self._axis_mpl, "_scitex_metadata" - ): - meta = self._axis_mpl._scitex_metadata - if "position_in_grid" in meta: - ax_row, ax_col = meta["position_in_grid"] - - # If no ID was provided, generate one using method_name + counter - if id is None: - # Initialize method counters if not exist - if not hasattr(self, "_method_counters"): - self._method_counters = {} - - # Get current counter value for this method and increment it - counter = self._method_counters.get(method_name, 0) - self._method_counters[method_name] = counter + 1 - - # Format ID with axes position: ax_RC_method_counter - # e.g., ax_00_plot_0, ax_01_bar_1, ax_10_scatter_2 - id = f"ax_{ax_row}{ax_col}_{method_name}_{counter}" - else: - # User-provided ID - prepend axes position - # e.g., ax_00_sine, ax_01_my-data - id = f"ax_{ax_row}{ax_col}_{id}" - - # For backward compatibility - self.id += 1 - - # Store the tracking record - self._ax_history[id] = (id, method_name, tracked_dict, kwargs) - - @contextmanager - def _no_tracking(self): - """Context manager to temporarily disable tracking.""" - original_track = self.track - self.track = False - try: - yield - finally: - self.track = original_track - - @property - def history(self): - return {k: self._ax_history[k] for k in self._ax_history} - - @property - def flat(self): - if isinstance(self._axis_mpl, list): - return self._axis_mpl - else: - return [self._axis_mpl] - - def reset_history(self): - self._ax_history = {} - - def export_as_csv(self): - """ - Export tracked plotting data to a DataFrame. - """ - df = _export_as_csv(self.history) - - return df if df is not None else pd.DataFrame() - - def export_as_csv_for_sigmaplot(self, include_visual_params=True): - """ - Export tracked plotting data to a DataFrame in SigmaPlot format. - - Parameters - ---------- - include_visual_params : bool, optional - Whether to include visual parameters (xlabel, ylabel, scales, etc.) - at the top of the CSV. Default is True. - - Returns - ------- - pandas.DataFrame - DataFrame containing the plotted data formatted for SigmaPlot. - - Examples - -------- - >>> fig, ax = scitex.plt.subplots() - >>> ax.plot([1, 2, 3], [4, 5, 6]) - >>> ax.scatter([1, 2, 3], [7, 8, 9]) - >>> df = ax.export_as_csv_for_sigmaplot() - >>> df.to_csv('for_sigmaplot.csv', index=False) - """ - df = _export_as_csv(self.history) - - return df if df is not None else pd.DataFrame() - - # def _track( - # self, - # track: Optional[bool], - # plot_id: Optional[str], - # method_name: str, - # tracked_dict: Any, - # kwargs: Dict[str, Any] - # ) -> None: - # """Tracks plotting operation if tracking is enabled.""" - # if track is None: - # track = self.track - # if track: - # plot_id = plot_id if plot_id is not None else self.id - # self.id += 1 - # self._ax_history[plot_id] = (plot_id, method_name, tracked_dict, kwargs) - - # @contextmanager - # def _no_tracking(self) -> None: - # """Temporarily disables tracking within a context.""" - # original_track = self.track - # self.track = False - # try: - # yield - # finally: - # self.track = original_track - - # @property - # def history(self) -> Dict[str, Tuple]: - # """Returns the plotting history.""" - # return dict(self._ax_history) - - # def reset_history(self) -> None: - # """Clears the plotting history.""" - # self._ax_history = OrderedDict() - - # def export_as_csv(self) -> pd.DataFrame: - # """Converts plotting history to a SigmaPlot-compatible DataFrame.""" - # df = _export_as_csv(self.history) - # return df if df is not None else pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/_UnitAwareMixin.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/_UnitAwareMixin.py deleted file mode 100755 index e1c781dfa..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/_UnitAwareMixin.py +++ /dev/null @@ -1,449 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-08-01 10:35:00 (ywatanabe)" -# File: /home/ywatanabe/proj/SciTeX-Code/src/scitex/plt/_subplots/_AxisWrapperMixins/_UnitAwareMixin.py -# ---------------------------------------- - -""" -Unit-Aware Plotting Mixin -========================= - -This mixin adds unit handling capabilities to the AxisWrapper class, -ensuring scientific validity in plots. - -Features: -- Automatic unit tracking for axes -- Unit validation for data compatibility -- Automatic unit conversion -- Unit-aware axis labels -""" - -import re -from typing import Any, Dict, Optional, Tuple, Union - -import numpy as np - -import scitex.logging as logging -from scitex.logging import SciTeXError, UnitWarning -from scitex.logging import warn as _warn -from scitex.units import Q, Unit, Units - -# Valid dimensionless/special unit markers -_VALID_DIMENSIONLESS = { - "[-]", - "[a.u.]", - "[arb. units]", - "[dimensionless]", - "[1]", - "[A.U.]", -} - - -def _convert_to_negative_exponent(unit: str) -> str: - """Convert unit with / to negative exponent format. - - Examples: - m/s -> m·s⁻¹ - kg/m^2 -> kg·m⁻² - W/m^2/K -> W·m⁻²·K⁻¹ - """ - superscript = str.maketrans("0123456789-", "⁰¹²³⁴⁵⁶⁷⁸⁹⁻") - - parts = unit.split("/") - if len(parts) < 2: - return unit - - result = parts[0] - for part in parts[1:]: - exp_match = re.match(r"([a-zA-Z]+)\^?(\d+)?", part) - if exp_match: - base = exp_match.group(1) - exp = exp_match.group(2) or "1" - neg_exp = f"-{exp}".translate(superscript) - result += f"·{base}{neg_exp}" - else: - result += f"·{part}⁻¹" - - return result - - -def validate_axis_label(label: str, axis_name: str = "axis") -> str: - """Validate and warn about axis label units (educational for scientific standards). - - Checks for: - - Missing units - - Non-standard format (prefer [] over ()) - - Suggests ^-1 format over / - - Parameters - ---------- - label : str - Axis label to validate - axis_name : str - Name for warning messages (e.g., "X axis", "Y axis") - - Returns - ------- - str - Original label (warnings are educational, not auto-correcting) - """ - if not label: - return label - - # Check for units in brackets [] or parentheses () - has_square_brackets = bool(re.search(r"\[.*?\]", label)) - has_parentheses = bool(re.search(r"\(.*?\)", label)) - - unit_match_square = re.search(r"\[(.*?)\]", label) - unit_match_paren = re.search(r"\((.*?)\)", label) - - if not has_square_brackets and not has_parentheses: - _warn( - f"{axis_name} label '{label}' has no units. " - f"Consider: '{label} [unit]' or '{label} [-]' for dimensionless", - UnitWarning, - stacklevel=3, - ) - return label - - if has_parentheses and not has_square_brackets: - unit = unit_match_paren.group(1) if unit_match_paren else "" - suggested = re.sub(r"\((.*?)\)", f"[{unit}]", label) - _warn( - f"{axis_name} label '{label}' uses parentheses. " - f"SI convention prefers: '{suggested}'", - UnitWarning, - stacklevel=3, - ) - - unit_content = None - if unit_match_square: - unit_content = unit_match_square.group(1) - elif unit_match_paren: - unit_content = unit_match_paren.group(1) - - if unit_content and "/" in unit_content: - suggested_unit = _convert_to_negative_exponent(unit_content) - if suggested_unit != unit_content: - suggested_label = label.replace(f"[{unit_content}]", f"[{suggested_unit}]") - _warn( - f"{axis_name} uses '/' in units. Consider: '{suggested_label}'", - UnitWarning, - stacklevel=3, - ) - - return label - - -class UnitMismatchError(SciTeXError): - """Raised when units are incompatible for an operation.""" - - pass - - -class UnitAwareMixin: - """Mixin that adds unit awareness to plotting operations.""" - - def __init__(self): - """Initialize unit tracking.""" - self._x_unit: Optional[Unit] = None - self._y_unit: Optional[Unit] = None - self._z_unit: Optional[Unit] = None - self._unit_validation_enabled: bool = True - - def set_unit_validation(self, enabled: bool) -> None: - """Enable or disable unit validation.""" - self._unit_validation_enabled = enabled - - def set_x_unit(self, unit: Union[str, Unit]) -> None: - """Set the unit for the x-axis.""" - if isinstance(unit, str): - unit_obj = getattr(Units, unit, None) - if unit_obj is None: - raise ValueError(f"Unknown unit: {unit}") - unit = unit_obj - self._x_unit = unit - self._update_xlabel_with_unit() - - def set_y_unit(self, unit: Union[str, Unit]) -> None: - """Set the unit for the y-axis.""" - if isinstance(unit, str): - unit_obj = getattr(Units, unit, None) - if unit_obj is None: - raise ValueError(f"Unknown unit: {unit}") - unit = unit_obj - self._y_unit = unit - self._update_ylabel_with_unit() - - def set_z_unit(self, unit: Union[str, Unit]) -> None: - """Set the unit for the z-axis (for 3D plots).""" - if isinstance(unit, str): - unit_obj = getattr(Units, unit, None) - if unit_obj is None: - raise ValueError(f"Unknown unit: {unit}") - unit = unit_obj - self._z_unit = unit - self._update_zlabel_with_unit() - - def get_x_unit(self) -> Optional[Unit]: - """Get the current x-axis unit.""" - return self._x_unit - - def get_y_unit(self) -> Optional[Unit]: - """Get the current y-axis unit.""" - return self._y_unit - - def get_z_unit(self) -> Optional[Unit]: - """Get the current z-axis unit.""" - return self._z_unit - - def _update_xlabel_with_unit(self) -> None: - """Update x-axis label to include unit.""" - if self._x_unit and hasattr(self, "_axes_mpl"): - current_label = self._axes_mpl.get_xlabel() - # Remove existing unit if present - if "[" in current_label and "]" in current_label: - current_label = current_label.split("[")[0].strip() - if current_label: - self._axes_mpl.set_xlabel(f"{current_label} [{self._x_unit.symbol}]") - - def _update_ylabel_with_unit(self) -> None: - """Update y-axis label to include unit.""" - if self._y_unit and hasattr(self, "_axes_mpl"): - current_label = self._axes_mpl.get_ylabel() - # Remove existing unit if present - if "[" in current_label and "]" in current_label: - current_label = current_label.split("[")[0].strip() - if current_label: - self._axes_mpl.set_ylabel(f"{current_label} [{self._y_unit.symbol}]") - - def _update_zlabel_with_unit(self) -> None: - """Update z-axis label to include unit (for 3D plots).""" - if ( - self._z_unit - and hasattr(self, "_axes_mpl") - and hasattr(self._axes_mpl, "set_zlabel") - ): - current_label = self._axes_mpl.get_zlabel() - # Remove existing unit if present - if "[" in current_label and "]" in current_label: - current_label = current_label.split("[")[0].strip() - if current_label: - self._axes_mpl.set_zlabel(f"{current_label} [{self._z_unit.symbol}]") - - def plot_with_units(self, x, y, x_unit=None, y_unit=None, **kwargs): - """Plot with automatic unit handling. - - Parameters - ---------- - x : array-like or Quantity - X-axis data - y : array-like or Quantity - Y-axis data - x_unit : str or Unit, optional - Unit for x-axis (overrides detected unit) - y_unit : str or Unit, optional - Unit for y-axis (overrides detected unit) - **kwargs : dict - Additional plotting parameters - - Returns - ------- - lines : list of Line2D - The plotted lines - """ - # Extract values and units from Quantity objects - x_val, x_detected_unit = self._extract_value_and_unit(x) - y_val, y_detected_unit = self._extract_value_and_unit(y) - - # Use provided units or detected units - if x_unit: - self.set_x_unit(x_unit) - elif x_detected_unit and not self._x_unit: - self.set_x_unit(x_detected_unit) - - if y_unit: - self.set_y_unit(y_unit) - elif y_detected_unit and not self._y_unit: - self.set_y_unit(y_detected_unit) - - # Validate units if enabled - if self._unit_validation_enabled: - self._validate_unit_compatibility(x_detected_unit, self._x_unit, "x") - self._validate_unit_compatibility(y_detected_unit, self._y_unit, "y") - - # Plot using the standard method - return self.plot(x_val, y_val, **kwargs) - - def _extract_value_and_unit(self, data) -> Tuple[np.ndarray, Optional[Unit]]: - """Extract numerical value and unit from data.""" - if hasattr(data, "value") and hasattr(data, "unit"): - # It's a Quantity object - return data.value, data.unit - else: - # Regular array - return np.asarray(data), None - - def _validate_unit_compatibility( - self, data_unit: Optional[Unit], axis_unit: Optional[Unit], axis_name: str - ) -> None: - """Validate that data unit is compatible with axis unit.""" - if not self._unit_validation_enabled: - return - - if data_unit and axis_unit: - # Check if units have same dimensions - if data_unit.dimensions != axis_unit.dimensions: - raise UnitMismatchError( - f"Unit mismatch on {axis_name}-axis: " - f"data has unit {data_unit.symbol} {data_unit.dimensions}, " - f"but axis expects {axis_unit.symbol} {axis_unit.dimensions}" - ) - - def convert_x_units( - self, new_unit: Union[str, Unit], update_data: bool = True - ) -> float: - """Convert x-axis to new units. - - Parameters - ---------- - new_unit : str or Unit - Target unit - update_data : bool - Whether to update plotted data - - Returns - ------- - float - Conversion factor applied - """ - if isinstance(new_unit, str): - new_unit = getattr(Units, new_unit) - - if not self._x_unit: - raise ValueError("No x-axis unit set") - - # Calculate conversion factor - factor = self._x_unit.scale / new_unit.scale - - if update_data and hasattr(self, "_axes_mpl"): - # Update all line data - for line in self._axes_mpl.lines: - xdata = line.get_xdata() - line.set_xdata(xdata * factor) - - # Update x-axis limits - xlim = self._axes_mpl.get_xlim() - self._axes_mpl.set_xlim([x * factor for x in xlim]) - - # Update unit - self.set_x_unit(new_unit) - - return factor - - def convert_y_units( - self, new_unit: Union[str, Unit], update_data: bool = True - ) -> float: - """Convert y-axis to new units. - - Parameters - ---------- - new_unit : str or Unit - Target unit - update_data : bool - Whether to update plotted data - - Returns - ------- - float - Conversion factor applied - """ - if isinstance(new_unit, str): - new_unit = getattr(Units, new_unit) - - if not self._y_unit: - raise ValueError("No y-axis unit set") - - # Calculate conversion factor - factor = self._y_unit.scale / new_unit.scale - - if update_data and hasattr(self, "_axes_mpl"): - # Update all line data - for line in self._axes_mpl.lines: - ydata = line.get_ydata() - line.set_ydata(ydata * factor) - - # Update y-axis limits - ylim = self._axes_mpl.get_ylim() - self._axes_mpl.set_ylim([y * factor for y in ylim]) - - # Update unit - self.set_y_unit(new_unit) - - return factor - - def set_xlabel(self, label: str, unit: Optional[Union[str, Unit]] = None) -> None: - """Set x-axis label with optional unit. - - Parameters - ---------- - label : str - Axis label text - unit : str or Unit, optional - Unit to display - """ - if unit: - self.set_x_unit(unit) - - if self._x_unit: - label = f"{label} [{self._x_unit.symbol}]" - - # Validate units (educational warnings for scientific standards) - validate_axis_label(label, "X axis") - - self._axes_mpl.set_xlabel(label) - - def set_ylabel(self, label: str, unit: Optional[Union[str, Unit]] = None) -> None: - """Set y-axis label with optional unit. - - Parameters - ---------- - label : str - Axis label text - unit : str or Unit, optional - Unit to display - """ - if unit: - self.set_y_unit(unit) - - if self._y_unit: - label = f"{label} [{self._y_unit.symbol}]" - - # Validate units (educational warnings for scientific standards) - validate_axis_label(label, "Y axis") - - self._axes_mpl.set_ylabel(label) - - def set_zlabel(self, label: str, unit: Optional[Union[str, Unit]] = None) -> None: - """Set z-axis label with optional unit (for 3D plots). - - Parameters - ---------- - label : str - Axis label text - unit : str or Unit, optional - Unit to display - """ - if not hasattr(self._axes_mpl, "set_zlabel"): - raise ValueError("Z-axis labels only available for 3D plots") - - if unit: - self.set_z_unit(unit) - - if self._z_unit: - label = f"{label} [{self._z_unit.symbol}]" - - # Validate units (educational warnings for scientific standards) - validate_axis_label(label, "Z axis") - - self._axes_mpl.set_zlabel(label) diff --git a/src/scitex/plt/_subplots/_AxisWrapperMixins/__init__.py b/src/scitex/plt/_subplots/_AxisWrapperMixins/__init__.py deleted file mode 100755 index a47dde3df..000000000 --- a/src/scitex/plt/_subplots/_AxisWrapperMixins/__init__.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/__init__.py - -""" -AxisWrapper Mixins - Modular plotting API for SciTeX. - -API Layers -========== - -SciTeX provides three distinct API layers for plotting, each with different -purposes and trade-offs: - -stx_* (SciTeX Canonical) ------------------------- -- Input: ArrayLike / List / ndarray -- Output: (Axes, tracked_df, meta) -- Purpose: publication / reproducibility -- Features: - * Full tracking and metadata support - * Output connects to .plot / .figure formats - * Automatic styling according to SciTeX style - * Primary API - recommended for final figures - - Examples: - ax.stx_bar(x, height) - ax.stx_scatter(x, y, label="Data") - ax.stx_kde(data) - -mpl_* (Matplotlib Compatibility) --------------------------------- -- Input: Same as matplotlib -- Output: matplotlib artists -- Purpose: compatibility / low-level control / escape hatch -- Features: - * No tracking, no scitex processing - * Direct matplotlib API access - * Use for unsupported operations or migration - - Examples: - ax.mpl_plot(x, y) - ax.mpl_scatter(x, y) - ax.mpl_raw("some_method", *args) # Call any matplotlib method - -sns_* (Seaborn / DataFrame-Centric) ------------------------------------ -- Input: DataFrame + column names (data=, x=, y=, hue=) -- Output: Axes (+ summarized df) -- Purpose: exploratory / grouped statistics -- Features: - * DataFrame-centric interface - * Statistical summaries and grouping - * Familiar seaborn UX - - Examples: - ax.sns_boxplot(data=df, x="group", y="value") - ax.sns_histplot(data=df, x="measurement", hue="category") - -Choosing an API Layer -===================== - -Use stx_*: - - For publication-ready figures - - When you need reproducibility and tracking - - As your default choice - -Use mpl_*: - - For low-level matplotlib control - - When migrating existing matplotlib code - - For matplotlib features not yet wrapped - -Use sns_*: - - For exploratory data analysis - - When input is a DataFrame - - For statistical visualization with grouping -""" - -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) - -from ._AdjustmentMixin import AdjustmentMixin -from ._MatplotlibPlotMixin import MatplotlibPlotMixin -from ._RawMatplotlibMixin import MPL_METHODS, RawMatplotlibMixin -from ._SeabornMixin import SeabornMixin -from ._TrackingMixin import TrackingMixin -from ._UnitAwareMixin import UnitAwareMixin - -# EOF diff --git a/src/scitex/plt/_subplots/_FigWrapper.py b/src/scitex/plt/_subplots/_FigWrapper.py deleted file mode 100755 index d8019c733..000000000 --- a/src/scitex/plt/_subplots/_FigWrapper.py +++ /dev/null @@ -1,475 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2025-05-19 02:53:28 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_FigWrapper.py.new -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/_subplots/_FigWrapper.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import warnings -from functools import wraps - -import pandas as pd - -from scitex import logging - -logger = logging.getLogger(__name__) - - -class FigWrapper: - def __init__(self, fig_mpl): - self._fig_mpl = fig_mpl - self._axes = [] # Keep track of axes for synchronization - self._last_saved_info = None - self._not_saved_yet_flag = True - self._called_from_mng_io_save = False - - @property - def figure( - self, - ): - return self._fig_mpl - - def __getattr__(self, attr): - # print(f"Attribute of FigWrapper: {attr}") - attr_mpl = getattr(self._fig_mpl, attr) - - if callable(attr_mpl): - - @wraps(attr_mpl) - def wrapper(*args, track=None, id=None, **kwargs): - # Suppress constrained_layout warnings for certain operations - import warnings - - with warnings.catch_warnings(): - if attr in ["subplots_adjust", "tight_layout"]: - warnings.filterwarnings( - "ignore", - message=".*constrained_layout.*", - category=UserWarning, - ) - warnings.filterwarnings( - "ignore", - message=".*layout engine.*incompatible.*", - category=UserWarning, - ) - results = attr_mpl(*args, **kwargs) - # self._track(track, id, attr, args, kwargs) - return results - - return wrapper - - else: - return attr_mpl - - def __dir__(self): - # Combine attributes from both self and the wrapped matplotlib figure - attrs = set(dir(self.__class__)) - attrs.update(object.__dir__(self)) - attrs.update(dir(self._fig_mpl)) - return sorted(attrs) - - def savefig(self, fname, *args, embed_metadata=True, metadata=None, **kwargs): - """ - Save figure with automatic metadata embedding. - - Parameters - ---------- - fname : str - Output file path - embed_metadata : bool, optional - Automatically embed dimension/style metadata in PNG/JPEG/TIFF/PDF (default: True) - metadata : dict, optional - Additional custom metadata to merge with auto-collected metadata - *args, **kwargs - Passed to scitex.io.save_image or matplotlib savefig - - Notes - ----- - For PNG/JPEG/TIFF/PDF formats, metadata is automatically embedded including: - - Software versions (scitex, matplotlib) - - Timestamp - - Figure/axes dimensions (mm, inch, px) - - DPI settings - - Styling parameters (if available via _scitex_metadata) - - Mode (display/publication) - - For other formats (SVG, etc.), delegates to matplotlib's savefig. - - When facecolor is specified (and is not 'none'), axes with transparent - patches will temporarily have their alpha set to 1.0 to ensure the - facecolor is visible. - - Examples - -------- - >>> fig, ax = splt.subplots(fig_mm={'width': 35, 'height': 24.5}) - >>> ax.plot(x, y) - >>> fig.savefig('result.png', dpi=300) # Metadata embedded automatically! - - >>> # Add custom metadata - >>> fig.savefig('result.png', dpi=300, metadata={'experiment': 'test_001'}) - - >>> # Disable metadata embedding - >>> fig.savefig('result.png', embed_metadata=False) - - >>> # Override transparent background with white - >>> fig.savefig('result.png', facecolor='white') - """ - # Handle facecolor override for transparent figures - # When facecolor is specified (not 'none'), temporarily make axes and figure opaque - facecolor = kwargs.get("facecolor", None) - patches_backup = [] # List of (patch, original_alpha, original_facecolor) - - if facecolor is not None: - # Check if facecolor indicates a non-transparent background - is_opaque_facecolor = True - if isinstance(facecolor, str): - if facecolor.lower() in ("none", "transparent"): - is_opaque_facecolor = False - - if is_opaque_facecolor: - # Backup and set figure patch to opaque - fig_patch = self._fig_mpl.patch - fig_alpha = fig_patch.get_alpha() - fig_fc = fig_patch.get_facecolor() - patches_backup.append((fig_patch, fig_alpha, fig_fc)) - fig_patch.set_alpha(1.0) - fig_patch.set_facecolor(facecolor) - - # Backup and set axes patches to opaque - for ax_mpl in self._fig_mpl.axes: - ax_patch = ax_mpl.patch - original_alpha = ax_patch.get_alpha() - original_fc = ax_patch.get_facecolor() - patches_backup.append((ax_patch, original_alpha, original_fc)) - ax_patch.set_alpha(1.0) - # Set axes facecolor to match figure facecolor if it was transparent - if original_alpha == 0.0 or original_alpha is None: - ax_patch.set_facecolor(facecolor) - - # Ensure transparent=False so matplotlib respects the facecolor - if "transparent" not in kwargs: - kwargs["transparent"] = False - # Wrap save logic in try/finally to restore axes alpha - try: - # Check if this is a format that can have metadata (PNG/JPEG/TIFF/PDF) - # Handle both string paths and file-like objects (e.g., BytesIO) - if isinstance(fname, str): - is_image_format = fname.lower().endswith( - (".png", ".jpg", ".jpeg", ".tiff", ".tif", ".pdf") - ) - else: - # For file-like objects, check the 'format' kwarg if provided - # Otherwise default to False (no metadata embedding for BytesIO etc.) - fmt = kwargs.get("format", "").lower() if kwargs.get("format") else "" - is_image_format = fmt in ("png", "jpg", "jpeg", "tiff", "tif", "pdf") - - if is_image_format and embed_metadata: - # Collect automatic metadata - auto_metadata = None - - # Get first axes if available - # Keep the scitex AxisWrapper (for history tracking) separate from matplotlib axes - ax = None - ax_scitex = None # scitex AxisWrapper with history - if hasattr(self, "axes"): - try: - # Try to get first axes from various wrapper types - if hasattr(self.axes, "_ax"): # AxisWrapper - ax = self.axes._ax - ax_scitex = self.axes # Keep the wrapper for history - elif hasattr(self.axes, "_axis_mpl"): # Alternative - ax = self.axes._axis_mpl - ax_scitex = self.axes - elif hasattr(self.axes, "flatten"): # AxesWrapper - flat = list(self.axes.flatten()) - if flat and hasattr(flat[0], "_ax"): - ax = flat[0]._ax - ax_scitex = flat[0] # Keep the wrapper for history - elif flat and hasattr(flat[0], "_axis_mpl"): - ax = flat[0]._axis_mpl - ax_scitex = flat[0] - except Exception: - pass - - # If still no axes, try from figure - if ( - ax is None - and hasattr(self._fig_mpl, "axes") - and len(self._fig_mpl.axes) > 0 - ): - ax = self._fig_mpl.axes[0] - - # Collect metadata - # Pass ax_scitex if available (has history for plot type detection) - try: - from scitex.plt.utils import collect_figure_metadata - - auto_metadata = collect_figure_metadata( - self._fig_mpl, ax_scitex if ax_scitex else ax - ) - - # Merge with custom metadata - if metadata: - if "custom" not in auto_metadata: - auto_metadata["custom"] = {} - auto_metadata["custom"].update(metadata) - except Exception as e: - # If metadata collection fails, warn but continue - logger.warning(f"Could not collect metadata: {e}") - auto_metadata = metadata - - # Use scitex.io.save_image for metadata embedding - try: - from scitex.io._save_modules import save_image - - save_image( - self._fig_mpl, fname, metadata=auto_metadata, *args, **kwargs - ) - except Exception as e: - # Fallback to regular matplotlib savefig - logger.warning( - f"Metadata embedding failed, using regular savefig: {e}" - ) - self._fig_mpl.savefig(fname, *args, **kwargs) - else: - # For non-image formats or when metadata disabled, use regular savefig - self._fig_mpl.savefig(fname, *args, **kwargs) - finally: - # Restore patch alpha and facecolor values if they were modified - for patch, original_alpha, original_fc in patches_backup: - patch.set_alpha(original_alpha) - patch.set_facecolor(original_fc) - - def export_as_csv(self): - """Export plotted data from all axes. - - This method collects data from all axes in the figure and combines - them into a single DataFrame with appropriate axis identifiers in - the column names. - - Returns - ------- - pd.DataFrame: Combined DataFrame with data from all axes, - with axis ID prefixes for each column. - """ - dfs = [] - - # Use the _traverse_axes helper method to iterate through all axes - # regardless of their structure (single, array, list, etc.) - for ii, ax in enumerate(self._traverse_axes()): - # Try different ways to access the export_as_csv method - df = None - try: - if hasattr(ax, "_axis_mpl") and hasattr(ax._axis_mpl, "export_as_csv"): - # If it's a nested structure with _axis_mpl having export_as_csv - df = ax._axis_mpl.export_as_csv() - elif hasattr(ax, "export_as_csv"): - # Direct AxisWrapper object - df = ax.export_as_csv() - else: - # Skip if no export method available - continue - except Exception: - continue - - # Process the DataFrame if it's not empty - if df is not None and not df.empty: - # Column names already include axis position via get_csv_column_name - # (single source of truth from _csv_column_naming.py) - # Only handle duplicates by adding a counter - new_cols = [] - col_counts = {} - for col in df.columns: - col_str = str(col) - - # Handle duplicates by adding a counter - if col_str in col_counts: - col_counts[col_str] += 1 - col_str = f"{col_str}_{col_counts[col_str]}" - else: - col_counts[col_str] = 0 - - new_cols.append(col_str) - - df.columns = new_cols - dfs.append(df) - - # Return concatenated DataFrame or empty DataFrame if no data - return pd.concat(dfs, axis=1) if dfs else pd.DataFrame() - - def colorbar(self, mappable, ax=None, **kwargs): - """Add a colorbar to the figure, automatically unwrapping SciTeX axes. - - This method properly handles both regular matplotlib axes and SciTeX - AxisWrapper objects when creating colorbars. - - Parameters - ---------- - mappable : ScalarMappable - The image, contour set, etc. to which the colorbar applies - ax : Axes or AxisWrapper, optional - The axes to attach the colorbar to. If not specified, uses current axes. - **kwargs : dict - Additional keyword arguments passed to matplotlib's colorbar - - Returns - ------- - Colorbar - The created colorbar object - """ - # Unwrap axes if it's a SciTeX AxisWrapper - if ax is not None: - ax_mpl = ax._axis_mpl if hasattr(ax, "_axis_mpl") else ax - else: - ax_mpl = None - - # Call matplotlib's colorbar with the unwrapped axes - return self._fig_mpl.colorbar(mappable, ax=ax_mpl, **kwargs) - - def _traverse_axes(self): - """Helper method to traverse all axis wrappers in the figure.""" - if hasattr(self, "axes"): - # Check if we're dealing with an AxesWrapper instance - if hasattr(self.axes, "_axes_scitex") and hasattr( - self.axes._axes_scitex, "flat" - ): - # This is an AxesWrapper, get the individual AxisWrapper objects - for ax in self.axes._axes_scitex.flat: - yield ax - elif not hasattr(self.axes, "__iter__"): - # Single axis case - yield self.axes - else: - # Multiple axes case - if hasattr(self.axes, "flat"): - # 2D array of axes - for ax in self.axes.flat: - yield ax - elif hasattr(self.axes, "ravel"): - # Numpy array - for ax in self.axes.ravel(): - yield ax - elif isinstance(self.axes, (list, tuple)): - # List of axes - for ax in self.axes: - yield ax - - @property - def history(self): - """Aggregate tracking history from all axes in the figure. - - Returns a combined OrderedDict of all tracking records from all axes, - enabling FTS bundle creation to build encoding from plot operations. - """ - from collections import OrderedDict - - combined = OrderedDict() - for ax in self._traverse_axes(): - if hasattr(ax, "history") and ax.history: - combined.update(ax.history) - return combined - - def legend(self, *args, loc="best", **kwargs): - """Legend with 'best' automatic placement by default for all axes.""" - for ax in self._traverse_axes(): - try: - ax.legend(*args, loc=loc, **kwargs) - except Exception: - pass - - def supxyt(self, x=False, y=False, t=False): - """Wrapper for supxlabel, supylabel, and suptitle""" - if x is not False: - self._fig_mpl.supxlabel(x) - if y is not False: - self._fig_mpl.supylabel(y) - if t is not False: - self._fig_mpl.suptitle(t) - return self._fig_mpl - - def tight_layout(self, *, rect=[0, 0.03, 1, 0.95], **kwargs): - """Wrapper for tight_layout with rect=[0, 0.03, 1, 0.95] by default. - - Handles cases where certain axes (like colorbars) are incompatible - with tight_layout. If the figure is using constrained_layout, this - method does nothing as constrained_layout handles spacing automatically. - """ - # Check if figure is already using constrained_layout - if ( - hasattr(self._fig_mpl, "get_constrained_layout") - and self._fig_mpl.get_constrained_layout() - ): - # Figure is using constrained_layout, which handles colorbars better - # No need to call tight_layout - return - - try: - with warnings.catch_warnings(): - # Suppress the specific warning about incompatible axes - warnings.filterwarnings( - "ignore", - message="This figure includes Axes that are not compatible with tight_layout", - ) - self._fig_mpl.tight_layout(rect=rect, **kwargs) - except Exception: - # If tight_layout fails completely, try constrained_layout as fallback - try: - self._fig_mpl.set_constrained_layout(True) - self._fig_mpl.set_constrained_layout_pads(w_pad=0.04, h_pad=0.04) - except Exception: - # If both fail, do nothing - figure will use default layout - pass - - def adjust_layout(self, **kwargs): - """Adjust the constrained layout parameters. - - Parameters - ---------- - w_pad : float, optional - Width padding around axes (default: 0.05) - h_pad : float, optional - Height padding around axes (default: 0.05) - wspace : float, optional - Width space between subplots (default: 0.02) - hspace : float, optional - Height space between subplots (default: 0.02) - rect : list of 4 floats, optional - Rectangle in normalized figure coordinates to fit the whole layout - [left, bottom, right, top] (default: [0, 0, 1, 1]) - """ - if ( - hasattr(self._fig_mpl, "get_constrained_layout") - and self._fig_mpl.get_constrained_layout() - ): - # Update constrained layout parameters - self._fig_mpl.set_constrained_layout_pads(**kwargs) - else: - # Fall back to tight_layout with rect parameter if provided - if "rect" in kwargs: - self.tight_layout(rect=kwargs["rect"]) - - def close(self): - """Close the underlying matplotlib figure""" - import matplotlib.pyplot as plt - - plt.close(self._fig_mpl) - - @property - def number(self): - """Return the figure number for matplotlib.pyplot.close() compatibility""" - return self._fig_mpl.number - - def __del__(self): - """Cleanup when FigWrapper is deleted""" - try: - import matplotlib.pyplot as plt - - plt.close(self._fig_mpl) - except: - pass - - -# EOF diff --git a/src/scitex/plt/_subplots/_SubplotsWrapper.py b/src/scitex/plt/_subplots/_SubplotsWrapper.py deleted file mode 100755 index e3d11d40d..000000000 --- a/src/scitex/plt/_subplots/_SubplotsWrapper.py +++ /dev/null @@ -1,331 +0,0 @@ -#!/usr/bin/env python3 -"""SubplotsWrapper: Monitor data plotted using matplotlib for CSV export.""" - -import os -from collections import OrderedDict - -import matplotlib.pyplot as plt - -__FILE__ = "./src/scitex/plt/_subplots/_SubplotsWrapper.py" -__DIR__ = os.path.dirname(__FILE__) - -import os - -import matplotlib as mpl - -# Register Arial fonts at module import -import matplotlib.font_manager as fm - -# Configure fonts at import -from ._fonts import _arial_enabled # noqa: F401 -from ._mm_layout import create_with_mm_control - -_arial_enabled = False - -# Try to find Arial -try: - fm.findfont("Arial", fallback_to_default=False) - _arial_enabled = True -except Exception: - # Search for Arial font files and register them - arial_paths = [ - f - for f in fm.findSystemFonts() - if os.path.basename(f).lower().startswith("arial") - ] - - if arial_paths: - for path in arial_paths: - try: - fm.fontManager.addfont(path) - except Exception: - pass - - # Verify Arial is now available - try: - fm.findfont("Arial", fallback_to_default=False) - _arial_enabled = True - except Exception: - pass - -# Configure matplotlib to use Arial if available -if _arial_enabled: - mpl.rcParams["font.family"] = "Arial" - mpl.rcParams["font.sans-serif"] = [ - "Arial", - "Helvetica", - "DejaVu Sans", - "Liberation Sans", - ] -else: - # Warn about missing Arial - from scitex import logging as _logging - - _logger = _logging.getLogger(__name__) - _logger.warning( - "Arial font not found. Using fallback fonts (Helvetica/DejaVu Sans). " - "For publication figures with Arial: sudo apt-get install ttf-mscorefonts-installer && fc-cache -fv" - ) - - -class SubplotsWrapper: - """ - A wrapper class monitors data plotted using the ax methods from matplotlib.pyplot. - This data can be converted into a CSV file formatted for SigmaPlot compatibility. - - Supports optional figrecipe integration for reproducible figures. - When figrecipe is available and `use_figrecipe=True`, figures are created - with recipe recording capability for later reproduction. - """ - - def __init__(self): - self._subplots_wrapper_history = OrderedDict() - self._fig_scitex = None - self._counter_part = plt.subplots - self._figrecipe_available = None # Lazy check - - def _check_figrecipe(self): - """Check if figrecipe is available (lazy, cached).""" - if self._figrecipe_available is None: - try: - import figrecipe # noqa: F401 - - self._figrecipe_available = True - except ImportError: - self._figrecipe_available = False - return self._figrecipe_available - - def __call__( - self, - *args, - track=True, - sharex=False, - sharey=False, - constrained_layout=None, - use_figrecipe=None, # NEW: Enable figrecipe recording - # MM-control parameters (unified style system) - axes_width_mm=None, - axes_height_mm=None, - margin_left_mm=None, - margin_right_mm=None, - margin_bottom_mm=None, - margin_top_mm=None, - space_w_mm=None, - space_h_mm=None, - axes_thickness_mm=None, - tick_length_mm=None, - tick_thickness_mm=None, - trace_thickness_mm=None, - marker_size_mm=None, - axis_font_size_pt=None, - tick_font_size_pt=None, - title_font_size_pt=None, - legend_font_size_pt=None, - suptitle_font_size_pt=None, - n_ticks=None, - mode=None, - dpi=None, - styles=None, - transparent=None, - theme=None, - **kwargs, - ): - """ - Create figure and axes with optional millimeter-based control. - - Parameters - ---------- - *args : int - nrows, ncols passed to matplotlib.pyplot.subplots - track : bool, optional - Track plotting operations for CSV export (default: True) - use_figrecipe : bool or None, optional - If True, use figrecipe for recipe recording. - If None (default), auto-detect figrecipe availability. - If False, disable figrecipe even if available. - - MM-Control Parameters - --------------------- - axes_width_mm, axes_height_mm : float or list - Axes dimensions in mm - margin_*_mm : float - Figure margins in mm - space_w_mm, space_h_mm : float - Spacing between axes in mm - mode : str - 'publication' or 'display' - - Returns - ------- - fig : FigWrapper - Wrapped matplotlib Figure (with optional RecordingFigure) - ax or axes : AxisWrapper or AxesWrapper - Wrapped matplotlib Axes - """ - # Resolve style values - from scitex.plt.styles import SCITEX_STYLE as _S - from scitex.plt.styles import resolve_style_value as _resolve - - axes_width_mm = _resolve( - "axes.width_mm", axes_width_mm, _S.get("axes_width_mm") - ) - axes_height_mm = _resolve( - "axes.height_mm", axes_height_mm, _S.get("axes_height_mm") - ) - margin_left_mm = _resolve( - "margins.left_mm", margin_left_mm, _S.get("margin_left_mm") - ) - margin_right_mm = _resolve( - "margins.right_mm", margin_right_mm, _S.get("margin_right_mm") - ) - margin_bottom_mm = _resolve( - "margins.bottom_mm", margin_bottom_mm, _S.get("margin_bottom_mm") - ) - margin_top_mm = _resolve( - "margins.top_mm", margin_top_mm, _S.get("margin_top_mm") - ) - space_w_mm = _resolve("spacing.horizontal_mm", space_w_mm, _S.get("space_w_mm")) - space_h_mm = _resolve("spacing.vertical_mm", space_h_mm, _S.get("space_h_mm")) - axes_thickness_mm = _resolve( - "axes.thickness_mm", axes_thickness_mm, _S.get("axes_thickness_mm") - ) - tick_length_mm = _resolve( - "ticks.length_mm", tick_length_mm, _S.get("tick_length_mm") - ) - tick_thickness_mm = _resolve( - "ticks.thickness_mm", tick_thickness_mm, _S.get("tick_thickness_mm") - ) - trace_thickness_mm = _resolve( - "lines.trace_mm", trace_thickness_mm, _S.get("trace_thickness_mm") - ) - marker_size_mm = _resolve( - "markers.size_mm", marker_size_mm, _S.get("marker_size_mm") - ) - axis_font_size_pt = _resolve( - "fonts.axis_label_pt", axis_font_size_pt, _S.get("axis_font_size_pt") - ) - tick_font_size_pt = _resolve( - "fonts.tick_label_pt", tick_font_size_pt, _S.get("tick_font_size_pt") - ) - title_font_size_pt = _resolve( - "fonts.title_pt", title_font_size_pt, _S.get("title_font_size_pt") - ) - legend_font_size_pt = _resolve( - "fonts.legend_pt", legend_font_size_pt, _S.get("legend_font_size_pt") - ) - suptitle_font_size_pt = _resolve( - "fonts.suptitle_pt", suptitle_font_size_pt, _S.get("suptitle_font_size_pt") - ) - n_ticks = _resolve("ticks.n_ticks", n_ticks, _S.get("n_ticks"), int) - dpi = _resolve("output.dpi", dpi, _S.get("dpi"), int) - - if transparent is None: - transparent = _S.get("transparent", True) - if mode is None: - mode = _S.get("mode", "publication") - if theme is None: - theme = _resolve("theme.mode", None, "light", str) - - # Determine figrecipe usage - if use_figrecipe is None: - use_figrecipe = self._check_figrecipe() - - # Create figure with mm-control - fig, axes = create_with_mm_control( - *args, - track=track, - sharex=sharex, - sharey=sharey, - axes_width_mm=axes_width_mm, - axes_height_mm=axes_height_mm, - margin_left_mm=margin_left_mm, - margin_right_mm=margin_right_mm, - margin_bottom_mm=margin_bottom_mm, - margin_top_mm=margin_top_mm, - space_w_mm=space_w_mm, - space_h_mm=space_h_mm, - axes_thickness_mm=axes_thickness_mm, - tick_length_mm=tick_length_mm, - tick_thickness_mm=tick_thickness_mm, - trace_thickness_mm=trace_thickness_mm, - marker_size_mm=marker_size_mm, - axis_font_size_pt=axis_font_size_pt, - tick_font_size_pt=tick_font_size_pt, - title_font_size_pt=title_font_size_pt, - legend_font_size_pt=legend_font_size_pt, - suptitle_font_size_pt=suptitle_font_size_pt, - n_ticks=n_ticks, - mode=mode, - dpi=dpi, - styles=styles, - transparent=transparent, - theme=theme, - **kwargs, - ) - - # If figrecipe enabled, create recording layer - if use_figrecipe: - self._attach_figrecipe_recorder(fig) - - self._fig_scitex = fig - return fig, axes - - def _attach_figrecipe_recorder(self, fig_wrapper): - """Attach figrecipe recorder to FigWrapper for recipe export. - - This creates a RecordingFigure layer that wraps the underlying - matplotlib figure, enabling save_recipe() on the FigWrapper. - """ - try: - from figrecipe._recorder import Recorder - - # Get the underlying matplotlib figure - mpl_fig = fig_wrapper._fig_mpl - - # Create recorder - recorder = Recorder() - figsize = mpl_fig.get_size_inches() - dpi_val = mpl_fig.dpi - recorder.start_figure(figsize=tuple(figsize), dpi=int(dpi_val)) - - # Store recorder on FigWrapper for later recipe export - fig_wrapper._figrecipe_recorder = recorder - fig_wrapper._figrecipe_enabled = True - - # Store style info from scitex in the recipe - if hasattr(mpl_fig, "_scitex_theme"): - recorder.figure_record.style = {"theme": mpl_fig._scitex_theme} - - except Exception: - # Silently fail - figrecipe is optional - fig_wrapper._figrecipe_enabled = False - - def __dir__(self): - """Provide combined directory for tab completion.""" - local_attrs = set(super().__dir__()) - try: - counterpart_attrs = set(dir(self._counter_part)) - except Exception: - counterpart_attrs = set() - return sorted(local_attrs.union(counterpart_attrs)) - - -# Instantiate the wrapper -subplots = SubplotsWrapper() - - -if __name__ == "__main__": - import matplotlib - - import scitex - - matplotlib.use("TkAgg") - - fig, ax = subplots() - ax.plot([1, 2, 3], [4, 5, 6], id="plot1") - ax.plot([4, 5, 6], [1, 2, 3], id="plot2") - scitex.io.save(fig, "/tmp/subplots_demo/plots.png") - - print(ax.export_as_csv()) - -# EOF diff --git a/src/scitex/plt/_subplots/__init__.py b/src/scitex/plt/_subplots/__init__.py deleted file mode 100755 index 0654cdbd6..000000000 --- a/src/scitex/plt/_subplots/__init__.py +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 05:22:40 (ywatanabe)" -# File: /home/ywatanabe/proj/_scitex_repo/src/scitex/plt/_subplots/__init__.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/_subplots/__init__.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from ._AxesWrapper import AxesWrapper -from ._AxisWrapper import AxisWrapper - -# Import wrapper classes -from ._FigWrapper import FigWrapper - -# Backward-compatible aliases -_FigWrapper = FigWrapper -_AxisWrapper = AxisWrapper -_AxesWrapper = AxesWrapper - -# Import export_as_csv module functions -from ._export_as_csv import export_as_csv, format_record - -# Import formatters for backward compatibility -from ._export_as_csv_formatters import ( - _format_bar, - _format_boxplot, - _format_errorbar, - _format_fill_between, - _format_hist, - _format_imshow, - _format_plot, - _format_plot_conf_mat, - _format_plot_ecdf, - _format_plot_joyplot, - _format_plot_kde, - _format_plot_line, - _format_plot_mean_std, - _format_plot_raster, - _format_scatter, - _format_sns_boxplot, - _format_violin, -) - -# import importlib -# import inspect - -# # Get the current directory -# current_dir = os.path.dirname(__file__) - -# # Iterate through all Python files in the current directory -# for filename in os.listdir(current_dir): -# if filename.endswith(".py") and not filename.startswith("__"): -# module_name = filename[:-3] # Remove .py extension -# module = importlib.import_module(f".{module_name}", package=__name__) - -# # Import only functions and classes from the module -# for name, obj in inspect.getmembers(module): -# if inspect.isfunction(obj) or inspect.isclass(obj): -# if not name.startswith("_"): -# globals()[name] = obj - -# # Clean up temporary variables -# del ( -# os, -# importlib, -# inspect, -# current_dir, -# filename, -# module_name, -# module, -# name, -# obj, -# ) - -# ################################################################################ -# # For Matplotlib Compatibility -# ################################################################################ -# import matplotlib.pyplot.subplots as counter_part - -# _local_module_attributes = list(globals().keys()) -# print(_local_module_attributes) - - -# def __getattr__(name): -# """ -# Fallback to fetch attributes from matplotlib.pyplot -# if they are not defined directly in this module. -# """ -# try: -# # Get the attribute from matplotlib.pyplot -# return getattr(counter_part, name) -# except AttributeError: -# # Raise the standard error if not found in pyplot either -# raise AttributeError( -# f"module '{__name__}' nor matplotlib.pyplot has attribute '{name}'" -# ) from None - - -# def __dir__(): -# """ -# Provide combined directory for tab completion, including -# attributes from this module and matplotlib.pyplot. -# """ -# # Get attributes defined explicitly in this module -# local_attrs = set(_local_module_attributes) -# # Get attributes from matplotlib.pyplot -# pyplot_attrs = set(dir(counter_part)) -# # Return the sorted union -# return sorted(local_attrs.union(pyplot_attrs)) - - -""" -import matplotlib.pyplot as plt -import scitex.plt as mplt - -print(set(dir(mplt.subplots)) - set(dir(plt.subplots))) -""" - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv.py b/src/scitex/plt/_subplots/_export_as_csv.py deleted file mode 100755 index 90833a492..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv.py +++ /dev/null @@ -1,464 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-09-21 01:52:22 (ywatanabe)" -# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/plt/_subplots/_export_as_csv.py -# ---------------------------------------- -from __future__ import annotations - -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -import pandas as pd - -from scitex import logging -from scitex.pd import to_xyz - -logger = logging.getLogger(__name__) - -# Global warning registry to track which warnings have been shown -_warning_registry = set() - -# Mapping of matplotlib/seaborn methods to their scitex equivalents -_METHOD_ALTERNATIVES = { - # Matplotlib methods - "imshow": "plot_imshow", - "plot": "plot", # already tracked - "scatter": "plot_scatter", # already tracked - "bar": "plot_bar", # already tracked - "barh": "plot_barh", # already tracked - "hist": "hist", # already tracked - "boxplot": "stx_box or plot_boxplot", - "violinplot": "stx_violin or plot_violinplot", - "fill_between": "plot_fill_between", - "errorbar": "plot_errorbar", - "contour": "plot_contour", - "heatmap": "stx_heatmap", - # Seaborn methods (accessed via ax.sns_*) - "scatterplot": "sns_scatterplot", - "lineplot": "sns_lineplot", - "barplot": "sns_barplot", - "boxplot_sns": "sns_boxplot", - "violinplot_sns": "sns_violinplot", - "stripplot": "sns_stripplot", - "swarmplot": "sns_swarmplot", - "histplot": "sns_histplot", - "kdeplot": "sns_kdeplot", - "heatmap_sns": "sns_heatmap", - "jointplot": "sns_jointplot", - "pairplot": "sns_pairplot", -} - - -def _warn_once(message, category=UserWarning): - """Show a warning only once per runtime. - - Args: - message: Warning message to display - category: Warning category (default: UserWarning) - """ - if message not in _warning_registry: - _warning_registry.add(message) - logger.warning(message) - - -from ._export_as_csv_formatters import ( # Standard matplotlib formatters; Custom scitex formatters; stx_ aliases formatters; Seaborn formatters - _format_annotate, - _format_bar, - _format_barh, - _format_boxplot, - _format_contour, - _format_contourf, - _format_errorbar, - _format_eventplot, - _format_fill, - _format_fill_between, - _format_hexbin, - _format_hist, - _format_hist2d, - _format_imshow, - _format_imshow2d, - _format_matshow, - _format_pcolormesh, - _format_pie, - _format_plot, - _format_plot_box, - _format_plot_conf_mat, - _format_plot_ecdf, - _format_plot_fillv, - _format_plot_heatmap, - _format_plot_image, - _format_plot_imshow, - _format_plot_joyplot, - _format_plot_kde, - _format_plot_line, - _format_plot_mean_ci, - _format_plot_mean_std, - _format_plot_median_iqr, - _format_plot_raster, - _format_plot_rectangle, - _format_plot_scatter, - _format_plot_scatter_hist, - _format_plot_shaded_line, - _format_plot_violin, - _format_quiver, - _format_scatter, - _format_sns_barplot, - _format_sns_boxplot, - _format_sns_heatmap, - _format_sns_histplot, - _format_sns_jointplot, - _format_sns_kdeplot, - _format_sns_lineplot, - _format_sns_pairplot, - _format_sns_scatterplot, - _format_sns_stripplot, - _format_sns_swarmplot, - _format_sns_violinplot, - _format_stackplot, - _format_stem, - _format_step, - _format_streamplot, - _format_stx_bar, - _format_stx_barh, - _format_stx_contour, - _format_stx_errorbar, - _format_stx_imshow, - _format_stx_scatter, - _format_text, - _format_violin, - _format_violinplot, -) - -# Registry mapping method names to their formatter functions -_FORMATTER_REGISTRY = { - # Standard matplotlib methods - "annotate": _format_annotate, - "bar": _format_bar, - "barh": _format_barh, - "boxplot": _format_boxplot, - "contour": _format_contour, - "contourf": _format_contourf, - "errorbar": _format_errorbar, - "eventplot": _format_eventplot, - "fill": _format_fill, - "fill_between": _format_fill_between, - "stackplot": _format_stackplot, - "pcolormesh": _format_pcolormesh, - "pcolor": _format_pcolormesh, - "hexbin": _format_hexbin, - "hist": _format_hist, - "hist2d": _format_hist2d, - "imshow": _format_imshow, - "imshow2d": _format_imshow2d, - "matshow": _format_matshow, - "pie": _format_pie, - "plot": _format_plot, - "quiver": _format_quiver, - "scatter": _format_scatter, - "stem": _format_stem, - "step": _format_step, - "streamplot": _format_streamplot, - "text": _format_text, - "violin": _format_violin, - "violinplot": _format_violinplot, - # Custom scitex methods - "stx_box": _format_plot_box, - "stx_conf_mat": _format_plot_conf_mat, - "stx_contour": _format_stx_contour, - "stx_ecdf": _format_plot_ecdf, - "stx_fillv": _format_plot_fillv, - "stx_heatmap": _format_plot_heatmap, - "stx_image": _format_plot_image, - "plot_imshow": _format_plot_imshow, - "stx_imshow": _format_stx_imshow, - "stx_joyplot": _format_plot_joyplot, - "stx_kde": _format_plot_kde, - "stx_line": _format_plot_line, - "stx_mean_ci": _format_plot_mean_ci, - "stx_mean_std": _format_plot_mean_std, - "stx_median_iqr": _format_plot_median_iqr, - "stx_raster": _format_plot_raster, - "stx_rectangle": _format_plot_rectangle, - "plot_scatter": _format_plot_scatter, - "stx_scatter_hist": _format_plot_scatter_hist, - "stx_shaded_line": _format_plot_shaded_line, - "stx_violin": _format_plot_violin, - # stx_ aliases - "stx_scatter": _format_stx_scatter, - "stx_bar": _format_stx_bar, - "stx_barh": _format_stx_barh, - "stx_errorbar": _format_stx_errorbar, - # Seaborn methods (sns_ prefix) - "sns_barplot": _format_sns_barplot, - "sns_boxplot": _format_sns_boxplot, - "sns_heatmap": _format_sns_heatmap, - "sns_histplot": _format_sns_histplot, - "sns_jointplot": _format_sns_jointplot, - "sns_kdeplot": _format_sns_kdeplot, - "sns_lineplot": _format_sns_lineplot, - "sns_pairplot": _format_sns_pairplot, - "sns_scatterplot": _format_sns_scatterplot, - "sns_stripplot": _format_sns_stripplot, - "sns_swarmplot": _format_sns_swarmplot, - "sns_violinplot": _format_sns_violinplot, -} - - -def _to_numpy(data): - """Convert various data types to numpy array. - - Handles torch tensors, pandas Series/DataFrame, and other array-like objects. - - Parameters - ---------- - data : array-like - Data to convert to numpy array - - Returns - ------- - numpy.ndarray - Data as numpy array - """ - if hasattr(data, "numpy"): # torch tensor - return data.detach().numpy() if hasattr(data, "detach") else data.numpy() - elif hasattr(data, "values"): # pandas series/dataframe - return data.values - else: - return np.asarray(data) - - -def export_as_csv(history_records): - """Convert plotting history records to a combined DataFrame suitable for CSV export. - - Args: - history_records (dict): Dictionary of plotting records. - - Returns: - pd.DataFrame: Combined DataFrame containing all plotting data. - - Raises: - ValueError: If no plotting records are found or they cannot be combined. - """ - if len(history_records) <= 0: - logger.warning("Plotting records not found. Cannot export empty data.") - return pd.DataFrame() # Return empty DataFrame instead of None - - dfs = [] - failed_methods = set() # Track failed methods for helpful warnings - - for record_index, record in enumerate(list(history_records.values())): - try: - formatted_df = format_record(record, record_index=record_index) - if formatted_df is not None and not formatted_df.empty: - dfs.append(formatted_df) - else: - # Track the method that failed to format - method_name = record[1] if len(record) > 1 else "unknown" - failed_methods.add(method_name) - except Exception as e: - method_name = record[1] if len(record) > 1 else "unknown" - failed_methods.add(method_name) - - # If no valid dataframes were created, provide helpful suggestions - if not dfs and failed_methods: - for method in failed_methods: - if method in _METHOD_ALTERNATIVES: - alternative = _METHOD_ALTERNATIVES[method] - message = ( - f"Matplotlib method '{method}()' does not support full data tracking for CSV export. " - f"Consider using 'ax.{alternative}()' instead for better data export support." - ) - else: - message = ( - f"Method '{method}()' does not support data tracking for CSV export. " - f"Consider using scitex plot methods (e.g., stx_image, plot_imshow) for data export support." - ) - _warn_once(message) - return pd.DataFrame() - - try: - # Reset index for each dataframe to avoid alignment issues - dfs_reset = [df.reset_index(drop=True) for df in dfs] - df = pd.concat(dfs_reset, axis=1) - return df - except Exception as e: - logger.warning(f"Failed to combine plotting records: {e}") - # Return a DataFrame with metadata about what records were attempted - meta_df = pd.DataFrame( - { - "record_id": [r[0] for r in history_records.values()], - "method": [r[1] for r in history_records.values()], - "has_data": [ - "Yes" if r[2] and r[2] != {} else "No" - for r in history_records.values() - ], - } - ) - return meta_df - - -def format_record(record, record_index=0): - """Route record to the appropriate formatting function based on plot method. - - Args: - record (tuple): Plotting record tuple (id, method, tracked_dict, kwargs). - record_index (int): Index of this record in the history (used as fallback - for trace_id when user doesn't provide an explicit id= kwarg). - - Returns: - pd.DataFrame: Formatted data for the plot record. - """ - id, method, tracked_dict, kwargs = record - - # Basic Matplotlib functions - if method == "plot": - return _format_plot(id, tracked_dict, kwargs) - elif method == "scatter": - return _format_scatter(id, tracked_dict, kwargs) - elif method == "bar": - return _format_bar(id, tracked_dict, kwargs) - elif method == "barh": - return _format_barh(id, tracked_dict, kwargs) - elif method == "hist": - return _format_hist(id, tracked_dict, kwargs) - elif method == "boxplot": - return _format_boxplot(id, tracked_dict, kwargs) - elif method == "contour": - return _format_contour(id, tracked_dict, kwargs) - elif method == "contourf": - return _format_contourf(id, tracked_dict, kwargs) - elif method == "errorbar": - return _format_errorbar(id, tracked_dict, kwargs) - elif method == "eventplot": - return _format_eventplot(id, tracked_dict, kwargs) - elif method == "fill": - return _format_fill(id, tracked_dict, kwargs) - elif method == "fill_between": - return _format_fill_between(id, tracked_dict, kwargs) - elif method == "stackplot": - return _format_stackplot(id, tracked_dict, kwargs) - elif method == "pcolormesh": - return _format_pcolormesh(id, tracked_dict, kwargs) - elif method == "pcolor": - return _format_pcolormesh(id, tracked_dict, kwargs) - elif method == "hexbin": - return _format_hexbin(id, tracked_dict, kwargs) - elif method == "hist2d": - return _format_hist2d(id, tracked_dict, kwargs) - elif method == "imshow": - return _format_imshow(id, tracked_dict, kwargs) - elif method == "imshow2d": - return _format_imshow2d(id, tracked_dict, kwargs) - elif method == "matshow": - return _format_matshow(id, tracked_dict, kwargs) - elif method == "pie": - return _format_pie(id, tracked_dict, kwargs) - elif method == "quiver": - return _format_quiver(id, tracked_dict, kwargs) - elif method == "stem": - return _format_stem(id, tracked_dict, kwargs) - elif method == "step": - return _format_step(id, tracked_dict, kwargs) - elif method == "streamplot": - return _format_streamplot(id, tracked_dict, kwargs) - elif method == "violin": - return _format_violin(id, tracked_dict, kwargs) - elif method == "violinplot": - return _format_violinplot(id, tracked_dict, kwargs) - elif method == "text": - return _format_text(id, tracked_dict, kwargs) - elif method == "annotate": - return _format_annotate(id, tracked_dict, kwargs) - - # Custom plotting functions - elif method == "stx_box": - return _format_plot_box(id, tracked_dict, kwargs) - elif method == "stx_conf_mat": - return _format_plot_conf_mat(id, tracked_dict, kwargs) - elif method == "stx_contour": - return _format_stx_contour(id, tracked_dict, kwargs) - elif method == "stx_ecdf": - return _format_plot_ecdf(id, tracked_dict, kwargs) - elif method == "stx_fillv": - return _format_plot_fillv(id, tracked_dict, kwargs) - elif method == "stx_heatmap": - return _format_plot_heatmap(id, tracked_dict, kwargs) - elif method == "stx_image": - return _format_plot_image(id, tracked_dict, kwargs) - elif method == "plot_imshow": - return _format_plot_imshow(id, tracked_dict, kwargs) - elif method == "stx_imshow": - return _format_stx_imshow(id, tracked_dict, kwargs) - elif method == "stx_joyplot": - return _format_plot_joyplot(id, tracked_dict, kwargs) - elif method == "stx_kde": - return _format_plot_kde(id, tracked_dict, kwargs) - elif method == "stx_line": - return _format_plot_line(id, tracked_dict, kwargs) - elif method == "stx_mean_ci": - return _format_plot_mean_ci(id, tracked_dict, kwargs) - elif method == "stx_mean_std": - return _format_plot_mean_std(id, tracked_dict, kwargs) - elif method == "stx_median_iqr": - return _format_plot_median_iqr(id, tracked_dict, kwargs) - elif method == "stx_raster": - return _format_plot_raster(id, tracked_dict, kwargs) - elif method == "stx_rectangle": - return _format_plot_rectangle(id, tracked_dict, kwargs) - elif method == "plot_scatter": - return _format_plot_scatter(id, tracked_dict, kwargs) - elif method == "stx_scatter_hist": - return _format_plot_scatter_hist(id, tracked_dict, kwargs) - elif method == "stx_shaded_line": - return _format_plot_shaded_line(id, tracked_dict, kwargs) - elif method == "stx_violin": - return _format_plot_violin(id, tracked_dict, kwargs) - - # stx_ aliases - elif method == "stx_scatter": - return _format_stx_scatter(id, tracked_dict, kwargs) - elif method == "stx_bar": - return _format_stx_bar(id, tracked_dict, kwargs) - elif method == "stx_barh": - return _format_stx_barh(id, tracked_dict, kwargs) - elif method == "stx_errorbar": - return _format_stx_errorbar(id, tracked_dict, kwargs) - - # Seaborn functions (sns_ prefix) - elif method == "sns_barplot": - return _format_sns_barplot(id, tracked_dict, kwargs) - elif method == "sns_boxplot": - return _format_sns_boxplot(id, tracked_dict, kwargs) - elif method == "sns_heatmap": - return _format_sns_heatmap(id, tracked_dict, kwargs) - elif method == "sns_histplot": - return _format_sns_histplot(id, tracked_dict, kwargs) - elif method == "sns_jointplot": - return _format_sns_jointplot(id, tracked_dict, kwargs) - elif method == "sns_kdeplot": - return _format_sns_kdeplot(id, tracked_dict, kwargs) - elif method == "sns_lineplot": - return _format_sns_lineplot(id, tracked_dict, kwargs) - elif method == "sns_pairplot": - return _format_sns_pairplot(id, tracked_dict, kwargs) - elif method == "sns_scatterplot": - return _format_sns_scatterplot(id, tracked_dict, kwargs) - elif method == "sns_stripplot": - return _format_sns_stripplot(id, tracked_dict, kwargs) - elif method == "sns_swarmplot": - return _format_sns_swarmplot(id, tracked_dict, kwargs) - elif method == "sns_violinplot": - return _format_sns_violinplot(id, tracked_dict, kwargs) - else: - # Unknown or unimplemented method - raise NotImplementedError( - f"CSV export for plot method '{method}' is not yet implemented in the scitex.plt module. " - f"Check the feature-request-export-as-csv-functions.md for implementation status." - ) - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/__init__.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/__init__.py deleted file mode 100755 index 90ffbd5ba..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/__init__.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/__init__.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from ._format_annotate import _format_annotate -from ._format_bar import _format_bar -from ._format_barh import _format_barh -from ._format_boxplot import _format_boxplot -from ._format_contour import _format_contour -from ._format_contourf import _format_contourf -from ._format_errorbar import _format_errorbar -from ._format_eventplot import _format_eventplot -from ._format_fill import _format_fill -from ._format_fill_between import _format_fill_between -from ._format_hexbin import _format_hexbin -from ._format_hist import _format_hist -from ._format_hist2d import _format_hist2d -from ._format_imshow import _format_imshow -from ._format_imshow2d import _format_imshow2d -from ._format_matshow import _format_matshow -from ._format_pcolormesh import _format_pcolormesh -from ._format_pie import _format_pie - -# Standard matplotlib formatters -from ._format_plot import _format_plot -from ._format_plot_box import _format_plot_box - -# Custom plotting formatters -from ._format_plot_imshow import _format_plot_imshow -from ._format_plot_kde import _format_plot_kde -from ._format_plot_scatter import _format_plot_scatter -from ._format_quiver import _format_quiver -from ._format_scatter import _format_scatter - -# Seaborn formatters (sns_ prefix) -from ._format_sns_barplot import _format_sns_barplot -from ._format_sns_boxplot import _format_sns_boxplot -from ._format_sns_heatmap import _format_sns_heatmap -from ._format_sns_histplot import _format_sns_histplot -from ._format_sns_jointplot import _format_sns_jointplot -from ._format_sns_kdeplot import _format_sns_kdeplot -from ._format_sns_lineplot import _format_sns_lineplot -from ._format_sns_pairplot import _format_sns_pairplot -from ._format_sns_scatterplot import _format_sns_scatterplot -from ._format_sns_stripplot import _format_sns_stripplot -from ._format_sns_swarmplot import _format_sns_swarmplot -from ._format_sns_violinplot import _format_sns_violinplot -from ._format_stackplot import _format_stackplot -from ._format_stem import _format_stem -from ._format_step import _format_step -from ._format_streamplot import _format_streamplot -from ._format_stx_bar import _format_stx_bar -from ._format_stx_barh import _format_stx_barh -from ._format_stx_conf_mat import _format_plot_conf_mat -from ._format_stx_contour import _format_stx_contour -from ._format_stx_ecdf import _format_plot_ecdf -from ._format_stx_errorbar import _format_stx_errorbar -from ._format_stx_fillv import _format_plot_fillv -from ._format_stx_heatmap import _format_plot_heatmap -from ._format_stx_image import _format_plot_image -from ._format_stx_imshow import _format_stx_imshow -from ._format_stx_joyplot import _format_plot_joyplot -from ._format_stx_line import _format_plot_line -from ._format_stx_mean_ci import _format_plot_mean_ci -from ._format_stx_mean_std import _format_plot_mean_std -from ._format_stx_median_iqr import _format_plot_median_iqr -from ._format_stx_raster import _format_plot_raster -from ._format_stx_rectangle import _format_plot_rectangle - -# stx_ aliases formatters -from ._format_stx_scatter import _format_stx_scatter -from ._format_stx_scatter_hist import _format_plot_scatter_hist -from ._format_stx_shaded_line import _format_plot_shaded_line -from ._format_stx_violin import _format_plot_violin -from ._format_text import _format_text -from ._format_violin import _format_violin -from ._format_violinplot import _format_violinplot diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_annotate.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_annotate.py deleted file mode 100755 index 6d592ae03..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_annotate.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-10-04 02:30:00 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/SciTeX-Code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_annotate.py -# ---------------------------------------- -from __future__ import annotations - -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_annotate(id, tracked_dict, kwargs): - """Format data from an annotate call. - - matplotlib annotate signature: annotate(text, xy, xytext=None, **kwargs) - - text: The text of the annotation - - xy: The point (x, y) to annotate - - xytext: The position (x, y) to place the text at (optional) - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Extract text and xy coordinates if available - if len(args) >= 2: - text_content = args[0] - xy = args[1] - - # xy should be a tuple (x, y) - if hasattr(xy, "__len__") and len(xy) >= 2: - x, y = xy[0], xy[1] - else: - return pd.DataFrame() - - data = { - get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): [x], - get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id): [y], - get_csv_column_name("content", ax_row, ax_col, trace_id=trace_id): [ - text_content - ], - } - - # Check if xytext was provided (either as third arg or in kwargs) - xytext = None - if len(args) >= 3: - xytext = args[2] - elif "xytext" in kwargs: - xytext = kwargs["xytext"] - - if xytext is not None and hasattr(xytext, "__len__") and len(xytext) >= 2: - data[get_csv_column_name("text_x", ax_row, ax_col, trace_id=trace_id)] = [ - xytext[0] - ] - data[get_csv_column_name("text_y", ax_row, ax_col, trace_id=trace_id)] = [ - xytext[1] - ] - - # Create DataFrame with proper column names (use dict with list values) - df = pd.DataFrame(data) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py deleted file mode 100755 index 10c2cead3..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-19 15:45:51 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_bar(id, tracked_dict, kwargs): - """Format data from a bar call for CSV export. - - Includes x, y values and optional yerr for error bars. - - Args: - id: The identifier for the plot - tracked_dict: Dictionary of tracked data - kwargs: Original keyword arguments (may contain yerr) - - Returns: - pd.DataFrame: Formatted data ready for CSV export with x, y, and optional yerr - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get structured column names - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - col_yerr = get_csv_column_name("yerr", ax_row, ax_col, trace_id=trace_id) - - # Extract yerr from kwargs if present - yerr = kwargs.get("yerr") if kwargs else None - - # Check if we have the newer format with bar_data - if "bar_data" in tracked_dict and isinstance( - tracked_dict["bar_data"], pd.DataFrame - ): - # Use the pre-formatted DataFrame but keep only x and height (y) - df = tracked_dict["bar_data"].copy() - - # Keep only essential columns - essential_cols = [col for col in df.columns if col in ["x", "height"]] - if essential_cols: - df = df[essential_cols] - - # Rename using structured naming - rename_map = {} - if "x" in df.columns: - rename_map["x"] = col_x - if "height" in df.columns: - rename_map["height"] = col_y - - df = df.rename(columns=rename_map) - - # Add yerr if present - if yerr is not None: - try: - yerr_array = np.asarray(yerr) - if len(yerr_array) == len(df): - df[col_yerr] = yerr_array - except (TypeError, ValueError): - pass - - return df - - # Legacy format - get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Extract x and y data if available - if len(args) >= 2: - x, y = args[0], args[1] - - # Convert to arrays if possible for consistent handling - try: - x_array = np.asarray(x) - y_array = np.asarray(y) - - # Create DataFrame with structured column names - data = { - col_x: x_array, - col_y: y_array, - } - - # Add yerr if present - if yerr is not None: - try: - yerr_array = np.asarray(yerr) - if len(yerr_array) == len(x_array): - data[col_yerr] = yerr_array - except (TypeError, ValueError): - pass - - return pd.DataFrame(data) - - except (TypeError, ValueError): - # Fall back to direct values if conversion fails - result = {col_x: x, col_y: y} - if yerr is not None: - result[col_yerr] = yerr - return pd.DataFrame(result) - - # If we have tracked data in another format (like our MatplotlibPlotMixin bar method) - result = {} - - # Check for x position (might be in different keys) - for x_key in ["x", "xs", "positions"]: - if x_key in tracked_dict: - result[col_x] = tracked_dict[x_key] - break - - # Check for y values (might be in different keys) - for y_key in ["y", "ys", "height", "heights", "values"]: - if y_key in tracked_dict: - result[col_y] = tracked_dict[y_key] - break - - # Add yerr if present in kwargs - if yerr is not None and result: - try: - yerr_array = np.asarray(yerr) - result[col_yerr] = yerr_array - except (TypeError, ValueError): - pass - - return pd.DataFrame(result) if result else pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py deleted file mode 100755 index 29ae9b25b..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_barh(id, tracked_dict, kwargs): - """Format data from a barh call.""" - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Extract x and y data if available - if len(args) >= 2: - # Note: in barh, first arg is y positions, second is widths (x values) - y_pos, x_width = args[0], args[1] - - # Get xerr from kwargs - xerr = kwargs.get("xerr") - - # Convert single values to Series - if isinstance(y_pos, (int, float)): - y_pos = pd.Series(y_pos, name="y") - if isinstance(x_width, (int, float)): - x_width = pd.Series(x_width, name="x") - else: - # Not enough arguments - return pd.DataFrame() - - # Use structured column naming: ax-row-{row}-col-{col}_trace-id-{id}_variable-{var} - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - - df = pd.DataFrame({col_y: y_pos, col_x: x_width}) - - if xerr is not None: - if isinstance(xerr, (int, float)): - xerr = pd.Series(xerr, name="xerr") - col_xerr = get_csv_column_name("xerr", ax_row, ax_col, trace_id=trace_id) - df[col_xerr] = xerr - return df diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py deleted file mode 100755 index d4cd41762..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py +++ /dev/null @@ -1,81 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_boxplot(id, tracked_dict, kwargs): - """Format data from a boxplot call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to boxplot - - Returns: - pd.DataFrame: Formatted data from boxplot - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - args = tracked_dict.get("args", []) - call_kwargs = tracked_dict.get("kwargs", {}) - - # Get labels if provided (for consistent naming with stats) - labels = call_kwargs.get("labels", None) - - if len(args) >= 1: - x = args[0] - - # One box plot - from scitex.types import is_listed_X as scitex_types_is_listed_X - - if isinstance(x, np.ndarray) or scitex_types_is_listed_X(x, [float, int]): - df = pd.DataFrame(x) - # Use label if single box and labels provided - if labels and len(labels) == 1: - col_name = get_csv_column_name( - labels[0], ax_row, ax_col, trace_id=trace_id - ) - else: - col_name = get_csv_column_name( - "data-0", ax_row, ax_col, trace_id=trace_id - ) - df.columns = [col_name] - else: - # Multiple boxes - import scitex.pd - - df = scitex.pd.force_df({i_x: _x for i_x, _x in enumerate(x)}) - - # Use labels if provided, otherwise use numeric indices - if labels and len(labels) == len(df.columns): - df.columns = [ - get_csv_column_name(label, ax_row, ax_col, trace_id=trace_id) - for label in labels - ] - else: - df.columns = [ - get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - for col in range(len(df.columns)) - ] - - df = df.apply(lambda col: col.dropna().reset_index(drop=True)) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py deleted file mode 100755 index 6508aadd9..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_contour(id, tracked_dict, kwargs): - """Format data from a contour call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to contour - - Returns: - pd.DataFrame: Formatted data from contour plot (flattened X, Y, Z grids) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - args = tracked_dict.get("args", []) - - # Typical args: X, Y, Z where X and Y are 2D coordinate arrays and Z is the height array - if len(args) >= 3: - X, Y, Z = args[:3] - X_flat = np.asarray(X).flatten() - Y_flat = np.asarray(Y).flatten() - Z_flat = np.asarray(Z).flatten() - - # Get column names from single source of truth - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - col_z = get_csv_column_name("z", ax_row, ax_col, trace_id=trace_id) - - df = pd.DataFrame({col_x: X_flat, col_y: Y_flat, col_z: Z_flat}) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py deleted file mode 100755 index b490e440a..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_contourf(id, tracked_dict, kwargs): - """Format data from a filled contour plot call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to contourf - - Returns: - pd.DataFrame: Formatted data from contourf (flattened X, Y, Z grids) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple): - # contourf can be called as: - # contourf(Z) - Z is 2D - # contourf(X, Y, Z) - X, Y are 1D or 2D, Z is 2D - if len(args) == 1: - Z = np.asarray(args[0]) - X, Y = np.meshgrid(np.arange(Z.shape[1]), np.arange(Z.shape[0])) - elif len(args) >= 3: - X = np.asarray(args[0]) - Y = np.asarray(args[1]) - Z = np.asarray(args[2]) - # If X, Y are 1D, create meshgrid - if X.ndim == 1 and Y.ndim == 1: - X, Y = np.meshgrid(X, Y) - else: - return pd.DataFrame() - - # Get column names from single source of truth - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - col_z = get_csv_column_name("z", ax_row, ax_col, trace_id=trace_id) - - df = pd.DataFrame( - {col_x: X.flatten(), col_y: Y.flatten(), col_z: Z.flatten()} - ) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py deleted file mode 100755 index 18b800cab..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_errorbar(id, tracked_dict, kwargs): - """Format data from an errorbar call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to errorbar - - Returns: - pd.DataFrame: Formatted data from errorbar plot - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - args = tracked_dict.get("args", []) - - if len(args) >= 2: - x, y = args[:2] - xerr = kwargs.get("xerr") - yerr = kwargs.get("yerr") - - # Get column names from single source of truth - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - - data = {col_x: x, col_y: y} - - if xerr is not None: - if isinstance(xerr, (list, tuple)) and len(xerr) == 2: - col_xerr_neg = get_csv_column_name( - "xerr-neg", ax_row, ax_col, trace_id=trace_id - ) - col_xerr_pos = get_csv_column_name( - "xerr-pos", ax_row, ax_col, trace_id=trace_id - ) - data[col_xerr_neg] = xerr[0] - data[col_xerr_pos] = xerr[1] - else: - col_xerr = get_csv_column_name( - "xerr", ax_row, ax_col, trace_id=trace_id - ) - data[col_xerr] = xerr - - if yerr is not None: - if isinstance(yerr, (list, tuple)) and len(yerr) == 2: - col_yerr_neg = get_csv_column_name( - "yerr-neg", ax_row, ax_col, trace_id=trace_id - ) - col_yerr_pos = get_csv_column_name( - "yerr-pos", ax_row, ax_col, trace_id=trace_id - ) - data[col_yerr_neg] = yerr[0] - data[col_yerr_pos] = yerr[1] - else: - col_yerr = get_csv_column_name( - "yerr", ax_row, ax_col, trace_id=trace_id - ) - data[col_yerr] = yerr - - # Handle different length arrays by padding - max_len = max( - len(arr) if hasattr(arr, "__len__") else 1 - for arr in data.values() - if arr is not None - ) - - for key, value in list(data.items()): - if value is None: - continue - if not hasattr(value, "__len__"): - data[key] = [value] * max_len - elif len(value) < max_len: - data[key] = np.pad( - np.asarray(value), - (0, max_len - len(value)), - mode="constant", - constant_values=np.nan, - ) - - return pd.DataFrame(data) - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py deleted file mode 100755 index b13cd7f9c..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -import pandas as pd - -import scitex -from scitex import logging -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - -logger = logging.getLogger(__name__) - - -def _format_eventplot(id, tracked_dict, kwargs): - """Format data from an eventplot call.""" - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Eventplot displays multiple sets of events as parallel lines - if len(args) >= 1: - positions = args[0] - - try: - # Try using scitex.pd.force_df if available - try: - import scitex.pd - - # If positions is a single array - if isinstance(positions, (list, np.ndarray)) and not isinstance( - positions[0], (list, np.ndarray) - ): - col_name = get_csv_column_name( - "eventplot-events", ax_row, ax_col, trace_id=trace_id - ) - return pd.DataFrame({col_name: positions}) - - # If positions is a list of arrays (multiple event sets) - elif isinstance(positions, (list, np.ndarray)): - data = {} - for i, events in enumerate(positions): - col_name = get_csv_column_name( - f"eventplot-events{i:02d}", - ax_row, - ax_col, - trace_id=f"{trace_id}-{i}", - ) - data[col_name] = events - - # Use force_df to handle different length arrays - return scitex.pd.force_df(data) - - except (ImportError, AttributeError): - # Fall back to pandas with manual Series creation - # If positions is a single array - if isinstance(positions, (list, np.ndarray)) and not isinstance( - positions[0], (list, np.ndarray) - ): - col_name = get_csv_column_name( - "eventplot-events", ax_row, ax_col, trace_id=trace_id - ) - return pd.DataFrame({col_name: positions}) - - # If positions is a list of arrays (multiple event sets) - elif isinstance(positions, (list, np.ndarray)): - # Create a DataFrame where each column is a Series that can handle varying lengths - df = pd.DataFrame() - for i, events in enumerate(positions): - col_name = get_csv_column_name( - f"eventplot-events{i:02d}", - ax_row, - ax_col, - trace_id=f"{trace_id}-{i}", - ) - df[col_name] = pd.Series(events) - return df - except Exception as e: - # If all else fails, return an empty DataFrame - logger.warning(f"Error formatting eventplot data: {str(e)}") - return pd.DataFrame() - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py deleted file mode 100755 index 090d7a202..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_fill(id, tracked_dict, kwargs): - """Format data from a fill call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to fill - - Returns: - pd.DataFrame: Formatted data from fill plot - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - args = tracked_dict.get("args", []) - - # Fill creates a polygon based on points - if len(args) >= 2: - x = args[0] - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - data = {col_x: x} - - for i, y in enumerate(args[1:]): - col_y = get_csv_column_name(f"y{i:02d}", ax_row, ax_col, trace_id=trace_id) - data[col_y] = y - - return pd.DataFrame(data) - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py deleted file mode 100755 index 7db801eed..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_fill_between(id, tracked_dict, kwargs): - """Format data from a fill_between call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to fill_between - - Returns: - pd.DataFrame: Formatted data from fill_between plot - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - args = tracked_dict.get("args", []) - - # Typical args: x, y1, y2 - if len(args) >= 3: - x, y1, y2 = args[:3] - - # Get column names from single source of truth - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y1 = get_csv_column_name("y1", ax_row, ax_col, trace_id=trace_id) - col_y2 = get_csv_column_name("y2", ax_row, ax_col, trace_id=trace_id) - - df = pd.DataFrame({col_x: x, col_y1: y1, col_y2: y2}) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py deleted file mode 100755 index f62f27da0..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_hexbin(id, tracked_dict, kwargs): - """Format data from a hexbin call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to hexbin - - Returns: - pd.DataFrame: Formatted data from hexbin (input x, y data) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple) and len(args) >= 2: - x = np.asarray(args[0]).flatten() - y = np.asarray(args[1]).flatten() - - # Ensure same length - min_len = min(len(x), len(y)) - x = x[:min_len] - y = y[:min_len] - - # Get column names from single source of truth - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - - df = pd.DataFrame({col_x: x, col_y: y}) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py deleted file mode 100755 index 8e7890818..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_hist(id, tracked_dict, kwargs): - """ - Format data from a hist call as a bar plot representation. - - This formatter extracts both the raw data and the binned data from histogram plots, - returning them in a format that can be visualized as a bar plot. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to hist - - Returns: - pd.DataFrame: DataFrame containing both raw data and bin information - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Check if histogram result (bin counts and edges) is available in tracked_dict - hist_result = tracked_dict.get("hist_result", None) - - columns = {} - - # Extract raw data if available - if len(args) >= 1: - x = args[0] - col_raw = get_csv_column_name("raw-data", ax_row, ax_col, trace_id=trace_id) - columns[col_raw] = x - - # If we have histogram result (counts and bin edges) - if hist_result is not None: - counts, bin_edges = hist_result - - # Calculate bin centers for bar plot representation - bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) - bin_widths = bin_edges[1:] - bin_edges[:-1] - - # Use structured column naming - col_centers = get_csv_column_name( - "bin-centers", ax_row, ax_col, trace_id=trace_id - ) - col_counts = get_csv_column_name( - "bin-counts", ax_row, ax_col, trace_id=trace_id - ) - col_widths = get_csv_column_name( - "bin-widths", ax_row, ax_col, trace_id=trace_id - ) - col_left = get_csv_column_name( - "bin-edges-left", ax_row, ax_col, trace_id=trace_id - ) - col_right = get_csv_column_name( - "bin-edges-right", ax_row, ax_col, trace_id=trace_id - ) - - # Add bin information to DataFrame - columns[col_centers] = bin_centers - columns[col_counts] = counts - columns[col_widths] = bin_widths - columns[col_left] = bin_edges[:-1] - columns[col_right] = bin_edges[1:] - - # Create DataFrame with aligned length - max_length = max(len(value) for value in columns.values()) - for key, value in list(columns.items()): - if len(value) < max_length: - # Pad with NaN if needed - convert to float first for NaN support - arr = np.asarray(value, dtype=float) - padded = np.full(max_length, np.nan) - padded[: len(arr)] = arr - columns[key] = padded - - # Return DataFrame or empty DataFrame if no data - if columns: - return pd.DataFrame(columns) - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py deleted file mode 100755 index ccabba0ff..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_hist2d(id, tracked_dict, kwargs): - """Format data from a 2D histogram call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to hist2d - - Returns: - pd.DataFrame: Formatted data from 2D histogram (input x, y data) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple) and len(args) >= 2: - x = np.asarray(args[0]).flatten() - y = np.asarray(args[1]).flatten() - - # Ensure same length - min_len = min(len(x), len(y)) - x = x[:min_len] - y = y[:min_len] - - # Get column names from single source of truth - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - - df = pd.DataFrame({col_x: x, col_y: y}) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py deleted file mode 100755 index c0d4c249a..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_imshow(id, tracked_dict, kwargs): - """Format data from an imshow call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to imshow - - Returns: - pd.DataFrame: Formatted data from imshow (flattened image with row, col indices) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Check for pre-formatted image_df (from plot_imshow wrapper) - if tracked_dict.get("image_df") is not None: - return tracked_dict.get("image_df") - - # Handle raw args from __getattr__ proxied calls - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple) and len(args) > 0: - img = np.asarray(args[0]) - - # Handle 2D grayscale image - if img.ndim == 2: - rows, cols = img.shape - row_indices, col_indices = np.meshgrid( - range(rows), range(cols), indexing="ij" - ) - - # Get column names from single source of truth - col_row = get_csv_column_name("row", ax_row, ax_col, trace_id=trace_id) - col_col = get_csv_column_name("col", ax_row, ax_col, trace_id=trace_id) - col_value = get_csv_column_name( - "value", ax_row, ax_col, trace_id=trace_id - ) - - df = pd.DataFrame( - { - col_row: row_indices.flatten(), - col_col: col_indices.flatten(), - col_value: img.flatten(), - } - ) - return df - - # Handle RGB/RGBA images (3D array) - elif img.ndim == 3: - rows, cols, channels = img.shape - row_indices, col_indices = np.meshgrid( - range(rows), range(cols), indexing="ij" - ) - - # Get column names from single source of truth - col_row = get_csv_column_name("row", ax_row, ax_col, trace_id=trace_id) - col_col = get_csv_column_name("col", ax_row, ax_col, trace_id=trace_id) - - data = { - col_row: row_indices.flatten(), - col_col: col_indices.flatten(), - } - - # Add channel data (R, G, B, A) - channel_names = ["r", "g", "b", "a"][:channels] - for c, name in enumerate(channel_names): - col_channel = get_csv_column_name( - name, ax_row, ax_col, trace_id=trace_id - ) - data[col_channel] = img[:, :, c].flatten() - - return pd.DataFrame(data) - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py deleted file mode 100755 index 1487eb695..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_imshow2d(id, tracked_dict, kwargs): - """Format data from an imshow2d call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to imshow2d - - Returns: - pd.DataFrame: Formatted data from imshow2d - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Extract data if available - if len(args) >= 1 and isinstance(args[0], pd.DataFrame): - df = args[0].copy() - # Rename columns using the single source of truth - renamed_cols = {} - for col in df.columns: - renamed_cols[col] = get_csv_column_name( - f"imshow2d_{col}", ax_row, ax_col, trace_id=trace_id - ) - df = df.rename(columns=renamed_cols) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py deleted file mode 100755 index 0ec0d495a..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_matshow(id, tracked_dict, kwargs): - """Format data from a matshow call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to matshow - - Returns: - pd.DataFrame: Formatted data from matshow (flattened matrix with row, col indices) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple) and len(args) > 0: - Z = np.asarray(args[0]) - - # Create row/col indices - rows, cols = np.indices(Z.shape) - - df = pd.DataFrame( - { - get_csv_column_name( - "row", ax_row, ax_col, trace_id=trace_id - ): rows.flatten(), - get_csv_column_name( - "col", ax_row, ax_col, trace_id=trace_id - ): cols.flatten(), - get_csv_column_name( - "value", ax_row, ax_col, trace_id=trace_id - ): Z.flatten(), - } - ) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pcolormesh.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pcolormesh.py deleted file mode 100755 index b6757b5f0..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pcolormesh.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-21 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pcolormesh.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_pcolormesh(id, tracked_dict, kwargs): - """Format data from a pcolormesh call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to pcolormesh - - Returns: - pd.DataFrame: Formatted data from pcolormesh (x, y, value columns) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - args = tracked_dict.get("args", ()) - - if len(args) == 0: - return pd.DataFrame() - - # pcolormesh can be called as: - # pcolormesh(C) - just color values - # pcolormesh(X, Y, C) - with coordinates - if len(args) == 1: - # Just C provided - C = np.asarray(args[0]) - rows, cols = C.shape - Y, X = np.meshgrid(range(rows), range(cols), indexing="ij") - elif len(args) >= 3: - # X, Y, C provided - X = np.asarray(args[0]) - Y = np.asarray(args[1]) - C = np.asarray(args[2]) - else: - return pd.DataFrame() - - # Get column names - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) - - # Flatten for CSV format - df = pd.DataFrame( - { - col_x: X.flatten(), - col_y: Y.flatten(), - col_value: C.flatten(), - } - ) - - return df - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py deleted file mode 100755 index 9e07675c3..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_pie(id, tracked_dict, kwargs): - """Format data from a pie chart call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to pie - - Returns: - pd.DataFrame: Formatted data from pie chart - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple) and len(args) > 0: - x = np.asarray(args[0]) - - # Get column names from single source of truth - col_values = get_csv_column_name( - "values", ax_row, ax_col, trace_id=trace_id - ) - data = {col_values: x} - - # Add labels if provided - labels = kwargs.get("labels", None) - if labels is not None: - col_labels = get_csv_column_name( - "labels", ax_row, ax_col, trace_id=trace_id - ) - data[col_labels] = labels - - df = pd.DataFrame(data) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py deleted file mode 100755 index 4312d2510..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py +++ /dev/null @@ -1,218 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-08 18:45:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py - -"""CSV formatter for matplotlib plot() calls.""" - -from collections import OrderedDict -from typing import Any, Dict, Optional - -import numpy as np -import pandas as pd -import xarray as xr - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - - -def _parse_tracking_id(id: str, record_index: int = 0) -> tuple: - """Parse tracking ID to extract axes position and trace ID. - - Parameters - ---------- - id : str - Tracking ID like "ax_00_plot_0", "ax_00_stim-box", "plot_0", - or user-provided like "sine" - record_index : int - Index of this record in the history (fallback for trace_id) - - Returns - ------- - tuple - (ax_row, ax_col, trace_id) - trace_id is a string - either the user-provided ID (e.g., "sine") - or the record_index as string (e.g., "0") - - Note - ---- - When user provides a custom ID like "sine", that ID is preserved in the - column names for clarity and traceability. - - Examples - -------- - >>> _parse_tracking_id("ax_00_plot_0") - (0, 0, 'plot_0') - >>> _parse_tracking_id("ax_00_stim-box") - (0, 0, 'stim-box') - >>> _parse_tracking_id("ax_12_text_0") - (1, 2, 'text_0') - >>> _parse_tracking_id("ax_10_violin") - (1, 0, 'violin') - """ - ax_row, ax_col = 0, 0 - trace_id = str(record_index) # Default to record_index as string - - if id.startswith("ax_"): - parts = id.split("_") - if len(parts) >= 2: - ax_pos = parts[1] - if len(ax_pos) >= 2: - try: - ax_row = int(ax_pos[0]) - ax_col = int(ax_pos[1]) - except ValueError: - pass - # Extract trace ID from parts[2:] (everything after "ax_XX_") - # e.g., "ax_00_stim-box" -> parts = ["ax", "00", "stim-box"] -> trace_id = "stim-box" - # e.g., "ax_00_plot_0" -> parts = ["ax", "00", "plot", "0"] -> trace_id = "plot_0" - # e.g., "ax_12_text_0" -> parts = ["ax", "12", "text", "0"] -> trace_id = "text_0" - if len(parts) >= 3: - trace_id = "_".join(parts[2:]) - elif id.startswith("plot_"): - # Extract everything after "plot_" as the trace_id - trace_id = id[5:] if len(id) > 5 else str(record_index) - else: - # User-provided ID like "sine", "cosine" - use it directly - trace_id = id - - return ax_row, ax_col, trace_id - - -def _format_plot( - id: str, - tracked_dict: Optional[Dict[str, Any]], - kwargs: Dict[str, Any], -) -> pd.DataFrame: - """Format data from a plot() call for CSV export. - - Handles various input formats including: - - Pre-formatted plot_df from scitex wrappers - - Raw args from __getattr__ proxied matplotlib calls - - Single array: plot(y) generates x from indices - - Two arrays: plot(x, y) - - 2D arrays: creates multiple x/y column pairs - - Parameters - ---------- - id : str - Identifier prefix for the output columns (e.g., "ax_00_plot_0"). - tracked_dict : dict or None - Dictionary containing tracked data. May include: - - 'plot_df': Pre-formatted DataFrame from wrapper - - 'args': Raw positional arguments (x, y) from plot() - kwargs : dict - Keyword arguments passed to plot (currently unused). - - Returns - ------- - pd.DataFrame - Formatted data with columns using single source of truth naming. - Format: ax-row_0_ax-col_0_trace-id_sine_variable_x - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # For stx_line, we expect a 'plot_df' key - if "plot_df" in tracked_dict: - plot_df = tracked_dict["plot_df"] - if isinstance(plot_df, pd.DataFrame): - # Rename columns using single source of truth - renamed = {} - for col in plot_df.columns: - if col == "plot_x": - renamed[col] = get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ) - elif col == "plot_y": - renamed[col] = get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ) - else: - # For other columns, use simplified naming - renamed[col] = get_csv_column_name( - col, ax_row, ax_col, trace_id=trace_id - ) - return plot_df.rename(columns=renamed) - - # Handle raw args from __getattr__ proxied calls - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple) and len(args) > 0: - # Get column names from single source of truth - x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - y_col = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - - # Handle single argument: plot(y) or plot(data_2d) - if len(args) == 1: - args_value = args[0] - - # Convert to numpy for consistent handling - if hasattr(args_value, "values"): # pandas Series/DataFrame - args_value = args_value.values - args_value = np.asarray(args_value) - - # 2D array: extract x and y columns - if hasattr(args_value, "ndim") and args_value.ndim == 2: - x, y = args_value[:, 0], args_value[:, 1] - df = pd.DataFrame({x_col: x, y_col: y}) - return df - - # 1D array: generate x from indices (common case: plot(y)) - elif hasattr(args_value, "ndim") and args_value.ndim == 1: - x = np.arange(len(args_value)) - y = args_value - df = pd.DataFrame({x_col: x, y_col: y}) - return df - - # Handle two arguments: plot(x, y) - elif len(args) >= 2: - x_arg, y_arg = args[0], args[1] - - # Convert to numpy - x = np.asarray(x_arg.values if hasattr(x_arg, "values") else x_arg) - y = np.asarray(y_arg.values if hasattr(y_arg, "values") else y_arg) - - # Handle 2D y array (multiple lines) - if hasattr(y, "ndim") and y.ndim == 2: - out = OrderedDict() - for ii in range(y.shape[1]): - x_col_i = get_csv_column_name( - f"x{ii:02d}", ax_row, ax_col, trace_id=f"{trace_id}-{ii}" - ) - y_col_i = get_csv_column_name( - f"y{ii:02d}", ax_row, ax_col, trace_id=f"{trace_id}-{ii}" - ) - out[x_col_i] = x - out[y_col_i] = y[:, ii] - df = pd.DataFrame(out) - return df - - # Handle DataFrame y - if isinstance(y_arg, pd.DataFrame): - result = {x_col: x} - for ii, col in enumerate(y_arg.columns): - y_col_i = get_csv_column_name( - f"y{ii:02d}", ax_row, ax_col, trace_id=f"{trace_id}-{ii}" - ) - result[y_col_i] = np.array(y_arg[col]) - df = pd.DataFrame(result) - return df - - # Handle 1D arrays (most common case: plot(x, y)) - if hasattr(y, "ndim") and y.ndim == 1: - # Flatten x if needed - x_flat = np.ravel(x) - y_flat = np.ravel(y) - df = pd.DataFrame({x_col: x_flat, y_col: y_flat}) - return df - - # Fallback for list-like y - df = pd.DataFrame({x_col: np.ravel(x), y_col: np.ravel(y)}) - return df - - # Default empty DataFrame if we can't process the input - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py deleted file mode 100755 index a9f2df172..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py - -"""CSV formatter for stx_box() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_box(id, tracked_dict, kwargs): - """Format data from a stx_box call. - - Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stx_box - - Returns: - pd.DataFrame: Formatted box plot data with standard column names - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # First try to get data directly from tracked_dict - data = tracked_dict.get("data") - - # If no data key, get from args - if data is None: - args = tracked_dict.get("args", []) - if len(args) >= 1: - data = args[0] - else: - return pd.DataFrame() - - # If data is a simple array or list of values - if isinstance(data, (np.ndarray, list)) and len(data) > 0: - try: - # Check if it's a simple list of values or a list of lists - if isinstance(data[0], (int, float, np.number)): - col_name = get_csv_column_name( - "data", ax_row, ax_col, trace_id=trace_id - ) - return pd.DataFrame({col_name: data}) - - # If data is a list of arrays (multiple box plots) - elif isinstance(data, (list, tuple)) and all( - isinstance(x, (list, np.ndarray)) for x in data - ): - result = pd.DataFrame() - for i, values in enumerate(data): - try: - col_name = get_csv_column_name( - f"data-{i}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = pd.Series(values) - except Exception: - pass - return result - except (IndexError, TypeError): - pass - - # If data is a dictionary - elif isinstance(data, dict): - result = pd.DataFrame() - for label, values in data.items(): - try: - col_name = get_csv_column_name( - f"data-{label}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = pd.Series(values) - except Exception: - pass - return result - - # If data is a DataFrame - elif isinstance(data, pd.DataFrame): - result = pd.DataFrame() - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - return result - - # Default case: return empty DataFrame if nothing could be processed - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_imshow.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_imshow.py deleted file mode 100755 index 355dbc7bf..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_imshow.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-11-18 11:40:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_imshow.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_imshow(id, tracked_dict, kwargs): - """Format data from a plot_imshow call. - - Args: - id: Plot identifier - tracked_dict: Dictionary containing tracked data with key "imshow_df" - kwargs: Additional keyword arguments - - Returns: - pd.DataFrame: Formatted image data for CSV export - """ - # Check for imshow_df in tracked_dict - if tracked_dict.get("imshow_df") is not None: - df = tracked_dict["imshow_df"] - - # Add id prefix to column names if id is provided - if id is not None: - # Parse tracking ID to extract axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Use standardized column naming for each column - df = df.copy() - renamed_cols = {} - for col in df.columns: - # Create column name like "plot_imshow_row" or "plot_imshow_col" - renamed_cols[col] = get_csv_column_name( - f"plot_imshow_{col}", ax_row, ax_col, trace_id=trace_id - ) - df.rename(columns=renamed_cols, inplace=True) - - return df - - # Fallback: return empty DataFrame - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py deleted file mode 100755 index 1eeefc45d..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pandas as pd - -from scitex.pd import force_df -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_kde(id, tracked_dict, kwargs): - """Format data from a stx_kde call. - - Processes kernel density estimation plot data. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing 'x', 'kde', and 'n' keys - kwargs (dict): Keyword arguments passed to stx_kde - - Returns: - pd.DataFrame: Formatted KDE data - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - x = tracked_dict.get("x") - kde = tracked_dict.get("kde") - n = tracked_dict.get("n") - - if x is None or kde is None: - return pd.DataFrame() - - # Parse tracking ID to extract axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Use standardized column naming - x_col = get_csv_column_name("kde_x", ax_row, ax_col, trace_id=trace_id) - density_col = get_csv_column_name("kde_density", ax_row, ax_col, trace_id=trace_id) - - df = pd.DataFrame({x_col: x, density_col: kde}) - - # Add sample count if available - if n is not None: - # If n is a scalar, create a list with the same length as x - if not hasattr(n, "__len__"): - n = [n] * len(x) - n_col = get_csv_column_name("kde_n", ax_row, ax_col, trace_id=trace_id) - df[n_col] = n - - return df diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter.py deleted file mode 100755 index 8e2a05715..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-10-03 02:47:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_scatter(id, tracked_dict, kwargs): - """Format data from a plot_scatter call. - - The plot_scatter method stores data as: - {"scatter_df": pd.DataFrame({"x": args[0], "y": args[1]})} - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Get the scatter_df from tracked_dict - scatter_df = tracked_dict.get("scatter_df") - - if scatter_df is not None and isinstance(scatter_df, pd.DataFrame): - # Parse tracking ID to extract axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Use standardized column naming - x_col = get_csv_column_name("scatter_x", ax_row, ax_col, trace_id=trace_id) - y_col = get_csv_column_name("scatter_y", ax_row, ax_col, trace_id=trace_id) - - # Rename columns to include the id - return scatter_df.rename(columns={"x": x_col, "y": y_col}) - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py deleted file mode 100755 index 809ac75e0..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 12:20:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_quiver(id, tracked_dict, kwargs): - """Format data from a quiver (vector field) call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to quiver - - Returns: - pd.DataFrame: Formatted data from quiver (X, Y positions and U, V vectors) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple): - # quiver can be called as: - # quiver(U, V) - positions auto-generated - # quiver(X, Y, U, V) - explicit positions - if len(args) == 2: - U = np.asarray(args[0]) - V = np.asarray(args[1]) - X, Y = np.meshgrid(np.arange(U.shape[1]), np.arange(U.shape[0])) - elif len(args) >= 4: - X = np.asarray(args[0]) - Y = np.asarray(args[1]) - U = np.asarray(args[2]) - V = np.asarray(args[3]) - else: - return pd.DataFrame() - - df = pd.DataFrame( - { - get_csv_column_name( - "quiver-x", ax_row, ax_col, trace_id=trace_id - ): X.flatten(), - get_csv_column_name( - "quiver-y", ax_row, ax_col, trace_id=trace_id - ): Y.flatten(), - get_csv_column_name( - "quiver-u", ax_row, ax_col, trace_id=trace_id - ): U.flatten(), - get_csv_column_name( - "quiver-v", ax_row, ax_col, trace_id=trace_id - ): V.flatten(), - } - ) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py deleted file mode 100755 index 524f51eef..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_scatter(id, tracked_dict, kwargs): - """Format data from a scatter call (matplotlib ax.scatter or seaborn scatter). - - Note: For plot_scatter (wrapper method), use _format_plot_scatter instead. - This formatter expects data in args format: tracked_dict['args'] = (x, y). - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Extract x and y data if available - if len(args) >= 2: - x, y = args[0], args[1] - # Use structured column naming: ax-row-{row}-col-{col}_trace-id-{id}_variable-{var} - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - df = pd.DataFrame({col_x: x, col_y: y}) - return df - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py deleted file mode 100755 index 405cbf9f8..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py - -"""CSV formatter for sns.barplot() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_barplot(id, tracked_dict, kwargs): - """Format data from a sns_barplot call. - - Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to sns_barplot - - Returns: - pd.DataFrame: Formatted data with standard column names - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # If 'data' key is in tracked_dict, use it - if "data" in tracked_dict: - df = tracked_dict["data"] - if isinstance(df, pd.DataFrame): - result = pd.DataFrame() - for col in df.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = df[col] - return result - - # Legacy handling for args - if "args" in tracked_dict: - df = tracked_dict["args"] - if isinstance(df, pd.DataFrame): - try: - processed_df = pd.DataFrame( - pd.Series(np.array(df).diagonal(), index=df.columns) - ).T - result = pd.DataFrame() - for col in processed_df.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = processed_df[col] - return result - except (ValueError, TypeError, IndexError): - result = pd.DataFrame() - for col in df.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = df[col] - return result - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py deleted file mode 100755 index 58b1f307b..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py - -"""CSV formatter for sns.boxplot() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_boxplot(id, tracked_dict, kwargs): - """Format data from a sns_boxplot call. - - Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to sns_boxplot - - Returns: - pd.DataFrame: Formatted boxplot data with standard column names - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict: - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # If tracked_dict is a dictionary, try to extract the data from it - if isinstance(tracked_dict, dict): - # First try to get 'data' key which is used in seaborn functions - if "data" in tracked_dict: - data = tracked_dict["data"] - if isinstance(data, pd.DataFrame): - result = pd.DataFrame() - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - return result - - # If no 'data' key, try to get data from args - args = tracked_dict.get("args", []) - if len(args) > 0: - data = args[0] - if isinstance(data, pd.DataFrame): - result = pd.DataFrame() - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - return result - - # Handle list or array data - elif isinstance(data, (list, np.ndarray)): - try: - if all(isinstance(item, (list, np.ndarray)) for item in data): - result = pd.DataFrame() - for i, group_data in enumerate(data): - col_name = get_csv_column_name( - f"data-{i}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = pd.Series(group_data) - return result - else: - col_name = get_csv_column_name( - "data", ax_row, ax_col, trace_id=trace_id - ) - return pd.DataFrame({col_name: data}) - except Exception: - pass - - # If tracked_dict is a DataFrame already, use it directly - elif isinstance(tracked_dict, pd.DataFrame): - result = pd.DataFrame() - for col in tracked_dict.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = tracked_dict[col] - return result - - # If tracked_dict is list-like, try to convert it to a DataFrame - elif hasattr(tracked_dict, "__iter__") and not isinstance(tracked_dict, str): - try: - if all(isinstance(item, (list, np.ndarray)) for item in tracked_dict): - result = pd.DataFrame() - for i, group_data in enumerate(tracked_dict): - col_name = get_csv_column_name( - f"data-{i}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = pd.Series(group_data) - return result - else: - col_name = get_csv_column_name( - "data", ax_row, ax_col, trace_id=trace_id - ) - return pd.DataFrame({col_name: tracked_dict}) - except Exception: - pass - - # Return empty DataFrame if we couldn't extract useful data - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py deleted file mode 100755 index ab0e46a62..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py - -"""CSV formatter for sns.heatmap() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_heatmap(id, tracked_dict, kwargs): - """Format data from a sns_heatmap call. - - Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to sns_heatmap - - Returns: - pd.DataFrame: Formatted data with standard column names - """ - # Check if tracked_dict is empty - if not tracked_dict: - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - def _format_dataframe(df): - result = pd.DataFrame() - for col in df.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = df[col] - return result - - def _format_array(arr): - rows, cols = arr.shape if len(arr.shape) >= 2 else (arr.shape[0], 1) - result = pd.DataFrame() - for i in range(cols): - col_data = arr[:, i] if len(arr.shape) >= 2 else arr - col_name = get_csv_column_name( - f"data-col-{i}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = col_data - return result - - # If tracked_dict is a dictionary - if isinstance(tracked_dict, dict): - if "data" in tracked_dict: - data = tracked_dict["data"] - - if isinstance(data, pd.DataFrame): - return _format_dataframe(data) - elif isinstance(data, np.ndarray): - return _format_array(data) - - # Legacy handling for args - args = tracked_dict.get("args", []) - if len(args) > 0: - data = args[0] - - if isinstance(data, pd.DataFrame): - return _format_dataframe(data) - elif isinstance(data, np.ndarray): - return _format_array(data) - - # If tracked_dict is a DataFrame directly - elif isinstance(tracked_dict, pd.DataFrame): - return _format_dataframe(tracked_dict) - - # If tracked_dict is a numpy array directly - elif isinstance(tracked_dict, np.ndarray): - return _format_array(tracked_dict) - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py deleted file mode 100755 index f4d365531..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py - -"""CSV formatter for sns.histplot() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_histplot(id, tracked_dict, kwargs): - """Format data from a sns_histplot call as a bar plot representation. - - Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to sns_histplot - - Returns: - pd.DataFrame: Formatted data with standard column names - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - columns = {} - - # Check if histogram result is available in tracked_dict - hist_result = tracked_dict.get("hist_result", None) - - # If we have histogram result (counts and bin edges) - if hist_result is not None: - counts, bin_edges = hist_result - - # Calculate bin centers for bar plot representation - bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) - bin_widths = bin_edges[1:] - bin_edges[:-1] - - # Add bin information with standard naming - columns[ - get_csv_column_name("bin-centers", ax_row, ax_col, trace_id=trace_id) - ] = bin_centers - columns[ - get_csv_column_name("bin-counts", ax_row, ax_col, trace_id=trace_id) - ] = counts - columns[ - get_csv_column_name("bin-widths", ax_row, ax_col, trace_id=trace_id) - ] = bin_widths - columns[ - get_csv_column_name("bin-edges-left", ax_row, ax_col, trace_id=trace_id) - ] = bin_edges[:-1] - columns[ - get_csv_column_name("bin-edges-right", ax_row, ax_col, trace_id=trace_id) - ] = bin_edges[1:] - - # Get raw data if available - if "data" in tracked_dict: - df = tracked_dict["data"] - if isinstance(df, pd.DataFrame): - x_col = kwargs.get("x") - if x_col and x_col in df.columns: - columns[ - get_csv_column_name("raw-data", ax_row, ax_col, trace_id=trace_id) - ] = df[x_col].values - - # Legacy handling for args - elif "args" in tracked_dict: - args = tracked_dict["args"] - if len(args) >= 1: - x = args[0] - if hasattr(x, "values"): - columns[ - get_csv_column_name("raw-data", ax_row, ax_col, trace_id=trace_id) - ] = x.values - else: - columns[ - get_csv_column_name("raw-data", ax_row, ax_col, trace_id=trace_id) - ] = x - - # If we have data to return - if columns: - # Ensure all arrays are the same length by padding with NaN - max_length = max( - len(value) for value in columns.values() if hasattr(value, "__len__") - ) - for key, value in list(columns.items()): - if hasattr(value, "__len__") and len(value) < max_length: - if isinstance(value, np.ndarray): - columns[key] = np.pad( - value, - (0, max_length - len(value)), - mode="constant", - constant_values=np.nan, - ) - else: - padded = list(value) + [np.nan] * (max_length - len(value)) - columns[key] = np.array(padded) - - return pd.DataFrame(columns) - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py deleted file mode 100755 index 636a049a1..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_jointplot(id, tracked_dict, kwargs): - """Format data from a sns_jointplot call.""" - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to extract axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Joint distribution plot in seaborn - if len(args) >= 1: - data = args[0] - - # Get x and y variables from kwargs - x_var = kwargs.get("x") - y_var = kwargs.get("y") - - # Handle DataFrame input - if isinstance(data, pd.DataFrame) and x_var and y_var: - # Extract the relevant columns - x_data = data[x_var] - y_data = data[y_var] - - result = pd.DataFrame( - { - get_csv_column_name( - f"joint_{x_var}", ax_row, ax_col, trace_id=trace_id - ): x_data, - get_csv_column_name( - f"joint_{y_var}", ax_row, ax_col, trace_id=trace_id - ): y_data, - } - ) - return result - - # Handle direct x, y data arrays - elif isinstance(data, pd.DataFrame): - # If no x, y specified, return the whole dataframe - result = data.copy() - if id is not None: - result.columns = [ - get_csv_column_name( - f"joint_{col}", ax_row, ax_col, trace_id=trace_id - ) - for col in result.columns - ] - return result - - # Handle numpy arrays directly - elif ( - all(arg in args for arg in range(2)) - and isinstance(args[0], (np.ndarray, list)) - and isinstance(args[1], (np.ndarray, list)) - ): - x_data, y_data = args[0], args[1] - return pd.DataFrame( - { - get_csv_column_name( - "joint_x", ax_row, ax_col, trace_id=trace_id - ): x_data, - get_csv_column_name( - "joint_y", ax_row, ax_col, trace_id=trace_id - ): y_data, - } - ) - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py deleted file mode 100755 index 055ce472d..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py - -"""CSV formatter for sns.kdeplot() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_kdeplot(id, tracked_dict, kwargs): - """Format data from a sns_kdeplot call. - - Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to sns_kdeplot - - Returns: - pd.DataFrame: Formatted data with standard column names - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get args from tracked_dict - args = tracked_dict.get("args", []) - x_var = kwargs.get("x") if kwargs else None - y_var = kwargs.get("y") if kwargs else None - - if len(args) >= 1: - data = args[0] - - # Handle DataFrame input with x, y variables - if isinstance(data, pd.DataFrame) and x_var: - if y_var and y_var in data.columns: # Bivariate KDE - return pd.DataFrame( - { - get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ): data[x_var], - get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ): data[y_var], - } - ) - elif x_var in data.columns: # Univariate KDE - return pd.DataFrame( - { - get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ): data[x_var] - } - ) - - # Handle direct data array input - elif isinstance(data, (np.ndarray, list)): - y_data = ( - args[1] - if len(args) > 1 and isinstance(args[1], (np.ndarray, list)) - else None - ) - - if y_data is not None: # Bivariate KDE - return pd.DataFrame( - { - get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ): data, - get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ): y_data, - } - ) - else: # Univariate KDE - return pd.DataFrame( - {get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): data} - ) - - # Handle DataFrame input without x, y specified - elif isinstance(data, pd.DataFrame): - result = pd.DataFrame() - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - return result - - # Also check for 'data' key directly - if "data" in tracked_dict: - data = tracked_dict["data"] - if isinstance(data, pd.DataFrame): - result = pd.DataFrame() - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - return result - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py deleted file mode 100755 index 04a64e6ab..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_lineplot(id, tracked_dict, kwargs): - """Format data from a sns_lineplot call.""" - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get data from tracked_dict - can be in "data" (from _sns_base_xyhue) or "args" - data = tracked_dict.get("data") - args = tracked_dict.get("args", []) - - # If data is None, try to get it from args - if data is None and len(args) >= 1: - data = args[0] - - x_var = kwargs.get("x") - y_var = kwargs.get("y") - - # Handle DataFrame input with x, y variables - if isinstance(data, pd.DataFrame): - # If data has been pre-processed by _sns_prepare_xyhue, it may be pivoted - # Just export all columns with proper naming - if data.empty: - return pd.DataFrame() - - result = {} - for col in data.columns: - col_name = str(col) if not isinstance(col, str) else col - result[get_csv_column_name(col_name, ax_row, ax_col, trace_id=trace_id)] = ( - data[col].values - ) - return pd.DataFrame(result) - - # Handle direct x, y data arrays from args - elif ( - len(args) > 1 - and isinstance(args[0], (np.ndarray, list)) - and isinstance(args[1], (np.ndarray, list)) - ): - x_data, y_data = args[0], args[1] - return pd.DataFrame( - { - get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): x_data, - get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id): y_data, - } - ) - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py deleted file mode 100755 index 9de01a00c..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_pairplot(id, tracked_dict, kwargs): - """Format data from a sns_pairplot call.""" - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to extract axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Grid of plots showing pairwise relationships - if len(args) >= 1: - data = args[0] - - # Handle DataFrame input - if isinstance(data, pd.DataFrame): - # For pairplot, just return the full DataFrame since it uses all variables - result = data.copy() - if id is not None: - result.columns = [ - get_csv_column_name( - f"pair_{col}", ax_row, ax_col, trace_id=trace_id - ) - for col in result.columns - ] - - # Add vars or hue columns if specified - vars_list = kwargs.get("vars") - if vars_list and all(var in data.columns for var in vars_list): - # Keep only the specified columns - result = pd.DataFrame( - { - get_csv_column_name( - f"pair_{col}", ax_row, ax_col, trace_id=trace_id - ): data[col] - for col in vars_list - } - ) - - return result - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py deleted file mode 100755 index f053b7006..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py - -"""CSV formatter for sns.scatterplot() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_scatterplot(id, tracked_dict, kwargs=None): - """Format data from a sns_scatterplot call. - - Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Tracked data dictionary - kwargs (dict): Keyword arguments from the record tuple - - Returns: - pd.DataFrame: Formatted data with standard column names - """ - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Look for the DataFrame in the kwargs dictionary if provided - if kwargs and isinstance(kwargs, dict) and "data" in kwargs: - data = kwargs["data"] - if isinstance(data, pd.DataFrame): - result = pd.DataFrame() - - # If x and y variables are specified in kwargs, use them - x_var = kwargs.get("x") - y_var = kwargs.get("y") - - if x_var and y_var and x_var in data.columns and y_var in data.columns: - result[get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id)] = ( - data[x_var] - ) - result[get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id)] = ( - data[y_var] - ) - - # Also extract hue, size, style if specified - for extra_var in ["hue", "size", "style"]: - var_name = kwargs.get(extra_var) - if var_name and var_name in data.columns: - result[ - get_csv_column_name( - extra_var, ax_row, ax_col, trace_id=trace_id - ) - ] = data[var_name] - - return result - else: - # If columns aren't specified, include all columns - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - return result - - # Alternative: try to find a DataFrame in tracked_dict - if tracked_dict and isinstance(tracked_dict, dict): - if "data" in tracked_dict and isinstance(tracked_dict["data"], pd.DataFrame): - data = tracked_dict["data"] - result = pd.DataFrame() - - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - - return result - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py deleted file mode 100755 index ab7625c30..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py - -"""CSV formatter for sns.stripplot() calls - uses standard column naming.""" - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_stripplot(id, tracked_dict, kwargs): - """Format data from a sns_stripplot call. - - Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to sns_stripplot - - Returns: - pd.DataFrame: Formatted data with standard column names - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # If 'data' key is in tracked_dict, use it - if "data" in tracked_dict: - data = tracked_dict["data"] - - if isinstance(data, pd.DataFrame): - result = pd.DataFrame() - - # Extract variables from kwargs - x_var = kwargs.get("x") if kwargs else None - y_var = kwargs.get("y") if kwargs else None - hue_var = kwargs.get("hue") if kwargs else None - - # Add x variable if specified - if x_var and x_var in data.columns: - result[get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id)] = ( - data[x_var] - ) - - # Add y variable if specified - if y_var and y_var in data.columns: - result[get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id)] = ( - data[y_var] - ) - - # Add grouping variable if present - if hue_var and hue_var in data.columns: - result[ - get_csv_column_name("hue", ax_row, ax_col, trace_id=trace_id) - ] = data[hue_var] - - # If we've added columns, return the result - if not result.empty: - return result - - # If no columns were explicitly specified, return all columns - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - return result - - # Legacy handling for args - if "args" in tracked_dict and len(tracked_dict["args"]) >= 1: - data = tracked_dict["args"][0] - - if isinstance(data, pd.DataFrame): - result = pd.DataFrame() - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - return result - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py deleted file mode 100755 index 4fddcf2e3..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py - -"""CSV formatter for sns.swarmplot() calls - uses standard column naming.""" - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_swarmplot(id, tracked_dict, kwargs): - """Format data from a sns_swarmplot call. - - Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to sns_swarmplot - - Returns: - pd.DataFrame: Formatted data with standard column names - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # If 'data' key is in tracked_dict, use it - if "data" in tracked_dict: - data = tracked_dict["data"] - - if isinstance(data, pd.DataFrame): - result = pd.DataFrame() - - # Extract variables from kwargs - x_var = kwargs.get("x") if kwargs else None - y_var = kwargs.get("y") if kwargs else None - hue_var = kwargs.get("hue") if kwargs else None - - # Add x variable if specified - if x_var and x_var in data.columns: - result[get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id)] = ( - data[x_var] - ) - - # Add y variable if specified - if y_var and y_var in data.columns: - result[get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id)] = ( - data[y_var] - ) - - # Add grouping variable if present - if hue_var and hue_var in data.columns: - result[ - get_csv_column_name("hue", ax_row, ax_col, trace_id=trace_id) - ] = data[hue_var] - - # If we've added columns, return the result - if not result.empty: - return result - - # If no columns were explicitly specified, return all columns - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - return result - - # Legacy handling for args - if "args" in tracked_dict and len(tracked_dict["args"]) >= 1: - data = tracked_dict["args"][0] - - if isinstance(data, pd.DataFrame): - result = pd.DataFrame() - for col in data.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = data[col] - return result - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py deleted file mode 100755 index 4d071c05e..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py +++ /dev/null @@ -1,171 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py - -"""CSV formatter for sns.violinplot() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_sns_violinplot(id, tracked_dict, kwargs): - """Format data from a sns_violinplot call. - - Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to sns_violinplot - - Returns: - pd.DataFrame: Formatted data with standard column names - """ - # Check if tracked_dict is empty - if not tracked_dict: - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - def _format_dataframe(df): - result = pd.DataFrame() - for col in df.columns: - col_name = get_csv_column_name( - f"data-{col}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = df[col] - return result - - def _format_list_of_arrays(data): - result = pd.DataFrame() - for i, group_data in enumerate(data): - col_name = get_csv_column_name( - f"data-{i}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = pd.Series(group_data) - return result - - # If tracked_dict is a dictionary - if isinstance(tracked_dict, dict): - if "data" in tracked_dict: - data = tracked_dict["data"] - - if isinstance(data, pd.DataFrame): - try: - return _format_dataframe(data) - except Exception: - try: - x_var = kwargs.get("x") if kwargs else None - y_var = kwargs.get("y") if kwargs else None - - if ( - x_var - and y_var - and x_var in data.columns - and y_var in data.columns - ): - return pd.DataFrame( - { - get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ): data[x_var], - get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ): data[y_var], - } - ) - elif len(data.columns) > 0: - first_col = data.columns[0] - return pd.DataFrame( - { - get_csv_column_name( - "data", ax_row, ax_col, trace_id=trace_id - ): data[first_col] - } - ) - except Exception: - return pd.DataFrame() - - elif isinstance(data, (list, np.ndarray)): - try: - if ( - isinstance(data, list) - and len(data) > 0 - and all(isinstance(item, (list, np.ndarray)) for item in data) - ): - return _format_list_of_arrays(data) - else: - return pd.DataFrame( - { - get_csv_column_name( - "data", ax_row, ax_col, trace_id=trace_id - ): data - } - ) - except Exception: - return pd.DataFrame() - - # Legacy handling for args - args = tracked_dict.get("args", []) - if len(args) > 0: - data = args[0] - - if isinstance(data, pd.DataFrame): - return _format_dataframe(data) - - elif isinstance(data, (list, np.ndarray)): - try: - if all(isinstance(item, (list, np.ndarray)) for item in data): - return _format_list_of_arrays(data) - else: - return pd.DataFrame( - { - get_csv_column_name( - "data", ax_row, ax_col, trace_id=trace_id - ): data - } - ) - except Exception: - return pd.DataFrame() - - # If tracked_dict is a DataFrame directly - elif isinstance(tracked_dict, pd.DataFrame): - try: - return _format_dataframe(tracked_dict) - except Exception: - try: - if len(tracked_dict.columns) > 0: - first_col = tracked_dict.columns[0] - return pd.DataFrame( - { - get_csv_column_name( - "data", ax_row, ax_col, trace_id=trace_id - ): tracked_dict[first_col] - } - ) - except Exception: - return pd.DataFrame() - - # If tracked_dict is a list or numpy array directly - elif isinstance(tracked_dict, (list, np.ndarray)): - try: - if all(isinstance(item, (list, np.ndarray)) for item in tracked_dict): - return _format_list_of_arrays(tracked_dict) - else: - return pd.DataFrame( - { - get_csv_column_name( - "data", ax_row, ax_col, trace_id=trace_id - ): tracked_dict - } - ) - except Exception: - return pd.DataFrame() - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stackplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stackplot.py deleted file mode 100755 index a775ff866..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stackplot.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-21 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stackplot.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_stackplot(id, tracked_dict, kwargs): - """Format data from a stackplot call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stackplot - - Returns: - pd.DataFrame: Formatted data from stackplot (x and multiple y columns) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - args = tracked_dict.get("args", ()) - - # stackplot(x, y1, y2, y3, ...) or stackplot(x, [y1, y2, y3], ...) - if len(args) < 2: - return pd.DataFrame() - - x = np.asarray(args[0]) - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - data = {col_x: x} - - # Get labels from kwargs if available - labels = kwargs.get("labels", []) - - # Handle remaining args as y arrays - y_arrays = args[1:] - - # If first y arg is a 2D array, treat rows as separate series - if len(y_arrays) == 1 and hasattr(y_arrays[0], "ndim"): - y_data = np.asarray(y_arrays[0]) - if y_data.ndim == 2: - y_arrays = [y_data[i] for i in range(y_data.shape[0])] - - for i, y in enumerate(y_arrays): - y = np.asarray(y) - # Use label if available, otherwise use index - label = labels[i] if i < len(labels) else f"y{i:02d}" - col_y = get_csv_column_name(label, ax_row, ax_col, trace_id=trace_id) - data[col_y] = y - - return pd.DataFrame(data) - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py deleted file mode 100755 index ad1e64a50..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 12:20:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_stem(id, tracked_dict, kwargs): - """Format data from a stem plot call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stem - - Returns: - pd.DataFrame: Formatted data from stem plot - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple) and len(args) > 0: - if len(args) == 1: - y = np.asarray(args[0]) - x = np.arange(len(y)) - elif len(args) >= 2: - x = np.asarray(args[0]) - y = np.asarray(args[1]) - else: - return pd.DataFrame() - - # Use structured column naming: ax-row-{row}-col-{col}_trace-id-{id}_variable-{var} - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - df = pd.DataFrame({col_x: x, col_y: y}) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py deleted file mode 100755 index d357366b9..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 12:20:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_step(id, tracked_dict, kwargs): - """Format data from a step plot call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to step - - Returns: - pd.DataFrame: Formatted data from step plot - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple) and len(args) > 0: - if len(args) == 1: - y = np.asarray(args[0]) - x = np.arange(len(y)) - elif len(args) >= 2: - x = np.asarray(args[0]) - y = np.asarray(args[1]) - else: - return pd.DataFrame() - - # Use structured column naming: ax-row-{row}-col-{col}_trace-id-{id}_variable-{var} - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - df = pd.DataFrame({col_x: x, col_y: y}) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py deleted file mode 100755 index b84f7a93c..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_streamplot(id, tracked_dict, kwargs): - """Format data from a streamplot call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to streamplot - - Returns: - pd.DataFrame: Formatted data from streamplot (X, Y positions and U, V vectors) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - if "args" in tracked_dict: - args = tracked_dict["args"] - if isinstance(args, tuple) and len(args) >= 4: - # streamplot(X, Y, U, V) - X, Y are 1D, U, V are 2D - X = np.asarray(args[0]) - Y = np.asarray(args[1]) - U = np.asarray(args[2]) - V = np.asarray(args[3]) - - # Create meshgrid if X, Y are 1D - if X.ndim == 1 and Y.ndim == 1: - X, Y = np.meshgrid(X, Y) - - df = pd.DataFrame( - { - get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ): X.flatten(), - get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ): Y.flatten(), - get_csv_column_name( - "u", ax_row, ax_col, trace_id=trace_id - ): U.flatten(), - get_csv_column_name( - "v", ax_row, ax_col, trace_id=trace_id - ): V.flatten(), - } - ) - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_bar.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_bar.py deleted file mode 100755 index ff47be9a0..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_bar.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -"""CSV formatter for stx_bar() calls.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_stx_bar(id, tracked_dict, kwargs): - """Format data from stx_bar call for CSV export. - - Parameters - ---------- - id : str - Tracking identifier - tracked_dict : dict - Dictionary containing tracked data with 'bar_df' key - kwargs : dict - Additional keyword arguments (may contain yerr) - - Returns - ------- - pd.DataFrame - Formatted bar data with standardized column names - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get bar_df from tracked data - bar_df = tracked_dict.get("bar_df") - if bar_df is not None and isinstance(bar_df, pd.DataFrame): - result = bar_df.copy() - renamed = {} - # Map 'x' and 'height' to standardized column names - for col in result.columns: - if col == "x": - renamed[col] = get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ) - elif col == "height": - renamed[col] = get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ) - else: - renamed[col] = get_csv_column_name( - col, ax_row, ax_col, trace_id=trace_id - ) - - result = result.rename(columns=renamed) - - # Add yerr if present in kwargs - yerr = kwargs.get("yerr") if kwargs else None - if yerr is not None: - try: - yerr_array = np.asarray(yerr) - if len(yerr_array) == len(result): - col_yerr = get_csv_column_name( - "yerr", ax_row, ax_col, trace_id=trace_id - ) - result[col_yerr] = yerr_array - except (TypeError, ValueError): - pass - - return result - - # Fallback to args if bar_df not found - args = tracked_dict.get("args", []) - if len(args) >= 2: - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - data = {col_x: args[0], col_y: args[1]} - - # Add yerr if present - yerr = kwargs.get("yerr") if kwargs else None - if yerr is not None: - try: - yerr_array = np.asarray(yerr) - col_yerr = get_csv_column_name( - "yerr", ax_row, ax_col, trace_id=trace_id - ) - data[col_yerr] = yerr_array - except (TypeError, ValueError): - pass - - return pd.DataFrame(data) - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_barh.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_barh.py deleted file mode 100755 index 5940f9536..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_barh.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -"""CSV formatter for stx_barh() calls.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_stx_barh(id, tracked_dict, kwargs): - """Format data from stx_barh call for CSV export. - - Parameters - ---------- - id : str - Tracking identifier - tracked_dict : dict - Dictionary containing tracked data with 'barh_df' key - kwargs : dict - Additional keyword arguments (may contain xerr) - - Returns - ------- - pd.DataFrame - Formatted barh data with standardized column names - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get barh_df from tracked data - barh_df = tracked_dict.get("barh_df") - if barh_df is not None and isinstance(barh_df, pd.DataFrame): - result = barh_df.copy() - renamed = {} - # Map 'y' and 'width' to standardized column names - for col in result.columns: - if col == "y": - renamed[col] = get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ) - elif col == "width": - renamed[col] = get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ) - else: - renamed[col] = get_csv_column_name( - col, ax_row, ax_col, trace_id=trace_id - ) - - result = result.rename(columns=renamed) - - # Add xerr if present in kwargs - xerr = kwargs.get("xerr") if kwargs else None - if xerr is not None: - try: - xerr_array = np.asarray(xerr) - if len(xerr_array) == len(result): - col_xerr = get_csv_column_name( - "xerr", ax_row, ax_col, trace_id=trace_id - ) - result[col_xerr] = xerr_array - except (TypeError, ValueError): - pass - - return result - - # Fallback to args if barh_df not found - args = tracked_dict.get("args", []) - if len(args) >= 2: - # Note: in barh, first arg is y positions, second is widths (x values) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - data = {col_y: args[0], col_x: args[1]} - - # Add xerr if present - xerr = kwargs.get("xerr") if kwargs else None - if xerr is not None: - try: - xerr_array = np.asarray(xerr) - col_xerr = get_csv_column_name( - "xerr", ax_row, ax_col, trace_id=trace_id - ) - data[col_xerr] = xerr_array - except (TypeError, ValueError): - pass - - return pd.DataFrame(data) - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_conf_mat.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_conf_mat.py deleted file mode 100755 index df762e3af..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_conf_mat.py +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_conf_mat(id, tracked_dict, kwargs): - """Format data from a stx_conf_mat call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stx_conf_mat - - Returns: - pd.DataFrame: Formatted confusion matrix data - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Extract confusion matrix if available in args - if len(args) >= 1 and isinstance(args[0], (np.ndarray, list)): - conf_mat = np.array(args[0]) - - # Convert to DataFrame - if conf_mat.ndim == 2: - # Create column and index names - n_classes = conf_mat.shape[0] - columns = [f"Predicted_{i}" for i in range(n_classes)] - index = [f"True_{i}" for i in range(n_classes)] - - # Create DataFrame with proper labels - df = pd.DataFrame(conf_mat, columns=columns, index=index) - - # Reset index to make it a regular column - df = df.reset_index().rename(columns={"index": "True_Class"}) - - # Add prefix to all columns using single source of truth - df.columns = [ - get_csv_column_name( - f"conf-mat-{col}", ax_row, ax_col, trace_id=trace_id - ) - for col in df.columns - ] - - return df - - # Extract balanced accuracy if available as fallback - bacc = tracked_dict.get("balanced_accuracy") - - # Create DataFrame with the balanced accuracy - if bacc is not None: - col_name = get_csv_column_name( - "conf-mat-balanced-accuracy", ax_row, ax_col, trace_id=trace_id - ) - df = pd.DataFrame({col_name: [bacc]}) - return df - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_contour.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_contour.py deleted file mode 100755 index 614af02b8..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_contour.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -"""CSV formatter for stx_contour() calls.""" - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_stx_contour(id, tracked_dict, kwargs): - """Format data from stx_contour call for CSV export. - - Parameters - ---------- - id : str - Identifier for the plot - tracked_dict : dict - Dictionary containing tracked data with 'contour_df' - kwargs : dict - Keyword arguments passed to stx_contour - - Returns - ------- - pd.DataFrame - Formatted contour data with X, Y, Z columns - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get contour_df from tracked_dict - contour_df = tracked_dict.get("contour_df") - if contour_df is not None and isinstance(contour_df, pd.DataFrame): - result = contour_df.copy() - - # Rename columns using single source of truth - renamed = {} - for col in result.columns: - if col == "X": - renamed[col] = get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ) - elif col == "Y": - renamed[col] = get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ) - elif col == "Z": - renamed[col] = get_csv_column_name( - "z", ax_row, ax_col, trace_id=trace_id - ) - else: - renamed[col] = get_csv_column_name( - col.lower(), ax_row, ax_col, trace_id=trace_id - ) - - return result.rename(columns=renamed) - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_ecdf.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_ecdf.py deleted file mode 100755 index 5f67c488e..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_ecdf.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py -# ---------------------------------------- -import os - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - - -def _format_plot_ecdf(id, tracked_dict, kwargs): - """Format data from a stx_ecdf call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing 'ecdf_df' key with ECDF data - kwargs (dict): Keyword arguments passed to stx_ecdf - - Returns: - pd.DataFrame: Formatted ECDF data - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Get the ecdf_df from tracked_dict - ecdf_df = tracked_dict.get("ecdf_df") - - if ecdf_df is None or not isinstance(ecdf_df, pd.DataFrame): - return pd.DataFrame() - - # Create a copy to avoid modifying the original - result = ecdf_df.copy() - - # Add prefix to column names if ID is provided - if id is not None: - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Rename columns using single source of truth - renamed = {} - for col in result.columns: - # Use the original column name as the variable (e.g., "ecdf_value", "ecdf_prob") - renamed[col] = get_csv_column_name( - f"ecdf_{col}", ax_row, ax_col, trace_id=trace_id - ) - result = result.rename(columns=renamed) - - return result diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_errorbar.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_errorbar.py deleted file mode 100755 index 262241232..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_errorbar.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -"""CSV formatter for stx_errorbar() calls.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_stx_errorbar(id, tracked_dict, kwargs): - """Format data from stx_errorbar call for CSV export. - - Parameters - ---------- - id : str - Tracking identifier - tracked_dict : dict - Dictionary containing tracked data with 'errorbar_df' key - kwargs : dict - Additional keyword arguments (may contain yerr, xerr) - - Returns - ------- - pd.DataFrame - Formatted errorbar data with standardized column names - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get errorbar_df from tracked data - errorbar_df = tracked_dict.get("errorbar_df") - if errorbar_df is not None and isinstance(errorbar_df, pd.DataFrame): - result = errorbar_df.copy() - renamed = {} - - # Map columns to standardized names - for col in result.columns: - if col == "x": - renamed[col] = get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ) - elif col == "y": - renamed[col] = get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ) - elif col == "yerr": - # Check if yerr is asymmetric (tuple/list of 2) - yerr_value = result[col].iloc[0] if len(result) > 0 else None - if isinstance(yerr_value, (list, tuple)) and len(yerr_value) == 2: - # Handle asymmetric yerr separately below - continue - else: - renamed[col] = get_csv_column_name( - "yerr", ax_row, ax_col, trace_id=trace_id - ) - elif col == "xerr": - # Check if xerr is asymmetric (tuple/list of 2) - xerr_value = result[col].iloc[0] if len(result) > 0 else None - if isinstance(xerr_value, (list, tuple)) and len(xerr_value) == 2: - # Handle asymmetric xerr separately below - continue - else: - renamed[col] = get_csv_column_name( - "xerr", ax_row, ax_col, trace_id=trace_id - ) - else: - renamed[col] = get_csv_column_name( - col, ax_row, ax_col, trace_id=trace_id - ) - - result = result.rename(columns=renamed) - - # Handle asymmetric error bars if needed from kwargs - yerr = kwargs.get("yerr") if kwargs else None - xerr = kwargs.get("xerr") if kwargs else None - - if yerr is not None and isinstance(yerr, (list, tuple)) and len(yerr) == 2: - col_yerr_neg = get_csv_column_name( - "yerr-neg", ax_row, ax_col, trace_id=trace_id - ) - col_yerr_pos = get_csv_column_name( - "yerr-pos", ax_row, ax_col, trace_id=trace_id - ) - result[col_yerr_neg] = yerr[0] - result[col_yerr_pos] = yerr[1] - - if xerr is not None and isinstance(xerr, (list, tuple)) and len(xerr) == 2: - col_xerr_neg = get_csv_column_name( - "xerr-neg", ax_row, ax_col, trace_id=trace_id - ) - col_xerr_pos = get_csv_column_name( - "xerr-pos", ax_row, ax_col, trace_id=trace_id - ) - result[col_xerr_neg] = xerr[0] - result[col_xerr_pos] = xerr[1] - - return result - - # Fallback to args if errorbar_df not found - args = tracked_dict.get("args", []) - if len(args) >= 2: - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - data = {col_x: args[0], col_y: args[1]} - - # Add error bars if present - yerr = kwargs.get("yerr") if kwargs else None - xerr = kwargs.get("xerr") if kwargs else None - - if yerr is not None: - if isinstance(yerr, (list, tuple)) and len(yerr) == 2: - col_yerr_neg = get_csv_column_name( - "yerr-neg", ax_row, ax_col, trace_id=trace_id - ) - col_yerr_pos = get_csv_column_name( - "yerr-pos", ax_row, ax_col, trace_id=trace_id - ) - data[col_yerr_neg] = yerr[0] - data[col_yerr_pos] = yerr[1] - else: - col_yerr = get_csv_column_name( - "yerr", ax_row, ax_col, trace_id=trace_id - ) - data[col_yerr] = yerr - - if xerr is not None: - if isinstance(xerr, (list, tuple)) and len(xerr) == 2: - col_xerr_neg = get_csv_column_name( - "xerr-neg", ax_row, ax_col, trace_id=trace_id - ) - col_xerr_pos = get_csv_column_name( - "xerr-pos", ax_row, ax_col, trace_id=trace_id - ) - data[col_xerr_neg] = xerr[0] - data[col_xerr_pos] = xerr[1] - else: - col_xerr = get_csv_column_name( - "xerr", ax_row, ax_col, trace_id=trace_id - ) - data[col_xerr] = xerr - - return pd.DataFrame(data) - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_fillv.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_fillv.py deleted file mode 100755 index 7141fa8d3..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_fillv.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 12:00:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_fillv.py - -"""CSV formatter for stx_fillv() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_fillv(id, tracked_dict, kwargs): - """Format data from a stx_fillv call. - - Formats data similar to line plot format for better compatibility. - Uses standard column naming convention: - (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stx_fillv - - Returns: - pd.DataFrame: Formatted fillv data in a long-format dataframe - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Try to get starts/ends directly from tracked_dict first - starts = tracked_dict.get("starts") - ends = tracked_dict.get("ends") - - # If not found, get from args - if starts is None or ends is None: - args = tracked_dict.get("args", []) - - # Extract data if available from args - if len(args) >= 2: - starts, ends = args[0], args[1] - - # If we have valid starts and ends, create a DataFrame in a format similar to line plot - if starts is not None and ends is not None: - # Convert to numpy arrays if they're lists for better handling - if isinstance(starts, list): - starts = np.array(starts) - if isinstance(ends, list): - ends = np.array(ends) - - # Get standard column names - x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - y_col = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - type_col = get_csv_column_name("type", ax_row, ax_col, trace_id=trace_id) - - # Create a DataFrame with x, y pairs for each fill span - rows = [] - for start, end in zip(starts, ends): - rows.append({x_col: start, y_col: 0, type_col: "start"}) - rows.append({x_col: end, y_col: 0, type_col: "end"}) - - if rows: - return pd.DataFrame(rows) - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_heatmap.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_heatmap.py deleted file mode 100755 index 7082145b0..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_heatmap.py +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_heatmap(id, tracked_dict, kwargs): - """Format data from a stx_heatmap call. - - Exports heatmap data in xyz format (x, y, value) for better compatibility. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stx_heatmap - - Returns: - pd.DataFrame: Formatted heatmap data in xyz format - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Extract data from tracked_dict - data = tracked_dict.get("data") - x_labels = tracked_dict.get("x_labels") - y_labels = tracked_dict.get("y_labels") - - if data is not None and hasattr(data, "shape") and len(data.shape) == 2: - rows, cols = data.shape - row_indices, col_indices = np.meshgrid(range(rows), range(cols), indexing="ij") - - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Format data in xyz format (x, y, value) using single source of truth - df = pd.DataFrame( - { - get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ): col_indices.flatten(), # x is column - get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ): row_indices.flatten(), # y is row - get_csv_column_name( - "value", ax_row, ax_col, trace_id=trace_id - ): data.flatten(), # z is intensity/value - } - ) - - # Add label information if available - if x_labels is not None and len(x_labels) == cols: - # Map column indices to x labels (columns are x) - x_label_map = {i: label for i, label in enumerate(x_labels)} - x_col_name = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - x_label_col_name = get_csv_column_name( - "x_label", ax_row, ax_col, trace_id=trace_id - ) - df[x_label_col_name] = df[x_col_name].map(x_label_map) - - if y_labels is not None and len(y_labels) == rows: - # Map row indices to y labels (rows are y) - y_label_map = {i: label for i, label in enumerate(y_labels)} - y_col_name = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - y_label_col_name = get_csv_column_name( - "y_label", ax_row, ax_col, trace_id=trace_id - ) - df[y_label_col_name] = df[y_col_name].map(y_label_map) - - return df - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_image.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_image.py deleted file mode 100755 index fb88190c2..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_image.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py -# ---------------------------------------- -import os - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - - -def _format_plot_image(id, tracked_dict, kwargs): - """Format data from a stx_image call. - - Exports image data in long-format xyz format for better compatibility. - Also saves channel data for RGB/RGBA images. - - Args: - id (str or int): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stx_image - - Returns: - pd.DataFrame: Formatted image data in xyz format - """ - # Check if tracked_dict is not a dictionary or is empty - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse the tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Check if image_df is available and use it if present - if "image_df" in tracked_dict: - image_df = tracked_dict.get("image_df") - if isinstance(image_df, pd.DataFrame): - # Add prefix if ID is provided - if id is not None: - image_df = image_df.copy() - # Rename columns using single source of truth - renamed = {} - for col in image_df.columns: - # Convert to string to handle integer column names - col_str = str(col) - renamed[col] = get_csv_column_name( - col_str, ax_row, ax_col, trace_id=trace_id - ) - image_df = image_df.rename(columns=renamed) - return image_df - - # If we have image data - if "image" in tracked_dict: - img = tracked_dict["image"] - - # Handle 2D grayscale images - create xyz format (x, y, value) - if isinstance(img, np.ndarray) and img.ndim == 2: - rows, cols = img.shape - row_indices, col_indices = np.meshgrid( - range(rows), range(cols), indexing="ij" - ) - - # Create xyz format using single source of truth - df = pd.DataFrame( - { - get_csv_column_name( - "x", ax_row, ax_col, trace_id=trace_id - ): col_indices.flatten(), # x is column - get_csv_column_name( - "y", ax_row, ax_col, trace_id=trace_id - ): row_indices.flatten(), # y is row - get_csv_column_name( - "value", ax_row, ax_col, trace_id=trace_id - ): img.flatten(), # z is intensity - } - ) - return df - - # Handle RGB/RGBA images - create xyz format with additional channel information - elif isinstance(img, np.ndarray) and img.ndim == 3: - rows, cols, channels = img.shape - - # Create a list to hold rows for a long-format DataFrame - data_rows = [] - channel_names = ["r", "g", "b", "a"] - - # Get column names using single source of truth - x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - y_col = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - channel_col = get_csv_column_name( - "channel", ax_row, ax_col, trace_id=trace_id - ) - value_col = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) - - # Create long-format data (x, y, channel, value) - for r in range(rows): - for c in range(cols): - for ch in range(min(channels, len(channel_names))): - data_rows.append( - { - x_col: c, # x is column - y_col: r, # y is row - channel_col: channel_names[ch], # channel name - value_col: img[r, c, ch], # channel value - } - ) - - # Return long-format DataFrame - return pd.DataFrame(data_rows) - - # Skip CSV export if no suitable data format found - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_imshow.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_imshow.py deleted file mode 100755 index d43bbadd6..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_imshow.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -"""CSV formatter for stx_imshow() calls.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_stx_imshow(id, tracked_dict, kwargs): - """Format data from stx_imshow call for CSV export. - - Parameters - ---------- - id : str - Identifier for the plot - tracked_dict : dict - Dictionary containing tracked data with 'imshow_df' - kwargs : dict - Keyword arguments passed to stx_imshow - - Returns - ------- - pd.DataFrame - Formatted imshow data in row, col, value format (or row, col, R, G, B for RGB) - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get imshow_df from tracked_dict - imshow_df = tracked_dict.get("imshow_df") - if imshow_df is not None and isinstance(imshow_df, pd.DataFrame): - # Convert from 2D DataFrame format (with col_0, col_1, ... columns) - # to row, col, value format for easier analysis - n_rows, n_cols = imshow_df.shape - - # Create row and column indices - row_indices = np.repeat(np.arange(n_rows), n_cols) - col_indices = np.tile(np.arange(n_cols), n_rows) - - # Get column names from single source of truth - col_row = get_csv_column_name("row", ax_row, ax_col, trace_id=trace_id) - col_col = get_csv_column_name("col", ax_row, ax_col, trace_id=trace_id) - col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) - - # Flatten the DataFrame values - values = imshow_df.values.flatten() - - result = pd.DataFrame( - {col_row: row_indices, col_col: col_indices, col_value: values} - ) - - return result - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_joyplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_joyplot.py deleted file mode 100755 index 9cbe83928..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_joyplot.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -import pandas as pd - -from scitex.pd import force_df -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_joyplot(id, tracked_dict, kwargs): - """Format data from a stx_joyplot call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing 'joyplot_data' key with joyplot data - kwargs (dict): Keyword arguments passed to stx_joyplot - - Returns: - pd.DataFrame: Formatted joyplot data - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get joyplot_data from tracked_dict - data = tracked_dict.get("joyplot_data") - - if data is None: - return pd.DataFrame() - - # Handle different data types - if isinstance(data, pd.DataFrame): - # Make a copy to avoid modifying original - result = data.copy() - # Add prefix to column names using single source of truth - if id is not None: - result.columns = [ - get_csv_column_name(f"joyplot-{col}", ax_row, ax_col, trace_id=trace_id) - for col in result.columns - ] - return result - - elif isinstance(data, dict): - # Convert dictionary to DataFrame - result = pd.DataFrame() - for group, values in data.items(): - col_name = get_csv_column_name( - f"joyplot-{group}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = pd.Series(values) - return result - - elif isinstance(data, (list, tuple)) and all( - isinstance(x, (np.ndarray, list)) for x in data - ): - # Convert list of arrays to DataFrame - result = pd.DataFrame() - for i, values in enumerate(data): - col_name = get_csv_column_name( - f"joyplot-group{i:02d}", ax_row, ax_col, trace_id=trace_id - ) - result[col_name] = pd.Series(values) - return result - - # Try to force to DataFrame as a last resort - try: - col_name = get_csv_column_name( - "joyplot-data", ax_row, ax_col, trace_id=trace_id - ) - return force_df({col_name: data}) - except: - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_line.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_line.py deleted file mode 100755 index d9cf5b8ce..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_line.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 02:00:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_line.py - -"""CSV formatter for stx_line() calls - uses standard column naming.""" - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_line(id, tracked_dict, kwargs): - """Format data from a stx_line call. - - Processes stx_line data for CSV export using standard column naming - (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing 'plot_df' key with plot data - kwargs (dict): Keyword arguments passed to stx_line - - Returns: - pd.DataFrame: Formatted line plot data with standard column names - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Get the plot_df from tracked_dict - plot_df = tracked_dict.get("plot_df") - - if plot_df is None or not isinstance(plot_df, pd.DataFrame): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Create a copy to avoid modifying the original - result = plot_df.copy() - - # Rename columns using standard naming convention - renamed = {} - for col in result.columns: - if col == "x": - renamed[col] = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - elif col == "y": - renamed[col] = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - else: - renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) - - return result.rename(columns=renamed) diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_ci.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_ci.py deleted file mode 100755 index 097ac684f..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_ci.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 02:00:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_ci.py - -"""CSV formatter for stx_mean_ci() calls - uses standard column naming.""" - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_mean_ci(id, tracked_dict, kwargs): - """Format data from a stx_mean_ci call. - - Processes mean with confidence interval band plot data for CSV export using - standard column naming (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Contains 'plot_df' (pandas DataFrame with mean and CI data) - kwargs (dict): Keyword arguments passed to stx_mean_ci - - Returns: - pd.DataFrame: Formatted mean and CI data with standard column names - """ - # Mean-CI plot data is passed in the tracked_dict - if not tracked_dict: - return pd.DataFrame() - - # Get the plot_df from tracked_dict - plot_df = tracked_dict.get("plot_df") - - if plot_df is None or not isinstance(plot_df, pd.DataFrame): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Create a copy to avoid modifying the original - result = plot_df.copy() - - # Rename columns using standard naming convention - renamed = {} - for col in result.columns: - renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) - - return result.rename(columns=renamed) diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_std.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_std.py deleted file mode 100755 index 4e1380866..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_std.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 02:00:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_std.py - -"""CSV formatter for stx_mean_std() calls - uses standard column naming.""" - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_mean_std(id, tracked_dict, kwargs): - """Format data from a stx_mean_std call. - - Processes mean with standard deviation band plot data for CSV export using - standard column naming (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing 'plot_df' key with mean and std data - kwargs (dict): Keyword arguments passed to stx_mean_std - - Returns: - pd.DataFrame: Formatted mean and std data with standard column names - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Get the plot_df from tracked_dict - plot_df = tracked_dict.get("plot_df") - - if plot_df is None or not isinstance(plot_df, pd.DataFrame): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Create a copy to avoid modifying the original - result = plot_df.copy() - - # Rename columns using standard naming convention - renamed = {} - for col in result.columns: - renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) - - return result.rename(columns=renamed) diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_median_iqr.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_median_iqr.py deleted file mode 100755 index a160b6dc4..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_median_iqr.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-13 02:00:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_median_iqr.py - -"""CSV formatter for stx_median_iqr() calls - uses standard column naming.""" - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_median_iqr(id, tracked_dict, kwargs): - """Format data from a stx_median_iqr call. - - Processes median with interquartile range band plot data for CSV export using - standard column naming (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Contains 'plot_df' (pandas DataFrame with median and IQR data) - kwargs (dict): Keyword arguments passed to stx_median_iqr - - Returns: - pd.DataFrame: Formatted median and IQR data with standard column names - """ - # Median-IQR plot data is passed in the tracked_dict - if not tracked_dict: - return pd.DataFrame() - - # Get the plot_df from tracked_dict - plot_df = tracked_dict.get("plot_df") - - if plot_df is None or not isinstance(plot_df, pd.DataFrame): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Create a copy to avoid modifying the original - result = plot_df.copy() - - # Rename columns using standard naming convention - renamed = {} - for col in result.columns: - renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) - - return result.rename(columns=renamed) diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_raster.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_raster.py deleted file mode 100755 index 903b822d0..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_raster.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_raster(id, tracked_dict, kwargs): - """Format data from a stx_raster call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing 'raster_digit_df' key with raster plot data - kwargs (dict): Keyword arguments passed to stx_raster - - Returns: - pd.DataFrame: Formatted raster plot data - """ - # Check if args is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get the raster_digit_df from args - raster_df = tracked_dict.get("raster_digit_df") - - if raster_df is None or not isinstance(raster_df, pd.DataFrame): - return pd.DataFrame() - - # Create a copy to avoid modifying the original - result = raster_df.copy() - - # Add prefix to column names using single source of truth - if id is not None: - # Rename columns with ID prefix - result.columns = [ - get_csv_column_name(f"raster-{col}", ax_row, ax_col, trace_id=trace_id) - for col in result.columns - ] - - return result diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_rectangle.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_rectangle.py deleted file mode 100755 index 0fe1fbac2..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_rectangle.py +++ /dev/null @@ -1,129 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 12:00:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_rectangle.py - -"""CSV formatter for stx_rectangle() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_rectangle(id, tracked_dict, kwargs): - """Format data from a stx_rectangle call. - - Uses standard column naming convention: - (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stx_rectangle - - Returns: - pd.DataFrame: Formatted rectangle data - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get standard column names - x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - y_col = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - width_col = get_csv_column_name("width", ax_row, ax_col, trace_id=trace_id) - height_col = get_csv_column_name("height", ax_row, ax_col, trace_id=trace_id) - - # Try to get rectangle parameters directly from tracked_dict - x = tracked_dict.get("x") - y = tracked_dict.get("y") - width = tracked_dict.get("width") - height = tracked_dict.get("height") - - # If direct parameters aren't available, try the args - if any(param is None for param in [x, y, width, height]): - args = tracked_dict.get("args", []) - - # Rectangles defined by [x, y, width, height] - if len(args) >= 4: - x, y, width, height = args[0], args[1], args[2], args[3] - - # If we have all required parameters, create the DataFrame - if all(param is not None for param in [x, y, width, height]): - try: - # Handle single rectangle - if all( - isinstance(val, (int, float, np.number)) - for val in [x, y, width, height] - ): - return pd.DataFrame( - { - x_col: [x], - y_col: [y], - width_col: [width], - height_col: [height], - } - ) - - # Handle multiple rectangles (arrays) - elif all( - isinstance(val, (np.ndarray, list)) for val in [x, y, width, height] - ): - try: - return pd.DataFrame( - { - x_col: x, - y_col: y, - width_col: width, - height_col: height, - } - ) - except ValueError: - # Handle case where arrays might be different lengths - result = pd.DataFrame() - result[x_col] = pd.Series(x) - result[y_col] = pd.Series(y) - result[width_col] = pd.Series(width) - result[height_col] = pd.Series(height) - return result - except Exception: - # Fallback for rectangle in case of any errors - try: - return pd.DataFrame( - { - x_col: [float(x) if x is not None else 0], - y_col: [float(y) if y is not None else 0], - width_col: [float(width) if width is not None else 0], - height_col: [float(height) if height is not None else 0], - } - ) - except (TypeError, ValueError): - pass - - # Check directly in the kwargs for the parameters - rect_x = kwargs.get("x") - rect_y = kwargs.get("y") - rect_w = kwargs.get("width") - rect_h = kwargs.get("height") - - if all(param is not None for param in [rect_x, rect_y, rect_w, rect_h]): - try: - return pd.DataFrame( - { - x_col: [float(rect_x)], - y_col: [float(rect_y)], - width_col: [float(rect_w)], - height_col: [float(rect_h)], - } - ) - except (TypeError, ValueError): - pass - - # Default empty DataFrame if nothing could be processed - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter.py deleted file mode 100755 index 5fd6abc9c..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -"""CSV formatter for stx_scatter() calls.""" - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_stx_scatter(id, tracked_dict, kwargs): - """Format data from stx_scatter call for CSV export. - - Parameters - ---------- - id : str - Tracking identifier - tracked_dict : dict - Dictionary containing tracked data with 'scatter_df' key - kwargs : dict - Additional keyword arguments (unused) - - Returns - ------- - pd.DataFrame - Formatted scatter data with standardized column names - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get scatter_df from tracked data - scatter_df = tracked_dict.get("scatter_df") - if scatter_df is not None and isinstance(scatter_df, pd.DataFrame): - result = scatter_df.copy() - renamed = {} - for col in result.columns: - renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) - return result.rename(columns=renamed) - - # Fallback to args if scatter_df not found - args = tracked_dict.get("args", []) - if len(args) >= 2: - col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - return pd.DataFrame({col_x: args[0], col_y: args[1]}) - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter_hist.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter_hist.py deleted file mode 100755 index f5768126e..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter_hist.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_scatter_hist(id, tracked_dict, kwargs): - """Format data from a stx_scatter_hist call. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stx_scatter_hist - - Returns: - pd.DataFrame: Formatted scatter histogram data - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to extract axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Extract data from tracked_dict - x = tracked_dict.get("x") - y = tracked_dict.get("y") - - if x is not None and y is not None: - # Create base DataFrame with x and y values - df = pd.DataFrame( - { - get_csv_column_name( - "scatter_hist_x", ax_row, ax_col, trace_id=trace_id - ): x, - get_csv_column_name( - "scatter_hist_y", ax_row, ax_col, trace_id=trace_id - ): y, - } - ) - - # Add histogram data if available - hist_x = tracked_dict.get("hist_x") - hist_y = tracked_dict.get("hist_y") - bin_edges_x = tracked_dict.get("bin_edges_x") - bin_edges_y = tracked_dict.get("bin_edges_y") - - # If we have histogram data - if hist_x is not None and bin_edges_x is not None: - # Calculate bin centers for x-axis histogram - bin_centers_x = 0.5 * (bin_edges_x[1:] + bin_edges_x[:-1]) - - # Create a DataFrame for x histogram data - hist_x_df = pd.DataFrame( - { - get_csv_column_name( - "hist_x_bin_centers", ax_row, ax_col, trace_id=trace_id - ): bin_centers_x, - get_csv_column_name( - "hist_x_counts", ax_row, ax_col, trace_id=trace_id - ): hist_x, - } - ) - - # Add it to the main DataFrame using a MultiIndex - for i, (center, count) in enumerate(zip(bin_centers_x, hist_x)): - df.loc[ - f"hist_x_{i}", - get_csv_column_name( - "hist_x_bin", ax_row, ax_col, trace_id=trace_id - ), - ] = center - df.loc[ - f"hist_x_{i}", - get_csv_column_name( - "hist_x_count", ax_row, ax_col, trace_id=trace_id - ), - ] = count - - # If we have y histogram data - if hist_y is not None and bin_edges_y is not None: - # Calculate bin centers for y-axis histogram - bin_centers_y = 0.5 * (bin_edges_y[1:] + bin_edges_y[:-1]) - - # Create a DataFrame for y histogram data - hist_y_df = pd.DataFrame( - { - get_csv_column_name( - "hist_y_bin_centers", ax_row, ax_col, trace_id=trace_id - ): bin_centers_y, - get_csv_column_name( - "hist_y_counts", ax_row, ax_col, trace_id=trace_id - ): hist_y, - } - ) - - # Add it to the main DataFrame using a MultiIndex - for i, (center, count) in enumerate(zip(bin_centers_y, hist_y)): - df.loc[ - f"hist_y_{i}", - get_csv_column_name( - "hist_y_bin", ax_row, ax_col, trace_id=trace_id - ), - ] = center - df.loc[ - f"hist_y_{i}", - get_csv_column_name( - "hist_y_count", ax_row, ax_col, trace_id=trace_id - ), - ] = count - - return df - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_shaded_line.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_shaded_line.py deleted file mode 100755 index ec26664ac..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_shaded_line.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 03:00:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_shaded_line.py - -"""CSV formatter for stx_shaded_line() calls - uses standard column naming.""" - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_shaded_line(id, tracked_dict, kwargs): - """Format data from a stx_shaded_line call. - - Processes stx_shaded_line data for CSV export using standard column naming - (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stx_shaded_line - - Returns: - pd.DataFrame: Formatted shaded line data with standard column names - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # If we have a plot_df from plotting methods, use that directly - if "plot_df" in tracked_dict and isinstance(tracked_dict["plot_df"], pd.DataFrame): - plot_df = tracked_dict["plot_df"] - # Rename columns using standard naming convention - renamed = {} - for col in plot_df.columns: - renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) - return plot_df.rename(columns=renamed) - - # Try getting the individual components - x = tracked_dict.get("x") - y_middle = tracked_dict.get("y_middle") - y_lower = tracked_dict.get("y_lower") - y_upper = tracked_dict.get("y_upper") - - # If we have all necessary components - if x is not None and y_middle is not None and y_lower is not None: - x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - y_col = get_csv_column_name("y-middle", ax_row, ax_col, trace_id=trace_id) - lower_col = get_csv_column_name("y-lower", ax_row, ax_col, trace_id=trace_id) - upper_col = get_csv_column_name("y-upper", ax_row, ax_col, trace_id=trace_id) - - data = { - x_col: x, - y_col: y_middle, - lower_col: y_lower, - } - - if y_upper is not None: - data[upper_col] = y_upper - else: - # If only y_lower is provided, assume it's symmetric around y_middle - data[upper_col] = y_middle + (y_middle - y_lower) - - return pd.DataFrame(data) - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_violin.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_violin.py deleted file mode 100755 index 4a6b11dcb..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_violin.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 12:00:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_violin.py - -"""CSV formatter for stx_violin() calls - uses standard column naming.""" - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_plot_violin(id, tracked_dict, kwargs): - """Format data from a stx_violin call. - - Formats data in a long-format for better compatibility. - Uses standard column naming convention: - (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to stx_violin - - Returns: - pd.DataFrame: Formatted violin plot data in long format - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get standard column names - group_col = get_csv_column_name("group", ax_row, ax_col, trace_id=trace_id) - value_col = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) - - # Extract data from tracked_dict - data = tracked_dict.get("data") - - if data is not None: - # If data is a simple array or list - if isinstance(data, (np.ndarray, list)) and not isinstance( - data[0], (list, np.ndarray, dict) - ): - # Convert to long format with group and value columns - rows = [{group_col: "0", value_col: val} for val in data] - return pd.DataFrame(rows) - - # If data is a list of arrays (multiple violin plots) - elif isinstance(data, (list, tuple)) and all( - isinstance(x, (list, np.ndarray)) for x in data - ): - # Get labels if available - labels = tracked_dict.get("labels") - - # Convert to long format - rows = [] - for i, values in enumerate(data): - # Use label if available, otherwise use index - group = labels[i] if labels and i < len(labels) else f"group{i:02d}" - for val in values: - rows.append({group_col: str(group), value_col: val}) - - if rows: - return pd.DataFrame(rows) - - # If data is a dictionary - elif isinstance(data, dict): - # Convert to long format - rows = [] - for group, values in data.items(): - for val in values: - rows.append({group_col: str(group), value_col: val}) - - if rows: - return pd.DataFrame(rows) - - # If data is a DataFrame - elif isinstance(data, pd.DataFrame): - # For DataFrame data with x and y columns - x = tracked_dict.get("x") - y = tracked_dict.get("y") - - if ( - x is not None - and y is not None - and x in data.columns - and y in data.columns - ): - # Convert to long format - rows = [] - for group_name, group_data in data.groupby(x): - for val in group_data[y]: - rows.append({group_col: str(group_name), value_col: val}) - - if rows: - return pd.DataFrame(rows) - else: - # For other dataframes, melt to long format - try: - # Try to melt to long format - result = pd.melt(data) - # Rename columns using standard naming - result.columns = [group_col, value_col] - return result - except Exception: - # If melt fails, just return empty - pass - - return pd.DataFrame() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_text.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_text.py deleted file mode 100755 index 3eaecddd1..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_text.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-10 12:00:00 (ywatanabe)" -# File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_text.py - -"""CSV formatter for text() calls - uses standard column naming.""" - -from __future__ import annotations - -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_text(id, tracked_dict, kwargs): - """Format data from a text call. - - Uses standard column naming convention: - (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to text - - Returns: - pd.DataFrame: Formatted text position data - """ - # Check if tracked_dict is empty or not a dictionary - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - # Get standard column names - x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) - y_col = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) - content_col = get_csv_column_name("content", ax_row, ax_col, trace_id=trace_id) - - # Get the args from tracked_dict - args = tracked_dict.get("args", []) - - # Extract x, y, and text content if available - if len(args) >= 2: - x, y = args[0], args[1] - text_content = args[2] if len(args) >= 3 else None - - data = {x_col: [x], y_col: [y]} - - if text_content is not None: - data[content_col] = [text_content] - - return pd.DataFrame(data) - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py deleted file mode 100755 index bbf0dc28c..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_violin(id, tracked_dict, kwargs): - """Format data from a violin call. - - Formats data in a long-format for better compatibility. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to violin plot - - Returns: - pd.DataFrame: Formatted violin data in long format - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - args = tracked_dict.get("args", []) - - if len(args) >= 1: - data = args[0] - - # Handle case when data is a simple array or list - if isinstance(data, (list, np.ndarray)) and not isinstance( - data[0], (list, np.ndarray, dict) - ): - rows = [{"group": "0", "value": val} for val in data] - df = pd.DataFrame(rows) - col_group = get_csv_column_name("group", ax_row, ax_col, trace_id=trace_id) - col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) - df.columns = [col_group, col_value] - return df - - # Handle case when data is a dictionary - elif isinstance(data, dict): - rows = [] - for group, values in data.items(): - for val in values: - rows.append({"group": str(group), "value": val}) - - if rows: - df = pd.DataFrame(rows) - col_group = get_csv_column_name( - "group", ax_row, ax_col, trace_id=trace_id - ) - col_value = get_csv_column_name( - "value", ax_row, ax_col, trace_id=trace_id - ) - df.columns = [col_group, col_value] - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py deleted file mode 100755 index 5732728f7..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py - -import numpy as np -import pandas as pd - -from scitex.plt.utils._csv_column_naming import get_csv_column_name - -from ._format_plot import _parse_tracking_id - - -def _format_violinplot(id, tracked_dict, kwargs): - """Format data from a violinplot call. - - Formats data in a long-format for better compatibility. - - Args: - id (str): Identifier for the plot - tracked_dict (dict): Dictionary containing tracked data - kwargs (dict): Keyword arguments passed to violinplot - - Returns: - pd.DataFrame: Formatted violinplot data in long format - """ - if not tracked_dict or not isinstance(tracked_dict, dict): - return pd.DataFrame() - - # Parse tracking ID to get axes position and trace ID - ax_row, ax_col, trace_id = _parse_tracking_id(id) - - args = tracked_dict.get("args", []) - - if len(args) >= 1: - data = args[0] - - # Handle case when data is a simple array or list - if isinstance(data, (list, np.ndarray)) and not isinstance( - data[0], (list, np.ndarray, dict) - ): - rows = [{"group": "0", "value": val} for val in data] - df = pd.DataFrame(rows) - # Use structured column naming - col_group = get_csv_column_name("group", ax_row, ax_col, trace_id=trace_id) - col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) - df.columns = [col_group, col_value] - return df - - # Handle case when data is a dictionary - elif isinstance(data, dict): - rows = [] - for group, values in data.items(): - for val in values: - rows.append({"group": str(group), "value": val}) - - if rows: - df = pd.DataFrame(rows) - col_group = get_csv_column_name( - "group", ax_row, ax_col, trace_id=trace_id - ) - col_value = get_csv_column_name( - "value", ax_row, ax_col, trace_id=trace_id - ) - df.columns = [col_group, col_value] - return df - - # Handle case when data is a list of arrays - elif isinstance(data, (list, tuple)) and all( - isinstance(x, (list, np.ndarray)) for x in data - ): - rows = [] - for i, values in enumerate(data): - for val in values: - rows.append({"group": str(i), "value": val}) - - if rows: - df = pd.DataFrame(rows) - col_group = get_csv_column_name( - "group", ax_row, ax_col, trace_id=trace_id - ) - col_value = get_csv_column_name( - "value", ax_row, ax_col, trace_id=trace_id - ) - df.columns = [col_group, col_value] - return df - - return pd.DataFrame() - - -# EOF diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py deleted file mode 100755 index 55a45a600..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py +++ /dev/null @@ -1,207 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 22:05:10 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import unittest - -import numpy as np -import pandas as pd - -# Import formatters directly -from ._format_plot import _format_plot -from ._format_plot_ecdf import _format_plot_ecdf -from ._format_plot_heatmap import _format_plot_heatmap -from ._format_plot_kde import _format_plot_kde -from ._format_plot_scatter_hist import _format_plot_scatter_hist -from ._format_plot_shaded_line import _format_plot_shaded_line -from ._format_plot_violin import _format_plot_violin - - -class FormattersTest(unittest.TestCase): - """Test the formatter functions.""" - - def test_format_plot_kde(self): - """Test _format_plot_kde function.""" - # Test case 1: Normal input - tracked_dict = { - "x": np.linspace(-3, 3, 100), - "kde": np.exp(-np.linspace(-3, 3, 100) ** 2 / 2), - "n": 500, - } - id = "test_kde" - df = _format_plot_kde(id, tracked_dict, {}) - - # Verify columns - self.assertIn(f"{id}_kde_x", df.columns) - self.assertIn(f"{id}_kde_density", df.columns) - self.assertIn(f"{id}_kde_n", df.columns) - - # Test case 2: Empty tracked_dict - df = _format_plot_kde(id, {}, {}) - self.assertTrue(df.empty) - - # Test case 3: Missing 'x' key - tracked_dict = {"kde": np.exp(-np.linspace(-3, 3, 100) ** 2 / 2)} - df = _format_plot_kde(id, tracked_dict, {}) - self.assertTrue(df.empty) - - def test_format_plot(self): - """Test _format_plot function.""" - # Test case 1: Normal input - tracked_dict = { - "plot_df": pd.DataFrame( - {"x": np.linspace(0, 10, 100), "y": np.sin(np.linspace(0, 10, 100))} - ) - } - id = "test_plot" - df = _format_plot(id, tracked_dict, {}) - - # Verify it returned the DataFrame with added prefix - self.assertFalse(df.empty) - - # Test case 2: Empty tracked_dict - df = _format_plot(id, {}, {}) - self.assertTrue(df.empty) - - def test_format_plot_ecdf(self): - """Test _format_plot_ecdf function.""" - # Test case 1: Normal input - tracked_dict = { - "ecdf_df": pd.DataFrame( - {"x": np.linspace(-3, 3, 100), "ecdf": np.linspace(0, 1, 100)} - ) - } - id = "test_ecdf" - df = _format_plot_ecdf(id, tracked_dict, {}) - - # Verify it returned the DataFrame - self.assertFalse(df.empty) - - # Test case 2: Empty tracked_dict - df = _format_plot_ecdf(id, {}, {}) - self.assertTrue(df.empty) - - def test_format_plot_heatmap(self): - """Test _format_plot_heatmap function.""" - # Test case 1: Normal input with labels - data = np.random.rand(3, 4) - x_labels = ["A", "B", "C"] - y_labels = ["W", "X", "Y", "Z"] - - tracked_dict = {"data": data, "x_labels": x_labels, "y_labels": y_labels} - id = "test_heatmap" - df = _format_plot_heatmap(id, tracked_dict, {}) - - # Verify it returned the DataFrame with the expected shape - self.assertFalse(df.empty) - self.assertEqual(df.shape[0], 12) # 3 rows * 4 columns = 12 cells - # We should have 5 columns: row, col, value, row_label, col_label - self.assertEqual(df.shape[1], 5) - - # Test case 2: No labels - tracked_dict = {"data": data} - df = _format_plot_heatmap(id, tracked_dict, {}) - self.assertFalse(df.empty) - - # Test case 3: Empty tracked_dict - df = _format_plot_heatmap(id, {}, {}) - self.assertTrue(df.empty) - - def test_format_plot_violin(self): - """Test _format_plot_violin function.""" - # Test case 1: List data - data = [np.random.normal(0, 1, 100), np.random.normal(2, 1, 100)] - labels = ["Group A", "Group B"] - - tracked_dict = {"data": data, "labels": labels} - id = "test_violin" - df = _format_plot_violin(id, tracked_dict, {}) - - # Verify it returned the DataFrame - self.assertFalse(df.empty) - - # Test case 2: DataFrame data - data_df = pd.DataFrame( - { - "values": np.concatenate( - [np.random.normal(0, 1, 100), np.random.normal(2, 1, 100)] - ), - "group": ["A"] * 100 + ["B"] * 100, - } - ) - tracked_dict = {"data": data_df, "x": "group", "y": "values"} - df = _format_plot_violin(id, tracked_dict, {}) - self.assertFalse(df.empty) - - # Test case 3: Empty tracked_dict - df = _format_plot_violin(id, {}, {}) - self.assertTrue(df.empty) - - def test_format_plot_shaded_line(self): - """Test _format_plot_shaded_line function.""" - # Test case 1: Normal input - tracked_dict = { - "plot_df": pd.DataFrame( - { - "x": np.linspace(0, 10, 100), - "y_lower": np.sin(np.linspace(0, 10, 100)) - 0.2, - "y_middle": np.sin(np.linspace(0, 10, 100)), - "y_upper": np.sin(np.linspace(0, 10, 100)) + 0.2, - } - ) - } - id = "test_shaded" - df = _format_plot_shaded_line(id, tracked_dict, {}) - - # Verify it returned the DataFrame - self.assertFalse(df.empty) - - # Test case 2: Empty tracked_dict - df = _format_plot_shaded_line(id, {}, {}) - self.assertTrue(df.empty) - - def test_format_plot_scatter_hist(self): - """Test _format_plot_scatter_hist function.""" - # Test case 1: Normal input - tracked_dict = { - "x": np.random.normal(0, 1, 100), - "y": np.random.normal(0, 1, 100), - "hist_x": np.random.rand(10), - "hist_y": np.random.rand(10), - "bin_edges_x": np.linspace(-3, 3, 11), - "bin_edges_y": np.linspace(-3, 3, 11), - } - id = "test_scatter_hist" - df = _format_plot_scatter_hist(id, tracked_dict, {}) - - # Verify it returned the DataFrame with expected columns - self.assertFalse(df.empty) - self.assertTrue( - any(col.startswith(f"{id}_scatter_hist_x") for col in df.columns) - ) - self.assertTrue( - any(col.startswith(f"{id}_scatter_hist_y") for col in df.columns) - ) - - # Test case 2: Missing keys - tracked_dict = { - "x": np.random.normal(0, 1, 100), - "y": np.random.normal(0, 1, 100), - } - df = _format_plot_scatter_hist(id, tracked_dict, {}) - self.assertFalse(df.empty) # Should still work with just x,y - - # Test case 3: Empty tracked_dict - df = _format_plot_scatter_hist(id, {}, {}) - self.assertTrue(df.empty) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py b/src/scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py deleted file mode 100755 index 8b9807c23..000000000 --- a/src/scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py +++ /dev/null @@ -1,360 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 23:14:10 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py -# ---------------------------------------- -import os -import sys - -import matplotlib -import numpy as np -import pandas as pd - -matplotlib.use("Agg") # Non-interactive backend - -# Add src to path if needed -src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../../")) -if src_path not in sys.path: - sys.path.insert(0, src_path) - -import scitex - -# Create output directory -OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "formatter_test_output") -os.makedirs(OUTPUT_DIR, exist_ok=True) - - -def test_all_formatters(): - """ - Test all formatters by creating actual plots and saving both image and CSV files. - Each function will create a different type of plot, save it, and verify the CSV export. - """ - # Test each formatter with a real plot - test_plot_kde() - test_plot_image() - test_plot_shaded_line() - test_plot_scatter_hist() - test_plot_violin() - test_plot_heatmap() - test_plot_ecdf() - test_multiple_plots() - - -def test_plot_kde(): - """Test KDE plotting and CSV export.""" - print("Testing stx_kde...") - - # Create figure - fig, ax = scitex.plt.subplots() - - # Generate data - np.random.seed(42) # For reproducibility - data = np.concatenate([np.random.normal(0, 1, 500), np.random.normal(5, 1, 300)]) - - # Plot with ID for tracking - ax.stx_kde(data, label="Bimodal Distribution", id="kde_test") - - # Style the plot - ax.set_xyt("Value", "Density", "KDE Test") - ax.legend() - - # Save both image and data - save_path = os.path.join(OUTPUT_DIR, "kde_test.png") - scitex.io.save(fig, save_path) - - # Verify CSV was created - csv_path = save_path.replace(".png", ".csv") - assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" - - # Read CSV and verify contents - df = pd.read_csv(csv_path) - assert "kde_test_kde_x" in df.columns, "Expected column 'kde_test_kde_x' not found" - assert ( - "kde_test_kde_density" in df.columns - ), "Expected column 'kde_test_kde_density' not found" - - # Close figure - scitex.plt.close(fig) - print("✓ stx_kde test successful") - - -def test_plot_image(): - """Test image plotting and CSV export.""" - print("Testing stx_image...") - - # Create figure - fig, ax = scitex.plt.subplots() - - # Generate data - np.random.seed(42) # For reproducibility - data = np.random.rand(20, 20) - - # Plot with ID for tracking - ax.stx_image(data, cmap="viridis", id="image_test") - - # Style the plot - ax.set_xyt("X", "Y", "Image Test") - - # Save both image and data - save_path = os.path.join(OUTPUT_DIR, "image_test.png") - scitex.io.save(fig, save_path) - - # Verify CSV was created - csv_path = save_path.replace(".png", ".csv") - assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" - - # Read CSV and verify contents - df = pd.read_csv(csv_path) - # The formatter should have converted the 2D array to a DataFrame - assert not df.empty, "CSV file is empty" - - # Close figure - scitex.plt.close(fig) - print("✓ stx_image test successful") - - -def test_plot_shaded_line(): - """Test shaded line plotting and CSV export.""" - print("Testing stx_shaded_line...") - - # Create figure - fig, ax = scitex.plt.subplots() - - # Generate data - np.random.seed(42) # For reproducibility - x = np.linspace(0, 10, 100) - y_middle = np.sin(x) - y_lower = y_middle - 0.2 - y_upper = y_middle + 0.2 - - # Plot with ID for tracking - ax.stx_shaded_line( - x, y_lower, y_middle, y_upper, label="Sine with error", id="shaded_line_test" - ) - - # Style the plot - ax.set_xyt("X", "Y", "Shaded Line Test") - ax.legend() - - # Save both image and data - save_path = os.path.join(OUTPUT_DIR, "shaded_line_test.png") - scitex.io.save(fig, save_path) - - # Verify CSV was created - csv_path = save_path.replace(".png", ".csv") - assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" - - # Read CSV and verify contents - df = pd.read_csv(csv_path) - assert not df.empty, "CSV file is empty" - - # Close figure - scitex.plt.close(fig) - print("✓ stx_shaded_line test successful") - - -def test_plot_scatter_hist(): - """Test scatter histogram plotting and CSV export.""" - print("Testing stx_scatter_hist...") - - # Create figure - fig, ax = scitex.plt.subplots(figsize=(8, 8)) - - # Generate data - np.random.seed(42) # For reproducibility - x = np.random.normal(0, 1, 500) - y = x + np.random.normal(0, 0.5, 500) - - # Plot with ID for tracking - ax.stx_scatter_hist(x, y, hist_bins=30, scatter_alpha=0.7, id="scatter_hist_test") - - # Style the plot - ax.set_xyt("X Values", "Y Values", "Scatter Histogram Test") - - # Save both image and data - save_path = os.path.join(OUTPUT_DIR, "scatter_hist_test.png") - scitex.io.save(fig, save_path) - - # Verify CSV was created - csv_path = save_path.replace(".png", ".csv") - assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" - - # Read CSV and verify contents - df = pd.read_csv(csv_path) - assert not df.empty, "CSV file is empty" - - # Close figure - scitex.plt.close(fig) - print("✓ stx_scatter_hist test successful") - - -def test_plot_violin(): - """Test violin plotting and CSV export.""" - print("Testing stx_violin...") - - # Create figure - fig, ax = scitex.plt.subplots() - - # Generate data - np.random.seed(42) # For reproducibility - data = [ - np.random.normal(0, 1, 100), - np.random.normal(2, 1.5, 100), - np.random.normal(5, 0.8, 100), - ] - labels = ["Group A", "Group B", "Group C"] - - # Plot with ID for tracking - ax.stx_violin( - data, labels=labels, colors=["red", "blue", "green"], id="violin_test" - ) - - # Style the plot - ax.set_xyt("Groups", "Values", "Violin Plot Test") - - # Save both image and data - save_path = os.path.join(OUTPUT_DIR, "violin_test.png") - scitex.io.save(fig, save_path) - - # Verify CSV was created - csv_path = save_path.replace(".png", ".csv") - assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" - - # Read CSV and verify contents - df = pd.read_csv(csv_path) - assert not df.empty, "CSV file is empty" - - # Close figure - scitex.plt.close(fig) - print("✓ stx_violin test successful") - - -def test_plot_heatmap(): - """Test heatmap plotting and CSV export.""" - print("Testing stx_heatmap...") - - # Create figure - fig, ax = scitex.plt.subplots() - - # Generate data - np.random.seed(42) # For reproducibility - data = np.random.rand(5, 10) - x_labels = [f"X{ii + 1}" for ii in range(5)] - y_labels = [f"Y{ii + 1}" for ii in range(10)] - - # Plot with ID for tracking - ax.stx_heatmap( - data, - x_labels=x_labels, - y_labels=y_labels, - cbar_label="Values", - show_annot=True, - value_format="{x:.2f}", - cmap="viridis", - id="heatmap_test", - ) - - # Style the plot - ax.set_title("Heatmap Test") - - # Save both image and data - save_path = os.path.join(OUTPUT_DIR, "heatmap_test.png") - scitex.io.save(fig, save_path) - - # Verify CSV was created - csv_path = save_path.replace(".png", ".csv") - assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" - - # Read CSV and verify contents - df = pd.read_csv(csv_path) - assert not df.empty, "CSV file is empty" - - # Close figure - scitex.plt.close(fig) - print("✓ stx_heatmap test successful") - - -def test_plot_ecdf(): - """Test ECDF plotting and CSV export.""" - print("Testing stx_ecdf...") - - # Create figure - fig, ax = scitex.plt.subplots() - - # Generate data - np.random.seed(42) # For reproducibility - data = np.random.normal(0, 1, 1000) - - # Plot with ID for tracking - ax.stx_ecdf(data, label="Normal Distribution", id="ecdf_test") - - # Style the plot - ax.set_xyt("Value", "Cumulative Probability", "ECDF Test") - ax.legend() - - # Save both image and data - save_path = os.path.join(OUTPUT_DIR, "ecdf_test.png") - scitex.io.save(fig, save_path) - - # Verify CSV was created - csv_path = save_path.replace(".png", ".csv") - assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" - - # Read CSV and verify contents - df = pd.read_csv(csv_path) - assert not df.empty, "CSV file is empty" - - # Close figure - scitex.plt.close(fig) - print("✓ stx_ecdf test successful") - - -def test_multiple_plots(): - """Test multiple plots on the same axis.""" - print("Testing multiple plots on the same axis...") - - # Create figure - fig, ax = scitex.plt.subplots() - - # Generate data - np.random.seed(42) # For reproducibility - x = np.linspace(0, 10, 100) - y1 = np.sin(x) - y2 = np.cos(x) - - # Create multiple plots with different IDs - ax.stx_line(y1, label="Sine", id="multi_test_sine") - ax.stx_line(y2, label="Cosine", id="multi_test_cosine") - - # Style the plot - ax.set_xyt("X", "Y", "Multiple Plots Test") - ax.legend() - - # Save both image and data - save_path = os.path.join(OUTPUT_DIR, "multiple_plots_test.png") - scitex.io.save(fig, save_path) - - # Verify CSV was created - csv_path = save_path.replace(".png", ".csv") - assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" - - # Read CSV and verify contents - df = pd.read_csv(csv_path) - assert not df.empty, "CSV file is empty" - - # Check that both plots are in the CSV - sine_cols = [col for col in df.columns if col.startswith("multi_test_sine")] - cosine_cols = [col for col in df.columns if col.startswith("multi_test_cosine")] - assert len(sine_cols) > 0, "Sine plot data not found in CSV" - assert len(cosine_cols) > 0, "Cosine plot data not found in CSV" - - # Close figure - scitex.plt.close(fig) - print("✓ Multiple plots test successful") - - -if __name__ == "__main__": - print("Starting formatter verification tests...") - test_all_formatters() - print("\nAll formatter tests completed successfully!") - print(f"Output files are in: {OUTPUT_DIR}") diff --git a/src/scitex/plt/_subplots/_fonts.py b/src/scitex/plt/_subplots/_fonts.py deleted file mode 100755 index 6dcf1732b..000000000 --- a/src/scitex/plt/_subplots/_fonts.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 -"""Font configuration for matplotlib figures.""" - -import os - -import matplotlib as mpl -import matplotlib.font_manager as fm - - -def configure_arial_font(): - """Configure Arial font for matplotlib if available. - - Returns - ------- - bool - True if Arial was successfully configured, False otherwise. - """ - arial_enabled = False - - # Try to find Arial - try: - fm.findfont("Arial", fallback_to_default=False) - arial_enabled = True - except Exception: - # Search for Arial font files and register them - arial_paths = [ - f - for f in fm.findSystemFonts() - if os.path.basename(f).lower().startswith("arial") - ] - - if arial_paths: - for path in arial_paths: - try: - fm.fontManager.addfont(path) - except Exception: - pass - - # Verify Arial is now available - try: - fm.findfont("Arial", fallback_to_default=False) - arial_enabled = True - except Exception: - pass - - # Configure matplotlib to use Arial if available - if arial_enabled: - mpl.rcParams["font.family"] = "Arial" - mpl.rcParams["font.sans-serif"] = [ - "Arial", - "Helvetica", - "DejaVu Sans", - "Liberation Sans", - ] - else: - # Warn about missing Arial - from scitex import logging as _logging - - _logger = _logging.getLogger(__name__) - _logger.warning( - "Arial font not found. Using fallback fonts (Helvetica/DejaVu Sans). " - "For publication figures with Arial: sudo apt-get install ttf-mscorefonts-installer && fc-cache -fv" - ) - - return arial_enabled - - -# Configure fonts at module import -_arial_enabled = configure_arial_font() - -# EOF diff --git a/src/scitex/plt/_subplots/_mm_layout.py b/src/scitex/plt/_subplots/_mm_layout.py deleted file mode 100755 index 76f5f934e..000000000 --- a/src/scitex/plt/_subplots/_mm_layout.py +++ /dev/null @@ -1,282 +0,0 @@ -#!/usr/bin/env python3 -"""Millimeter-based layout control for matplotlib figures.""" - -import matplotlib.pyplot as plt -import numpy as np - -from ._AxesWrapper import AxesWrapper -from ._AxisWrapper import AxisWrapper -from ._FigWrapper import FigWrapper - - -def create_with_mm_control( - *args, - track=True, - sharex=False, - sharey=False, - axes_width_mm=None, - axes_height_mm=None, - margin_left_mm=None, - margin_right_mm=None, - margin_bottom_mm=None, - margin_top_mm=None, - space_w_mm=None, - space_h_mm=None, - axes_thickness_mm=None, - tick_length_mm=None, - tick_thickness_mm=None, - trace_thickness_mm=None, - marker_size_mm=None, - axis_font_size_pt=None, - tick_font_size_pt=None, - title_font_size_pt=None, - legend_font_size_pt=None, - suptitle_font_size_pt=None, - label_pad_pt=None, - tick_pad_pt=None, - title_pad_pt=None, - font_family=None, - n_ticks=None, - mode=None, - dpi=None, - styles=None, - transparent=None, - theme=None, - **kwargs, -): - """Create figure with mm-based control over axes dimensions. - - Returns - ------- - tuple - (FigWrapper, AxisWrapper or AxesWrapper) - """ - from scitex.plt.utils import apply_style_mm, mm_to_inch - - # Parse nrows, ncols from args or kwargs - nrows, ncols = 1, 1 - if len(args) >= 1: - nrows = args[0] - elif "nrows" in kwargs: - nrows = kwargs.pop("nrows") - if len(args) >= 2: - ncols = args[1] - elif "ncols" in kwargs: - ncols = kwargs.pop("ncols") - - n_axes = nrows * ncols - - # Apply mode-specific defaults - if mode == "display": - scale_factor = 3.0 - dpi = dpi or 100 - else: - scale_factor = 1.0 - dpi = dpi or 300 - - # Set defaults with scaling - if axes_width_mm is None: - axes_width_mm = 30.0 * scale_factor - elif mode == "display": - axes_width_mm = axes_width_mm * scale_factor - - if axes_height_mm is None: - axes_height_mm = 21.0 * scale_factor - elif mode == "display": - axes_height_mm = axes_height_mm * scale_factor - - margin_left_mm = ( - margin_left_mm if margin_left_mm is not None else (5.0 * scale_factor) - ) - margin_right_mm = ( - margin_right_mm if margin_right_mm is not None else (2.0 * scale_factor) - ) - margin_bottom_mm = ( - margin_bottom_mm if margin_bottom_mm is not None else (5.0 * scale_factor) - ) - margin_top_mm = margin_top_mm if margin_top_mm is not None else (2.0 * scale_factor) - space_w_mm = space_w_mm if space_w_mm is not None else (3.0 * scale_factor) - space_h_mm = space_h_mm if space_h_mm is not None else (3.0 * scale_factor) - - # Handle list vs scalar for axes dimensions - if isinstance(axes_width_mm, (list, tuple)): - ax_widths_mm = list(axes_width_mm) - if len(ax_widths_mm) != n_axes: - raise ValueError( - f"axes_width_mm list length ({len(ax_widths_mm)}) " - f"must match nrows*ncols ({n_axes})" - ) - else: - ax_widths_mm = [axes_width_mm] * n_axes - - if isinstance(axes_height_mm, (list, tuple)): - ax_heights_mm = list(axes_height_mm) - if len(ax_heights_mm) != n_axes: - raise ValueError( - f"axes_height_mm list length ({len(ax_heights_mm)}) " - f"must match nrows*ncols ({n_axes})" - ) - else: - ax_heights_mm = [axes_height_mm] * n_axes - - # Calculate figure size from axes grid - ax_widths_2d = np.array(ax_widths_mm).reshape(nrows, ncols) - ax_heights_2d = np.array(ax_heights_mm).reshape(nrows, ncols) - - max_widths_per_col = ax_widths_2d.max(axis=0) - max_heights_per_row = ax_heights_2d.max(axis=1) - - total_width_mm = ( - margin_left_mm - + max_widths_per_col.sum() - + (ncols - 1) * space_w_mm - + margin_right_mm - ) - total_height_mm = ( - margin_bottom_mm - + max_heights_per_row.sum() - + (nrows - 1) * space_h_mm - + margin_top_mm - ) - - # Create figure - figsize_inch = (mm_to_inch(total_width_mm), mm_to_inch(total_height_mm)) - if transparent: - fig_mpl = plt.figure(figsize=figsize_inch, dpi=dpi, facecolor="none") - else: - fig_mpl = plt.figure(figsize=figsize_inch, dpi=dpi) - - # Store theme on figure - if theme is not None: - fig_mpl._scitex_theme = theme - - # Create axes array and position each one manually - axes_mpl_list = [] - ax_idx = 0 - - for row in range(nrows): - for col in range(ncols): - # Calculate position - left_mm = margin_left_mm + max_widths_per_col[:col].sum() + col * space_w_mm - rows_below = nrows - row - 1 - bottom_mm = ( - margin_bottom_mm - + max_heights_per_row[row + 1 :].sum() - + rows_below * space_h_mm - ) - - # Convert to figure coordinates [0-1] - left = left_mm / total_width_mm - bottom = bottom_mm / total_height_mm - width = ax_widths_mm[ax_idx] / total_width_mm - height = ax_heights_mm[ax_idx] / total_height_mm - - # Create axes - ax_mpl = fig_mpl.add_axes([left, bottom, width, height]) - if transparent: - ax_mpl.patch.set_alpha(0.0) - axes_mpl_list.append(ax_mpl) - - # Tag with metadata - ax_mpl._scitex_metadata = { - "created_with": "scitex.plt.subplots", - "mode": mode or "publication", - "axes_size_mm": (ax_widths_mm[ax_idx], ax_heights_mm[ax_idx]), - "position_in_grid": (row, col), - } - ax_idx += 1 - - # Apply styling to each axes - suptitle_font_size_pt_value = None - for i, ax_mpl in enumerate(axes_mpl_list): - # Determine which style dict to use - if styles is not None: - if isinstance(styles, list): - if len(styles) != n_axes: - raise ValueError( - f"styles list length ({len(styles)}) " - f"must match nrows*ncols ({n_axes})" - ) - style_dict = styles[i] - else: - style_dict = styles - else: - # Build style dict from individual parameters - style_dict = {} - if axes_thickness_mm is not None: - style_dict["axis_thickness_mm"] = axes_thickness_mm - if tick_length_mm is not None: - style_dict["tick_length_mm"] = tick_length_mm - if tick_thickness_mm is not None: - style_dict["tick_thickness_mm"] = tick_thickness_mm - if trace_thickness_mm is not None: - style_dict["trace_thickness_mm"] = trace_thickness_mm - if marker_size_mm is not None: - style_dict["marker_size_mm"] = marker_size_mm - if axis_font_size_pt is not None: - style_dict["axis_font_size_pt"] = axis_font_size_pt - if tick_font_size_pt is not None: - style_dict["tick_font_size_pt"] = tick_font_size_pt - if title_font_size_pt is not None: - style_dict["title_font_size_pt"] = title_font_size_pt - if legend_font_size_pt is not None: - style_dict["legend_font_size_pt"] = legend_font_size_pt - if suptitle_font_size_pt is not None: - style_dict["suptitle_font_size_pt"] = suptitle_font_size_pt - if label_pad_pt is not None: - style_dict["label_pad_pt"] = label_pad_pt - if tick_pad_pt is not None: - style_dict["tick_pad_pt"] = tick_pad_pt - if title_pad_pt is not None: - style_dict["title_pad_pt"] = title_pad_pt - if font_family is not None: - style_dict["font_family"] = font_family - if n_ticks is not None: - style_dict["n_ticks"] = n_ticks - - # Always add theme to style_dict - if theme is not None: - style_dict["theme"] = theme - - # Extract suptitle font size if available - if "suptitle_font_size_pt" in style_dict: - suptitle_font_size_pt_value = style_dict["suptitle_font_size_pt"] - - # Apply style if not empty - if style_dict: - apply_style_mm(ax_mpl, style_dict) - ax_mpl._scitex_metadata["style_mm"] = style_dict - - # Store suptitle font size in figure metadata - if suptitle_font_size_pt_value is not None: - fig_mpl._scitex_suptitle_font_size_pt = suptitle_font_size_pt_value - - # Wrap the figure - fig_scitex = FigWrapper(fig_mpl) - - # Reshape axes list - axes_array_mpl = np.array(axes_mpl_list).reshape(nrows, ncols) - - # Handle single axis case - if n_axes == 1: - ax_mpl_scalar = axes_array_mpl.item() - axis_scitex = AxisWrapper(fig_scitex, ax_mpl_scalar, track) - fig_scitex.axes = [axis_scitex] - ax_mpl_scalar._scitex_wrapper = axis_scitex - return fig_scitex, axis_scitex - - # Handle multiple axes case - axes_flat_scitex_list = [] - for ax_mpl in axes_mpl_list: - ax_scitex = AxisWrapper(fig_scitex, ax_mpl, track) - ax_mpl._scitex_wrapper = ax_scitex - axes_flat_scitex_list.append(ax_scitex) - - axes_array_scitex = np.array(axes_flat_scitex_list).reshape(nrows, ncols) - axes_scitex = AxesWrapper(fig_scitex, axes_array_scitex) - fig_scitex.axes = axes_scitex - - return fig_scitex, axes_scitex - - -# EOF diff --git a/src/scitex/plt/ax/__init__.py b/src/scitex/plt/ax/__init__.py deleted file mode 100755 index 44b3a98ff..000000000 --- a/src/scitex/plt/ax/__init__.py +++ /dev/null @@ -1,123 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 20:12:46 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/__init__.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/__init__.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from ._plot._add_fitted_line import add_fitted_line -from ._plot._plot_circular_hist import plot_circular_hist -from ._plot._plot_cube import plot_cube -from ._plot._plot_statistical_shaded_line import ( - stx_line, - stx_mean_ci, - stx_mean_std, - stx_median_iqr, -) -from ._plot._stx_conf_mat import stx_conf_mat -from ._plot._stx_ecdf import stx_ecdf -from ._plot._stx_fillv import stx_fillv - -# Plot -from ._plot._stx_heatmap import stx_heatmap -from ._plot._stx_image import stx_image -from ._plot._stx_joyplot import stx_joyplot -from ._plot._stx_raster import stx_raster -from ._plot._stx_rectangle import stx_rectangle -from ._plot._stx_scatter_hist import stx_scatter_hist -from ._plot._stx_shaded_line import stx_shaded_line -from ._plot._stx_violin import stx_violin - -# Adjust -from ._style._add_marginal_ax import add_marginal_ax -from ._style._add_panel import add_panel, panel -from ._style._auto_scale_axis import auto_scale_axis -from ._style._extend import extend -from ._style._force_aspect import force_aspect -from ._style._format_label import format_label as format_label_old -from ._style._format_units import format_label, format_label_auto -from ._style._hide_spines import hide_spines -from ._style._map_ticks import map_ticks -from ._style._rotate_labels import rotate_labels -from ._style._sci_note import sci_note -from ._style._set_log_scale import ( - add_log_scale_indicator, - set_log_scale, - smart_log_limits, -) -from ._style._set_n_ticks import set_n_ticks -from ._style._set_size import set_size -from ._style._set_supxyt import set_supxyt -from ._style._set_ticks import set_ticks -from ._style._set_xyt import set_xyt -from ._style._share_axes import ( - get_global_xlim, - get_global_ylim, - set_xlims, - set_ylims, - sharex, - sharexy, - sharey, -) -from ._style._shift import shift -from ._style._show_spines import ( - clean_spines, - scientific_spines, - show_all_spines, - show_box_spines, - show_classic_spines, - show_spines, - toggle_spines, -) -from ._style._style_barplot import style_barplot -from ._style._style_boxplot import style_boxplot -from ._style._style_errorbar import style_errorbar -from ._style._style_scatter import style_scatter -from ._style._style_suptitles import style_suptitles -from ._style._style_violinplot import style_violinplot - -# ################################################################################ -# # For Matplotlib Compatibility -# ################################################################################ -# import matplotlib.pyplot.axis as counter_part -# _local_module_attributes = list(globals().keys()) -# print(_local_module_attributes) - -# def __getattr__(name): -# """ -# Fallback to fetch attributes from matplotlib.pyplot -# if they are not defined directly in this module. -# """ -# try: -# # Get the attribute from matplotlib.pyplot -# return getattr(counter_part, name) -# except AttributeError: -# # Raise the standard error if not found in pyplot either -# raise AttributeError( -# f"module '{__name__}' nor matplotlib.pyplot has attribute '{name}'" -# ) from None - -# def __dir__(): -# """ -# Provide combined directory for tab completion, including -# attributes from this module and matplotlib.pyplot. -# """ -# # Get attributes defined explicitly in this module -# local_attrs = set(_local_module_attributes) -# # Get attributes from matplotlib.pyplot -# pyplot_attrs = set(dir(counter_part)) -# # Return the sorted union -# return sorted(local_attrs.union(pyplot_attrs)) - -# """ -# import matplotlib.pyplot as plt -# import scitex.plt as mplt - -# print(set(dir(mplt.ax)) - set(dir(plt.axis))) -# """ - -# EOF diff --git a/src/scitex/plt/ax/_plot/__init__.py b/src/scitex/plt/ax/_plot/__init__.py deleted file mode 100755 index 5c5e45b00..000000000 --- a/src/scitex/plt/ax/_plot/__init__.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 20:12:19 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/__init__.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot/__init__.py" -__DIR__ = os.path.dirname(__FILE__) - -from scitex.decorators import deprecated - -from ._add_fitted_line import add_fitted_line -from ._plot_circular_hist import plot_circular_hist -from ._plot_cube import plot_cube -from ._plot_statistical_shaded_line import ( - stx_line, - stx_mean_ci, - stx_mean_std, - stx_median_iqr, -) -from ._stx_conf_mat import stx_conf_mat -from ._stx_ecdf import stx_ecdf -from ._stx_fillv import stx_fillv -from ._stx_heatmap import stx_heatmap -from ._stx_image import stx_image -from ._stx_joyplot import stx_joyplot -from ._stx_raster import stx_raster -from ._stx_rectangle import stx_rectangle -from ._stx_scatter_hist import stx_scatter_hist -from ._stx_shaded_line import _plot_single_shaded_line, stx_shaded_line -from ._stx_violin import sns_plot_violin, stx_violin - - -# Backward-compatible aliases for renamed functions with deprecation warnings -@deprecated(reason="Use stx_line instead", forward_to="scitex.plt.ax._plot.stx_line") -def plot_line(*args, **kwargs): - pass - - -@deprecated( - reason="Use stx_mean_std instead", forward_to="scitex.plt.ax._plot.stx_mean_std" -) -def plot_mean_std(*args, **kwargs): - pass - - -@deprecated( - reason="Use stx_mean_ci instead", forward_to="scitex.plt.ax._plot.stx_mean_ci" -) -def plot_mean_ci(*args, **kwargs): - pass - - -@deprecated( - reason="Use stx_median_iqr instead", forward_to="scitex.plt.ax._plot.stx_median_iqr" -) -def plot_median_iqr(*args, **kwargs): - pass - - -__all__ = [ - "stx_scatter_hist", - "stx_heatmap", - "plot_circular_hist", - "stx_conf_mat", - "plot_cube", - "stx_ecdf", - "stx_fillv", - "stx_violin", - "sns_plot_violin", - "stx_image", - "stx_joyplot", - "stx_raster", - "stx_rectangle", - "stx_shaded_line", - "_plot_single_shaded_line", - "stx_line", - "stx_mean_std", - "stx_mean_ci", - "stx_median_iqr", - "add_fitted_line", - # Backward-compatible aliases - "plot_line", - "plot_mean_std", - "plot_mean_ci", - "plot_median_iqr", -] - -# EOF diff --git a/src/scitex/plt/ax/_plot/_add_fitted_line.py b/src/scitex/plt/ax/_plot/_add_fitted_line.py deleted file mode 100755 index d9fc262eb..000000000 --- a/src/scitex/plt/ax/_plot/_add_fitted_line.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-11-19 15:52:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_plot/_add_fitted_line.py - -""" -Add fitted regression line to scatter plots. -""" - -from typing import Dict, Optional, Tuple - -import numpy as np - - -def add_fitted_line( - ax, - x, - y, - color: str = "black", - linestyle: str = "--", - linewidth_mm: float = 0.2, - label: Optional[str] = None, - degree: int = 1, - show_stats: bool = True, - stats_position: float = 0.75, - stats_fontsize: int = 6, -) -> Tuple: - """ - Add a fitted polynomial line to a scatter plot with optional R² and p-value. - - Parameters - ---------- - ax : matplotlib.axes.Axes - Axes to plot on - x : array-like - X data - y : array-like - Y data - color : str, optional - Line color (default: 'black') - linestyle : str, optional - Line style (default: '--' for dashed) - linewidth_mm : float, optional - Line thickness in millimeters (default: 0.2mm) - label : str, optional - Label for the fitted line (default: None) - degree : int, optional - Polynomial degree for fitting (default: 1 for linear) - show_stats : bool, optional - Whether to display R² and p-value near the line (default: True) - Only applicable for linear fits (degree=1) - stats_position : float, optional - Position along x-axis (0-1 scale) for stats text (default: 0.75) - stats_fontsize : int, optional - Font size for statistics text in points (default: 6) - - Returns - ------- - line : Line2D - The fitted line object - coeffs : np.ndarray - Polynomial coefficients from np.polyfit - stats : StatResult or None - StatResult instance with correlation statistics (only for degree=1). - Use .to_dict() for dictionary format. - - Examples - -------- - >>> fig, ax = stx.plt.subplots(**stx.plt.presets.SCITEX_STYLE) - >>> scatter = ax.scatter(x, y) - >>> stx.plt.ax.add_fitted_line(ax, x, y) # Auto-shows R² and p - - >>> # Without statistics - >>> line, coeffs, stats = stx.plt.ax.add_fitted_line( - ... ax, x, y, show_stats=False - ... ) - - >>> # Custom position for stats - >>> line, coeffs, stats = stx.plt.ax.add_fitted_line( - ... ax, x, y, stats_position=0.5 - ... ) - """ - from scitex.plt.utils import mm_to_pt - - # Convert data to numpy arrays - x = np.asarray(x) - y = np.asarray(y) - - # Fit polynomial - coeffs = np.polyfit(x, y, degree) - poly_fn = np.poly1d(coeffs) - - # Generate fitted line points - x_fit = np.linspace(x.min(), x.max(), 100) - y_fit = poly_fn(x_fit) - - # Convert linewidth to points - lw_pt = mm_to_pt(linewidth_mm) - - # Plot fitted line - line = ax.plot( - x_fit, - y_fit, - color=color, - linestyle=linestyle, - linewidth=lw_pt, - label=label, - )[0] - - # Calculate and display statistics for linear regression (degree=1) - stats_result = None - if degree == 1 and show_stats: - # Import scitex.stats correlation test - from scitex.stats.tests.correlation import test_pearson - - # Calculate correlation statistics using scitex.stats - stats_result = test_pearson(x, y) - - # Position for text annotation - x_pos = x.min() + stats_position * (x.max() - x.min()) - y_pos = poly_fn(x_pos) - - # Format statistics text with R² and significance stars - r_squared = stats_result.effect_size["value"] # r_squared from effect_size - stars = stats_result.stars - - if stars and stars != "ns": # Only show if significant - stats_text = f"$R^2$ = {r_squared:.3f}{stars}" - else: # Not significant - stats_text = f"$R^2$ = {r_squared:.3f} (ns)" - - # Add text annotation near the line - ax.text( - x_pos, - y_pos, - stats_text, - verticalalignment="bottom", - fontsize=stats_fontsize, - ) - - # Store stats in axes metadata for embedding in saved figures - if not hasattr(ax, "_scitex_metadata"): - ax._scitex_metadata = {} - if "stats" not in ax._scitex_metadata: - ax._scitex_metadata["stats"] = [] - - # Add this StatResult to the stats list - ax._scitex_metadata["stats"].append(stats_result.to_dict()) - - return line, coeffs, stats_result - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_plot_circular_hist.py b/src/scitex/plt/ax/_plot/_plot_circular_hist.py deleted file mode 100755 index c93b71979..000000000 --- a/src/scitex/plt/ax/_plot/_plot_circular_hist.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-01 15:21:48 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_circular_hist.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot/_plot_circular_hist.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -# Time-stamp: "2024-02-03 13:10:50 (ywatanabe)" -import matplotlib -import numpy as np - -from ....plt.utils import assert_valid_axis - - -def plot_circular_hist( - axis, - radians, - bins=16, - density=True, - offset=0, - gaps=True, - color=None, - range_bias=0, -): - """ - Example: - fig, ax = plt.subplots(subplot_kw=dict(projection="polar")) - ax = scitex.plt.plot_circular_hist(ax, radians) - Produce a circular histogram of angles on ax. - - Parameters - ---------- - ax : matplotlib.axes._subplots.PolarAxesSubplot or scitex.plt._subplots.AxisWrapper - axis instance created with subplot_kw=dict(projection='polar'). - - radians : array - Angles to plot, expected in units of radians. - - bins : int, optional - Defines the number of equal-width bins in the range. The default is 16. - - density : bool, optional - If True plot frequency proportional to area. If False plot frequency - proportional to radius. The default is True. - - offset : float, optional - Sets the offset for the location of the 0 direction in units of - radians. The default is 0. - - gaps : bool, optional - Whether to allow gaps between bins. When gaps = False the bins are - forced to partition the entire [-pi, pi] range. The default is True. - - Returns - ------- - n : array or list of arrays - The number of values in each bin. - - bins : array - The edges of the bins. - - patches : `.BarContainer` or list of a single `.Polygon` - Container of individual artists used to create the histogram - or list of such containers if there are multiple input datasets. - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - # Wrap angles to [-pi, pi) - radians = (radians + np.pi) % (2 * np.pi) - np.pi - - # Force bins to partition entire circle - if not gaps: - bins = np.linspace(-np.pi, np.pi, num=bins + 1) - - # Bin data and record counts - n, bins = np.histogram( - radians, bins=bins, range=(-np.pi + range_bias, np.pi + range_bias) - ) - - # Compute width of each bin - widths = np.diff(bins) - - # By default plot frequency proportional to area - if density: - # Area to assign each bin - area = n / radians.size - # Calculate corresponding bin radius - radius = (area / np.pi) ** 0.5 - # Otherwise plot frequency proportional to radius - else: - radius = n - - mean_val = np.nanmean(radians) - std_val = np.nanstd(radians) - axis.axvline(mean_val, color=color) - axis.text(mean_val, 1, std_val) - - # Plot data on ax - patches = axis.bar( - bins[:-1], - radius, - zorder=1, - align="edge", - width=widths, - edgecolor=color, - alpha=0.9, - fill=False, - linewidth=1, - ) - - # Set the direction of the zero angle - axis.set_theta_offset(offset) - - # Remove ylabels for area plots (they are mostly obstructive) - if density: - axis.set_yticks([]) - - return n, bins, patches - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_plot_cube.py b/src/scitex/plt/ax/_plot/_plot_cube.py deleted file mode 100755 index 3ad46b88d..000000000 --- a/src/scitex/plt/ax/_plot/_plot_cube.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-01 15:21:37 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_cube.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot/_plot_cube.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from itertools import combinations, product - -import numpy as np - - -def plot_cube(ax, xlim, ylim, zlim, c="blue", alpha=1.0): - """ - Plot a 3D cube on the given axis. - - Args: - ax: Matplotlib 3D axis - xlim: Range for x-axis as a tuple (min, max) - ylim: Range for y-axis as a tuple (min, max) - zlim: Range for z-axis as a tuple (min, max) - c: Color of the cube edges (default: 'blue') - alpha: Transparency of the cube edges (default: 1.0) - - Returns: - Matplotlib axis with the cube plotted - """ - # Validate inputs - assert hasattr(ax, "plot3D"), "The axis must be a 3D axis with plot3D method" - assert len(xlim) == 2, "xlim must be a tuple of (min, max)" - assert len(ylim) == 2, "ylim must be a tuple of (min, max)" - assert len(zlim) == 2, "zlim must be a tuple of (min, max)" - assert xlim[0] < xlim[1], "xlim[0] must be less than xlim[1]" - assert ylim[0] < ylim[1], "ylim[0] must be less than ylim[1]" - assert zlim[0] < zlim[1], "zlim[0] must be less than zlim[1]" - - # Get all corners of the cube - corners = np.array(list(product(xlim, ylim, zlim))) - - # Draw edges between corners - for start, end in combinations(corners, 2): - # Check if the points form an edge (differ in exactly one dimension) - if np.sum(np.abs(start - end)) == xlim[1] - xlim[0]: - ax.plot3D(*zip(start, end), c=c, linewidth=3, alpha=alpha) - if np.sum(np.abs(start - end)) == ylim[1] - ylim[0]: - ax.plot3D(*zip(start, end), c=c, linewidth=3, alpha=alpha) - if np.sum(np.abs(start - end)) == zlim[1] - zlim[0]: - ax.plot3D(*zip(start, end), c=c, linewidth=3, alpha=alpha) - - return ax - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_plot_statistical_shaded_line.py b/src/scitex/plt/ax/_plot/_plot_statistical_shaded_line.py deleted file mode 100755 index 67fd0637b..000000000 --- a/src/scitex/plt/ax/_plot/_plot_statistical_shaded_line.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-04-30 20:50:45 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot_statistical_shaded_line.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot_statistical_shaded_line.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import numpy as np -import pandas as pd - -from ....plt.utils import assert_valid_axis -from ._stx_shaded_line import stx_shaded_line as scitex_plt_plot_shaded_line - - -def _format_sample_size(values_2d): - """Format sample size string, showing range if variable due to NaN. - - Parameters - ---------- - values_2d : np.ndarray, shape (n_samples, n_points) - 2D array where sample count may vary per column due to NaN. - - Returns - ------- - str - Formatted sample size string, e.g., "20" or "18-20". - """ - if values_2d.ndim == 1: - return "1" - - # Count non-NaN values per column (timepoint) - n_per_point = np.sum(~np.isnan(values_2d), axis=0) - n_min, n_max = int(n_per_point.min()), int(n_per_point.max()) - - if n_min == n_max: - return str(n_min) - else: - return f"{n_min}-{n_max}" - - -def stx_line(axis, values_1d, xx=None, **kwargs): - """ - Plot a simple line. - - Parameters - ---------- - axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axis to plot on - values_1d : array-like, shape (n_points,) - 1D array of y-values to plot - xx : array-like, shape (n_points,), optional - X coordinates for the data. If None, will use np.arange(len(values_1d)) - **kwargs - Additional keyword arguments passed to axis.plot() - - Returns - ------- - axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axis with the plot - df : pandas.DataFrame - DataFrame with x and y values - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - values_1d = np.asarray(values_1d) - assert values_1d.ndim <= 2, f"Data must be 1D or 2D, got {values_1d.ndim}D" - if xx is None: - xx = np.arange(len(values_1d)) - else: - xx = np.asarray(xx) - assert len(xx) == len( - values_1d - ), f"xx length ({len(xx)}) must match values_1d length ({len(values_1d)})" - - axis.plot(xx, values_1d, **kwargs) - return axis, pd.DataFrame({"x": xx, "y": values_1d}) - - -def stx_mean_std(axis, values_2d, xx=None, sd=1, **kwargs): - """ - Plot mean line with standard deviation shading. - - Parameters - ---------- - axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axis to plot on - values_2d : array-like, shape (n_samples, n_points) or (n_points,) - 2D array where mean and std are calculated across axis=0 (samples). - Can also be 1D for a single line without shading. - xx : array-like, shape (n_points,), optional - X coordinates for the data. If None, will use np.arange(n_points) - sd : float, optional - Number of standard deviations for the shaded region. Default is 1 - **kwargs - Additional keyword arguments passed to stx_shaded_line() - - Returns - ------- - axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axis with the plot - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - assert isinstance(sd, (int, float)), f"sd must be a number, got {type(sd)}" - assert sd >= 0, f"sd must be non-negative, got {sd}" - values_2d = np.asarray(values_2d) - assert values_2d.ndim <= 2, f"Data must be 1D or 2D, got {values_2d.ndim}D" - if xx is None: - xx = np.arange(values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d)) - else: - xx = np.asarray(xx) - expected_len = values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d) - assert ( - len(xx) == expected_len - ), f"xx length ({len(xx)}) must match values_2d length ({expected_len})" - - if values_2d.ndim == 1: - central = values_2d - error = np.zeros_like(central) - else: - central = np.nanmean(values_2d, axis=0) - error = np.nanstd(values_2d, axis=0) * sd - - y_lower = central - error - y_upper = central + error - - if "label" in kwargs and kwargs["label"]: - n_str = _format_sample_size(values_2d) - kwargs["label"] = f"{kwargs['label']} ($n$={n_str})" - - return scitex_plt_plot_shaded_line(axis, xx, y_lower, central, y_upper, **kwargs) - - -def stx_mean_ci(axis, values_2d, xx=None, perc=95, **kwargs): - """ - Plot mean line with confidence interval shading. - - Parameters - ---------- - axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axis to plot on - values_2d : array-like, shape (n_samples, n_points) or (n_points,) - 2D array where mean and percentiles are calculated across axis=0 (samples). - Can also be 1D for a single line without shading. - xx : array-like, shape (n_points,), optional - X coordinates for the data. If None, will use np.arange(n_points) - perc : float, optional - Confidence interval percentage (0-100). Default is 95 - **kwargs - Additional keyword arguments passed to stx_shaded_line() - - Returns - ------- - axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axis with the plot - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - assert isinstance(perc, (int, float)), f"perc must be a number, got {type(perc)}" - assert 0 <= perc <= 100, f"perc must be between 0 and 100, got {perc}" - values_2d = np.asarray(values_2d) - assert values_2d.ndim <= 2, f"Data must be 1D or 2D, got {values_2d.ndim}D" - - if xx is None: - xx = np.arange(values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d)) - else: - xx = np.asarray(xx) - - expected_len = values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d) - assert ( - len(xx) == expected_len - ), f"xx length ({len(xx)}) must match values_2d length ({expected_len})" - - if values_2d.ndim == 1: - central = values_2d - y_lower = central - y_upper = central - else: - central = np.nanmean(values_2d, axis=0) - # Calculate CI bounds - alpha = 1 - perc / 100 - y_lower_perc = alpha / 2 * 100 - y_upper_perc = (1 - alpha / 2) * 100 - y_lower = np.nanpercentile(values_2d, y_lower_perc, axis=0) - y_upper = np.nanpercentile(values_2d, y_upper_perc, axis=0) - - if "label" in kwargs and kwargs["label"]: - n_str = _format_sample_size(values_2d) - kwargs["label"] = f"{kwargs['label']} ($n$={n_str}, CI={perc}%)" - - return scitex_plt_plot_shaded_line(axis, xx, y_lower, central, y_upper, **kwargs) - - -def stx_median_iqr(axis, values_2d, xx=None, **kwargs): - """ - Plot median line with interquartile range shading. - - Parameters - ---------- - axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axis to plot on - values_2d : array-like, shape (n_samples, n_points) or (n_points,) - 2D array where median and IQR are calculated across axis=0 (samples). - Can also be 1D for a single line without shading. - xx : array-like, shape (n_points,), optional - X coordinates for the data. If None, will use np.arange(n_points) - **kwargs - Additional keyword arguments passed to stx_shaded_line() - - Returns - ------- - axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axis with the plot - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - values_2d = np.asarray(values_2d) - assert values_2d.ndim <= 2, f"Data must be 1D or 2D, got {values_2d.ndim}D" - - if xx is None: - xx = np.arange(values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d)) - else: - xx = np.asarray(xx) - - expected_len = values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d) - assert ( - len(xx) == expected_len - ), f"xx length ({len(xx)}) must match values_2d length ({expected_len})" - - if values_2d.ndim == 1: - central = values_2d - y_lower = central - y_upper = central - else: - central = np.nanmedian(values_2d, axis=0) - y_lower = np.nanpercentile(values_2d, 25, axis=0) - y_upper = np.nanpercentile(values_2d, 75, axis=0) - - if "label" in kwargs and kwargs["label"]: - n_str = _format_sample_size(values_2d) - kwargs["label"] = f"{kwargs['label']} ($n$={n_str}, IQR)" - - return scitex_plt_plot_shaded_line(axis, xx, y_lower, central, y_upper, **kwargs) - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_conf_mat.py b/src/scitex/plt/ax/_plot/_stx_conf_mat.py deleted file mode 100755 index 6209f9948..000000000 --- a/src/scitex/plt/ax/_plot/_stx_conf_mat.py +++ /dev/null @@ -1,140 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 15:08:16 (ywatanabe)" -# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_conf_mat.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot/_plot_conf_mat.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from typing import List, Optional, Tuple, Union - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns - -from scitex.plt.utils import assert_valid_axis -from scitex.plt.utils._calc_bacc_from_conf_mat import calc_bacc_from_conf_mat - -from .._style._extend import extend as scitex_plt_extend - - -def stx_conf_mat( - axis: plt.Axes, - conf_mat_2d: Union[np.ndarray, pd.DataFrame], - x_labels: Optional[List[str]] = None, - y_labels: Optional[List[str]] = None, - title: str = "Confusion Matrix", - cmap: str = "Blues", - cbar: bool = True, - cbar_kw: dict = {}, - label_rotation_xy: Tuple[float, float] = (15, 15), - x_extend_ratio: float = 1.0, - y_extend_ratio: float = 1.0, - calc_bacc: bool = False, - **kwargs, -) -> Union[plt.Axes, Tuple[plt.Axes, float]]: - """Creates a confusion matrix heatmap with optional balanced accuracy. - - Parameters - ---------- - axis : plt.Axes or scitex.plt._subplots.AxisWrapper - Matplotlib axes or scitex axis wrapper to plot on - conf_mat_2d : Union[np.ndarray, pd.DataFrame], shape (n_classes, n_classes) - 2D confusion matrix data (true labels × predicted labels) - x_labels : Optional[List[str]], optional - Labels for predicted classes - y_labels : Optional[List[str]], optional - Labels for true classes - title : str, optional - Plot title - cmap : str, optional - Colormap name - cbar : bool, optional - Whether to show colorbar - cbar_kw : dict, optional - Colorbar parameters - label_rotation_xy : Tuple[float, float], optional - (x,y) label rotation angles - x_extend_ratio : float, optional - X-axis extension ratio - y_extend_ratio : float, optional - Y-axis extension ratio - calc_bacc : bool, optional - Calculate Balanced Accuracy from Confusion Matrix - - Returns - ------- - Union[plt.Axes, Tuple[plt.Axes, float]] or Union[scitex.plt._subplots.AxisWrapper, Tuple[scitex.plt._subplots.AxisWrapper, float]] - Axes object and optionally balanced accuracy - - Example - ------- - >>> data = np.array([[10, 2, 0], [1, 15, 3], [0, 2, 20]]) - >>> fig, ax = plt.subplots() - >>> ax, bacc = stx_conf_mat(ax, data, x_labels=['A','B','C'], - ... y_labels=['X','Y','Z'], calc_bacc=True) - >>> print(f"Balanced Accuracy: {bacc:.3f}") - Balanced Accuracy: 0.889 - """ - - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - if not isinstance(conf_mat_2d, pd.DataFrame): - conf_mat_2d = pd.DataFrame(conf_mat_2d) - - bacc_val = calc_bacc_from_conf_mat(conf_mat_2d.values) - title = f"{title} (bACC = {bacc_val:.3f})" - - res = sns.heatmap( - conf_mat_2d, - ax=axis, - cmap=cmap, - annot=True, - fmt=",d", - cbar=False, - vmin=0, - **kwargs, - ) - - res.invert_yaxis() - - for _, spine in res.spines.items(): - spine.set_visible(False) - - axis.set_xlabel("Predicted label") - axis.set_ylabel("True label") - axis.set_title(title) - - if x_labels is not None: - axis.set_xticklabels(x_labels) - if y_labels is not None: - axis.set_yticklabels(y_labels) - - axis = scitex_plt_extend(axis, x_extend_ratio, y_extend_ratio) - if conf_mat_2d.shape[0] == conf_mat_2d.shape[1]: - axis.set_box_aspect(1) - axis.set_xticklabels( - axis.get_xticklabels(), - rotation=label_rotation_xy[0], - fontdict={"verticalalignment": "top"}, - ) - axis.set_yticklabels( - axis.get_yticklabels(), - rotation=label_rotation_xy[1], - fontdict={"horizontalalignment": "right"}, - ) - - if calc_bacc: - return axis, bacc_val - else: - return axis, None - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_ecdf.py b/src/scitex/plt/ax/_plot/_stx_ecdf.py deleted file mode 100755 index 108013d8e..000000000 --- a/src/scitex/plt/ax/_plot/_stx_ecdf.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 14:00:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_plot/_plot_ecdf.py - -"""Empirical Cumulative Distribution Function (ECDF) plotting.""" - -from typing import Any, Tuple, Union - -import numpy as np -import pandas as pd -from matplotlib.axes import Axes - -from scitex import logging -from scitex.pd._force_df import force_df as scitex_pd_force_df - -from ....plt.utils import assert_valid_axis, mm_to_pt - -logger = logging.getLogger(__name__) - - -# Default line width (0.2mm for publication) -DEFAULT_LINE_WIDTH_MM = 0.2 - - -def stx_ecdf( - axis: Union[Axes, "AxisWrapper"], - values_1d: np.ndarray, - **kwargs: Any, -) -> Tuple[Union[Axes, "AxisWrapper"], pd.DataFrame]: - """Plot Empirical Cumulative Distribution Function (ECDF). - - The ECDF shows the proportion of data points less than or equal to each - value, representing the empirical estimate of the cumulative distribution - function. - - Parameters - ---------- - axis : matplotlib.axes.Axes or AxisWrapper - Matplotlib axis or scitex axis wrapper to plot on. - values_1d : array-like, shape (n_samples,) - 1D array of values to compute and plot ECDF for. NaN values are automatically ignored. - **kwargs : dict - Additional arguments passed to plot function. - - Returns - ------- - axis : matplotlib.axes.Axes or AxisWrapper - The axes with the ECDF plot. - df : pd.DataFrame - DataFrame containing ECDF data with columns: - - x: sorted data values - - y: cumulative percentages (0-100) - - n: total number of data points - - x_step, y_step: step plot coordinates - - Examples - -------- - >>> import numpy as np - >>> import scitex as stx - >>> data = np.random.randn(100) - >>> fig, ax = stx.plt.subplots() - >>> ax, df = stx.plt.ax.stx_ecdf(ax, data) - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - # Flatten and remove NaN values - values_1d = np.hstack(values_1d) - - # Warnings - if np.isnan(values_1d).any(): - logger.warning("NaN values are ignored for ECDF plot.") - values_1d = values_1d[~np.isnan(values_1d)] - nn = len(values_1d) - - # Sort the data and compute the ECDF values - data_sorted = np.sort(values_1d) - ecdf_perc = 100 * np.arange(1, len(data_sorted) + 1) / len(data_sorted) - - # Create the pseudo x-axis for step plotting - x_step = np.repeat(data_sorted, 2)[1:] - y_step = np.repeat(ecdf_perc, 2)[:-1] - - # Apply default linewidth if not specified - if "linewidth" not in kwargs and "lw" not in kwargs: - kwargs["linewidth"] = mm_to_pt(DEFAULT_LINE_WIDTH_MM) - - # Add sample size to label if provided - if "label" in kwargs and kwargs["label"]: - kwargs["label"] = f"{kwargs['label']} ($n$={nn})" - - # Plot the ECDF using steps (no markers - clean line only) - axis.plot(x_step, y_step, drawstyle="steps-post", **kwargs) - - # Set ylim (xlim is auto-scaled based on data) - axis.set_ylim(0, 100) - - # Create a DataFrame to hold the ECDF data - df = scitex_pd_force_df( - { - "x": data_sorted, - "y": ecdf_perc, - "n": nn, - "x_step": x_step, - "y_step": y_step, - } - ) - - return axis, df - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_fillv.py b/src/scitex/plt/ax/_plot/_stx_fillv.py deleted file mode 100755 index dd5a8cf76..000000000 --- a/src/scitex/plt/ax/_plot/_stx_fillv.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-04-30 21:26:45 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot_fillv.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot_fillv.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import numpy as np - -from ....plt.utils import assert_valid_axis - - -def stx_fillv(axes, starts_1d, ends_1d, color="red", alpha=0.2): - """ - Fill between specified start and end intervals on an axis or array of axes. - - Parameters - ---------- - axes : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper or numpy.ndarray of axes - The axis object(s) to fill intervals on. - starts_1d : array-like, shape (n_regions,) - 1D array of start x-positions for vertical fill regions. - ends_1d : array-like, shape (n_regions,) - 1D array of end x-positions for vertical fill regions. - color : str, optional - The color to use for the filled regions. Default is "red". - alpha : float, optional - The alpha blending value, between 0 (transparent) and 1 (opaque). Default is 0.2. - - Returns - ------- - list - List of axes with filled intervals. - """ - - is_axes = isinstance(axes, np.ndarray) - - axes = axes if isinstance(axes, np.ndarray) else [axes] - - for ax in axes: - assert_valid_axis( - ax, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - for start, end in zip(starts_1d, ends_1d): - ax.axvspan(start, end, facecolor=color, edgecolor="none", alpha=alpha) - - if not is_axes: - return axes[0] - else: - return axes - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_heatmap.py b/src/scitex/plt/ax/_plot/_stx_heatmap.py deleted file mode 100755 index bc296e3e7..000000000 --- a/src/scitex/plt/ax/_plot/_stx_heatmap.py +++ /dev/null @@ -1,369 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 13:00:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_plot/_plot_heatmap.py - -"""Heatmap plotting with automatic annotation color switching.""" - -from typing import Any, List, Optional, Tuple, Union - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.axes import Axes -from matplotlib.colorbar import Colorbar -from matplotlib.image import AxesImage - - -def stx_heatmap( - ax: Union[Axes, "AxisWrapper"], - values_2d: np.ndarray, - x_labels: Optional[List[str]] = None, - y_labels: Optional[List[str]] = None, - cmap: str = "viridis", - cbar_label: str = "ColorBar Label", - annot_format: str = "{x:.1f}", - show_annot: bool = True, - annot_color_lighter: str = "black", - annot_color_darker: str = "white", - **kwargs: Any, -) -> Tuple[Union[Axes, "AxisWrapper"], AxesImage, Colorbar]: - """Plot a heatmap on the given axes with automatic annotation colors. - - Creates a heatmap visualization with optional cell annotations. Annotation - text colors are automatically switched based on background brightness for - optimal readability. - - Parameters - ---------- - ax : matplotlib.axes.Axes or AxisWrapper - The axes to plot on. - values_2d : np.ndarray, shape (n_rows, n_cols) - 2D array of data to display as heatmap. - x_labels : list of str, optional - Labels for the x-axis (columns). - y_labels : list of str, optional - Labels for the y-axis (rows). - cmap : str, default "viridis" - Colormap name to use. - cbar_label : str, default "ColorBar Label" - Label for the colorbar. - annot_format : str, default "{x:.1f}" - Format string for cell annotations. - show_annot : bool, default True - Whether to annotate the heatmap with values. - annot_color_lighter : str, default "black" - Text color for annotations on lighter backgrounds. - annot_color_darker : str, default "white" - Text color for annotations on darker backgrounds. - **kwargs : dict - Additional keyword arguments passed to imshow(). - - Returns - ------- - ax : matplotlib.axes.Axes or AxisWrapper - The axes with the heatmap. - im : matplotlib.image.AxesImage - The image object created by imshow. - cbar : matplotlib.colorbar.Colorbar - The colorbar object. - - Examples - -------- - >>> import numpy as np - >>> import scitex as stx - >>> data = np.random.rand(5, 10) - >>> fig, ax = stx.plt.subplots() - >>> ax, im, cbar = stx.plt.ax.stx_heatmap( - ... ax, data, - ... x_labels=[f"X{i}" for i in range(10)], - ... y_labels=[f"Y{i}" for i in range(5)], - ... cmap="Blues" - ... ) - """ - - im, cbar = _mpl_heatmap( - values_2d, - x_labels, - y_labels, - ax=ax, - cmap=cmap, - cbarlabel=cbar_label, - ) - - if show_annot: - textcolors = _switch_annot_colors(cmap, annot_color_lighter, annot_color_darker) - texts = _mpl_annotate_heatmap( - im, - valfmt=annot_format, - textcolors=textcolors, - ) - - return ax, im, cbar - - -def _switch_annot_colors( - cmap: str, - annot_color_lighter: str, - annot_color_darker: str, -) -> Tuple[str, str]: - """Determine annotation text colors based on colormap brightness. - - Uses perceived brightness (ITU-R BT.709) to select appropriate text - colors for light vs dark backgrounds in the colormap. - - Parameters - ---------- - cmap : str - Colormap name. - annot_color_lighter : str - Color to use on lighter backgrounds. - annot_color_darker : str - Color to use on darker backgrounds. - - Returns - ------- - tuple of str - (color_for_dark_bg, color_for_light_bg) text colors. - """ - cmap_obj = plt.cm.get_cmap(cmap) - - # Sample colormap at extremes (avoiding edge effects) - dark_color = cmap_obj(0.1) - light_color = cmap_obj(0.9) - - # Calculate perceived brightness using ITU-R BT.709 coefficients - dark_brightness = ( - 0.2126 * dark_color[0] + 0.7152 * dark_color[1] + 0.0722 * dark_color[2] - ) - - # Choose text colors based on background brightness - if dark_brightness < 0.5: - return (annot_color_lighter, annot_color_darker) - else: - return (annot_color_darker, annot_color_lighter) - - -def _mpl_heatmap( - data: np.ndarray, - row_labels: Optional[List[str]], - col_labels: Optional[List[str]], - ax: Optional[Axes] = None, - cbar_kw: Optional[dict] = None, - cbarlabel: str = "", - **kwargs: Any, -) -> Tuple[AxesImage, Colorbar]: - """Create a heatmap with imshow and add a colorbar. - - Parameters - ---------- - data : np.ndarray - 2D array of data to display. - row_labels : list of str or None - Labels for the rows (y-axis). - col_labels : list of str or None - Labels for the columns (x-axis). - ax : matplotlib.axes.Axes, optional - Axes to plot on. If None, uses current axes. - cbar_kw : dict, optional - Keyword arguments for colorbar creation. - cbarlabel : str, default "" - Label for the colorbar. - **kwargs : dict - Additional keyword arguments passed to imshow(). - - Returns - ------- - im : matplotlib.image.AxesImage - The image object. - cbar : matplotlib.colorbar.Colorbar - The colorbar object. - """ - - if ax is None: - ax = plt.gca() - - if cbar_kw is None: - cbar_kw = {} - - # Plot the heatmap - im = ax.imshow(data, **kwargs) - - # Create colorbar with proper formatting - cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) - cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") - - # Set colorbar border width to match axes spines - cbar.outline.set_linewidth(0.2 * 2.83465) # 0.2mm in points - - # Format colorbar ticks - from matplotlib.ticker import MaxNLocator - - cbar.ax.yaxis.set_major_locator(MaxNLocator(nbins=4, min_n_ticks=3)) - cbar.ax.tick_params(width=0.2 * 2.83465, length=0.8 * 2.83465) # Match tick styling - - # Show all ticks and label them with the respective list entries. - ax.set_xticks( - range(data.shape[1]), - labels=col_labels, - # rotation=45, - # ha="right", - # rotation_mode="anchor", - ) - ax.set_yticks(range(data.shape[0]), labels=row_labels) - - # Let the horizontal axes labeling appear on top. - ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True) - - # Show all 4 spines for heatmap - ax.spines[:].set_visible(True) - - # Set aspect ratio to 'equal' for square cells (1:1) - ax.set_aspect("equal", adjustable="box") - - ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True) - ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True) - ax.tick_params(which="minor", bottom=False, left=False) - - return im, cbar - - -def _calc_annot_fontsize(n_rows: int, n_cols: int) -> float: - """Calculate dynamic annotation font size based on cell count. - - Uses a base size of 6pt for small heatmaps and scales down for larger ones. - - Parameters - ---------- - n_rows : int - Number of rows in the heatmap. - n_cols : int - Number of columns in the heatmap. - - Returns - ------- - float - Font size in points. - """ - # Base font size for small heatmaps (e.g., 5x5) - BASE_FONTSIZE = 6.0 - BASE_CELLS = 5 # Reference dimension - - # Use the larger dimension to scale - max_dim = max(n_rows, n_cols) - - if max_dim <= BASE_CELLS: - return BASE_FONTSIZE - elif max_dim <= 10: - # Linear interpolation: 6pt at 5 cells, 5pt at 10 cells - return BASE_FONTSIZE - (max_dim - BASE_CELLS) * 0.2 - elif max_dim <= 20: - # 5pt at 10 cells, 4pt at 20 cells - return 5.0 - (max_dim - 10) * 0.1 - else: - # Minimum 3pt for very large heatmaps - return max(3.0, 4.0 - (max_dim - 20) * 0.05) - - -def _mpl_annotate_heatmap( - im: AxesImage, - data: Optional[np.ndarray] = None, - valfmt: str = "{x:.2f}", - textcolors: Tuple[str, str] = ("lightgray", "black"), - threshold: Optional[float] = None, - fontsize: Optional[float] = None, - **textkw: Any, -) -> List: - """Annotate a heatmap with cell values. - - Parameters - ---------- - im : matplotlib.image.AxesImage - The image to be annotated. - data : np.ndarray, optional - Data used to annotate. If None, uses the image's array. - valfmt : str, default "{x:.2f}" - Format string for the annotations. - textcolors : tuple of str, default ("lightgray", "black") - Colors for annotations. First color for values below threshold, - second for values above. - threshold : float, optional - Value in normalized colormap space (0 to 1) above which the - second color is used. If None, uses 0.7 * max(data). - fontsize : float, optional - Font size in points. If None, dynamically calculated based on - cell count (6pt base, scaling down for larger heatmaps). - **textkw : dict - Additional keyword arguments passed to ax.text(). - - Returns - ------- - texts : list of matplotlib.text.Text - The annotation text objects. - """ - - if not isinstance(data, (list, np.ndarray)): - data = im.get_array() - - # Calculate dynamic font size if not specified - if fontsize is None: - fontsize = _calc_annot_fontsize(data.shape[0], data.shape[1]) - - # Normalize the threshold to the images color range. - if threshold is not None: - threshold = im.norm(threshold) - else: - # Use 0.7 instead of 0.5 for better visibility with most colormaps - threshold = im.norm(data.max()) * 0.7 - - # Set default alignment to center, but allow it to be - # overwritten by textkw. - kw = dict( - horizontalalignment="center", verticalalignment="center", fontsize=fontsize - ) - kw.update(textkw) - - # Get the formatter in case a string is supplied - if isinstance(valfmt, str): - valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) - - # Loop over the data and create a `Text` for each "pixel". - # Change the text's color depending on the data. - texts = [] - for ii in range(data.shape[0]): - for jj in range(data.shape[1]): - kw.update(color=textcolors[int(im.norm(data[ii, jj]) > threshold)]) - text = im.axes.text(jj, ii, valfmt(data[ii, jj], None), **kw) - texts.append(text) - - return texts - - -if __name__ == "__main__": - import matplotlib - import matplotlib as mpl - import matplotlib.pyplot as plt - import numpy as np - - data = np.random.rand(5, 10) - x_labels = [f"X{ii + 1}" for ii in range(5)] - y_labels = [f"Y{ii + 1}" for ii in range(10)] - - fig, ax = plt.subplots() - - im, cbar = stx_heatmap( - ax, - data, - x_labels=x_labels, - y_labels=y_labels, - show_annot=True, - annot_color_lighter="white", - annot_color_darker="black", - cmap="Blues", - ) - - fig.tight_layout() - plt.show() - # EOF - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_image.py b/src/scitex/plt/ax/_plot/_stx_image.py deleted file mode 100755 index 75a0240a7..000000000 --- a/src/scitex/plt/ax/_plot/_stx_image.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-01 08:39:46 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_image2d.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot/_plot_image2d.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib - -from scitex.plt.utils import assert_valid_axis - - -def stx_image( - ax, - arr_2d, - cbar=True, - cbar_label=None, - cbar_shrink=1.0, - cbar_fraction=0.046, - cbar_pad=0.04, - cmap="viridis", - aspect="auto", - vmin=None, - vmax=None, - **kwargs, -): - """ - Imshows an two-dimensional array with theese two conditions: - 1) The first dimension represents the x dim, from left to right. - 2) The second dimension represents the y dim, from bottom to top - - Parameters - ---------- - ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axis to plot on - arr_2d : numpy.ndarray - The 2D array to display - cbar : bool, optional - Whether to show colorbar, by default True - cbar_label : str, optional - Label for the colorbar, by default None - cbar_shrink : float, optional - Shrink factor for the colorbar, by default 1.0 - cbar_fraction : float, optional - Fraction of original axes to use for colorbar, by default 0.046 - cbar_pad : float, optional - Padding between the image axes and colorbar axes, by default 0.04 - cmap : str, optional - Colormap name, by default "viridis" - aspect : str, optional - Aspect ratio adjustment, by default "auto" - vmin : float, optional - Minimum data value for colormap scaling, by default None - vmax : float, optional - Maximum data value for colormap scaling, by default None - **kwargs - Additional keyword arguments passed to ax.imshow() - - Returns - ------- - matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axis with the image plotted - """ - assert_valid_axis( - ax, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - assert arr_2d.ndim == 2, "Input array must be 2-dimensional" - - if kwargs.get("xyz"): - kwargs.pop("xyz") - - # Transposes arr_2d for correct orientation - arr_2d = arr_2d.T - - # Cals the original ax.imshow() method on the transposed array - im = ax.imshow(arr_2d, cmap=cmap, vmin=vmin, vmax=vmax, aspect=aspect, **kwargs) - - # Color bar - if cbar: - fig = ax.get_figure() - _cbar = fig.colorbar( - im, ax=ax, shrink=cbar_shrink, fraction=cbar_fraction, pad=cbar_pad - ) - if cbar_label: - _cbar.set_label(cbar_label) - - # Invert y-axis to match typical image orientation - ax.invert_yaxis() - - return ax - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_joyplot.py b/src/scitex/plt/ax/_plot/_stx_joyplot.py deleted file mode 100755 index 7513d1f54..000000000 --- a/src/scitex/plt/ax/_plot/_stx_joyplot.py +++ /dev/null @@ -1,134 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2025-05-02 09:03:23 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_joyplot.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot/_plot_joyplot.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np -from scipy import stats - -from ....plt.utils import assert_valid_axis - - -def stx_joyplot( - ax, arrays, overlap=0.5, fill_alpha=0.7, line_alpha=1.0, colors=None, **kwargs -): - """ - Create a joyplot (ridgeline plot) on the provided axes. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The axes to plot on - arrays : list of array-like - List of 1D arrays for each ridge - overlap : float, default 0.5 - Amount of overlap between ridges (0 = no overlap, 1 = full overlap) - fill_alpha : float, default 0.7 - Alpha for the filled KDE area - line_alpha : float, default 1.0 - Alpha for the KDE line - colors : list, optional - Colors for each ridge. If None, uses scitex palette. - **kwargs - Additional keyword arguments - - Returns - ------- - matplotlib.axes.Axes - The axes with the joyplot - """ - assert_valid_axis( - ax, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - # Convert dict to list of arrays (values only) - if isinstance(arrays, dict): - arrays = list(arrays.values()) - - # Add sample size per distribution to label if provided (show range if variable) - if kwargs.get("label"): - n_per_dist = [len(arr) for arr in arrays] - n_min, n_max = min(n_per_dist), max(n_per_dist) - n_str = str(n_min) if n_min == n_max else f"{n_min}-{n_max}" - kwargs["label"] = f"{kwargs['label']} ($n$={n_str})" - - # Import scitex colors - from scitex.plt.color import HEX - - # Default colors from scitex palette - if colors is None: - colors = [ - HEX["blue"], - HEX["red"], - HEX["green"], - HEX["yellow"], - HEX["purple"], - HEX["orange"], - HEX["lightblue"], - HEX["pink"], - ] - - n_ridges = len(arrays) - - # Calculate global x range - all_data = np.concatenate([np.asarray(arr) for arr in arrays]) - x_min, x_max = np.min(all_data), np.max(all_data) - x_range = x_max - x_min - x_padding = x_range * 0.1 - x = np.linspace(x_min - x_padding, x_max + x_padding, 200) - - # Calculate KDEs and find max density for scaling - kdes = [] - max_density = 0 - for arr in arrays: - arr = np.asarray(arr) - if len(arr) > 1: - kde = stats.gaussian_kde(arr) - density = kde(x) - kdes.append(density) - max_density = max(max_density, np.max(density)) - else: - kdes.append(np.zeros_like(x)) - - # Scale factor for ridge height - ridge_height = 1.0 / (1.0 - overlap * 0.5) if overlap < 1 else 2.0 - - # Plot each ridge from back to front - for i in range(n_ridges - 1, -1, -1): - color = colors[i % len(colors)] - baseline = i * (1.0 - overlap) - - # Scale density to fit nicely - scaled_density = ( - kdes[i] / max_density * ridge_height if max_density > 0 else kdes[i] - ) - - # Fill - ax.fill_between( - x, - baseline, - baseline + scaled_density, - facecolor=color, - edgecolor="none", - alpha=fill_alpha, - ) - # Line on top - ax.plot( - x, baseline + scaled_density, color=color, alpha=line_alpha, linewidth=1.0 - ) - - # Set y limits - ax.set_ylim(-0.1, n_ridges * (1.0 - overlap) + ridge_height) - - # Hide y-axis ticks for cleaner look (joyplots typically don't show y values) - ax.set_yticks([]) - - return ax - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_raster.py b/src/scitex/plt/ax/_plot/_stx_raster.py deleted file mode 100755 index 21539f520..000000000 --- a/src/scitex/plt/ax/_plot/_stx_raster.py +++ /dev/null @@ -1,200 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-01 15:23:01 (ywatanabe)" -# File: /home/ywatanabe/proj/_scitex_repo/src/scitex/plt/ax/_plot/_plot_raster.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot/_plot_raster.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from bisect import bisect_left - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from ....plt.utils import assert_valid_axis - - -def stx_raster( - ax, - spike_times_list, - time=None, - labels=None, - colors=None, - orientation="horizontal", - y_offset=None, - lineoffsets=None, - linelengths=None, - apply_set_n_ticks=True, - n_xticks=4, - n_yticks=None, - **kwargs, -): - """ - Create a raster plot using eventplot with custom labels and colors. - - Parameters - ---------- - ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axes on which to draw the raster plot. - spike_times_list : list of array-like, shape (n_trials,) where each element is (n_spikes,) - List of spike/event time arrays, one per trial/channel - time : array-like, optional - The time indices for the events (default: np.linspace(0, max(event_times))). - labels : list, optional - Labels for each channel/trial. - colors : list, optional - Colors for each channel/trial. - orientation: str, optional - Orientation of raster plot (default: horizontal). - y_offset : float, optional - Vertical spacing between trials/channels (default: 1.0). - lineoffsets : array-like, optional - Y-positions for each trial/channel (overrides automatic positioning). - linelengths : float, optional - Height of each spike mark (default: 0.8, slightly less than y_offset to prevent overlap). - apply_set_n_ticks : bool, optional - Whether to apply set_n_ticks for cleaner axis (default: True). - n_xticks : int, optional - Number of x-axis ticks (default: 4). - n_yticks : int or None, optional - Number of y-axis ticks (default: None, auto-determined). - **kwargs : dict - Additional keyword arguments for eventplot. - - Returns - ------- - ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axes with the raster plot. - df : pandas.DataFrame - DataFrame with time indices and channel events. - """ - assert_valid_axis( - ax, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - # Format spike_times_list data - spike_times_list = _ensure_list(spike_times_list) - - # Add sample size (number of trials) to label if provided - if kwargs.get("label"): - n_trials = len(spike_times_list) - kwargs["label"] = f"{kwargs['label']} ($n$={n_trials})" - - # Handle colors and labels - colors = _handle_colors(colors, spike_times_list) - - # Handle lineoffsets for positioning between trials/channels - if y_offset is None: - y_offset = 1.0 # Default spacing - if lineoffsets is None: - lineoffsets = np.arange(len(spike_times_list)) * y_offset - - # Set linelengths to prevent overlap (80% of y_offset by default) - if linelengths is None: - linelengths = y_offset * 0.8 - - # Ensure lineoffsets is iterable and matches spike_times_list length - if np.isscalar(lineoffsets): - lineoffsets = [lineoffsets] - if len(lineoffsets) < len(spike_times_list): - lineoffsets = list(lineoffsets) + list( - range(len(lineoffsets), len(spike_times_list)) - ) - - # Plotting as eventplot using spike_times_list with proper positioning - for ii, (pos, color, offset) in enumerate( - zip(spike_times_list, colors, lineoffsets) - ): - label = _define_label(labels, ii) - ax.eventplot( - pos, - lineoffsets=offset, - linelengths=linelengths, - orientation=orientation, - colors=color, - label=label, - **kwargs, - ) - - # Apply set_n_ticks for cleaner axes if requested - if apply_set_n_ticks: - from scitex.plt.ax._style._set_n_ticks import set_n_ticks - - # For categorical y-axis (trials/channels), use appropriate tick count - if n_yticks is None: - n_yticks = min(len(spike_times_list), 8) # Max 8 ticks for readability - - # Only apply if we have reasonable numeric ranges - try: - x_range = ax.get_xlim() - y_range = ax.get_ylim() - - # Apply x-ticks if we have a reasonable numeric range - if x_range[1] - x_range[0] > 0: - set_n_ticks(ax, n_xticks=n_xticks, n_yticks=None) - - # Apply y-ticks only if we don't have categorical labels - if labels is None and y_range[1] - y_range[0] > 0: - set_n_ticks(ax, n_xticks=None, n_yticks=n_yticks) - - except Exception: - # Skip set_n_ticks if there are issues (e.g., categorical data) - pass - - # Legend - if labels is not None: - ax.legend() - - # Return spike_times in a useful format - spike_times_digital_df = _event_times_to_digital_df( - spike_times_list, time, lineoffsets - ) - - return ax, spike_times_digital_df - - -def _ensure_list(event_times): - return [[pos] if isinstance(pos, (int, float)) else pos for pos in event_times] - - -def _define_label(labels, ii): - if (labels is not None) and (ii < len(labels)): - return labels[ii] - else: - return None - - -def _handle_colors(colors, event_times_list): - if colors is None: - colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] - if len(colors) < len(event_times_list): - colors = colors * (len(event_times_list) // len(colors) + 1) - return colors - - -def _event_times_to_digital_df(event_times_list, time, lineoffsets=None): - if time is None: - time = np.linspace(0, np.max([np.max(pos) for pos in event_times_list]), 1000) - - digi = np.full((len(event_times_list), len(time)), np.nan, dtype=float) - - for i_ch, posis_ch in enumerate(event_times_list): - for posi_ch in posis_ch: - i_insert = bisect_left(time, posi_ch) - if i_insert == len(time): - i_insert -= 1 - # Use lineoffset position if available, otherwise use channel index - if lineoffsets is not None and i_ch < len(lineoffsets): - digi[i_ch, i_insert] = lineoffsets[i_ch] - else: - digi[i_ch, i_insert] = i_ch - - return pd.DataFrame(digi.T, index=time) - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_rectangle.py b/src/scitex/plt/ax/_plot/_stx_rectangle.py deleted file mode 100755 index d8db20a63..000000000 --- a/src/scitex/plt/ax/_plot/_stx_rectangle.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-01 08:45:44 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_rectangle.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot/_plot_rectangle.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from matplotlib.patches import Rectangle - - -def stx_rectangle(ax, xx, yy, ww, hh, **kwargs): - """Add a rectangle patch to an axes. - - Convenience function for adding rectangular patches to plots, useful for - highlighting regions, creating box annotations, or drawing geometric shapes. - By default, rectangles have no edge (border) for cleaner publication figures. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The axes to add the rectangle to. - xx : float - X-coordinate of the rectangle's bottom-left corner. - yy : float - Y-coordinate of the rectangle's bottom-left corner. - ww : float - Width of the rectangle. - hh : float - Height of the rectangle. - **kwargs : dict - Additional keyword arguments passed to matplotlib.patches.Rectangle. - Common options include: - - facecolor/fc : fill color - - edgecolor/ec : edge color (default: 'none') - - linewidth/lw : edge line width - - alpha : transparency (0-1) - - linestyle/ls : edge line style - - Returns - ------- - matplotlib.axes.Axes - The axes with the rectangle added. - - Examples - -------- - >>> fig, ax = plt.subplots() - >>> ax.plot([0, 10], [0, 10]) - >>> # Highlight a region (no border by default) - >>> stx_rectangle(ax, 2, 3, 4, 3, facecolor='yellow', alpha=0.3) - - >>> # Draw a box with explicit edge - >>> stx_rectangle(ax, 5, 5, 2, 2, facecolor='none', edgecolor='red', linewidth=2) - - See Also - -------- - matplotlib.patches.Rectangle : The underlying Rectangle class - matplotlib.axes.Axes.add_patch : Method used to add the patch - """ - # Default to no edge for cleaner publication figures - if "edgecolor" not in kwargs and "ec" not in kwargs: - kwargs["edgecolor"] = "none" - ax.add_patch(Rectangle((xx, yy), ww, hh, **kwargs)) - return ax - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_scatter_hist.py b/src/scitex/plt/ax/_plot/_stx_scatter_hist.py deleted file mode 100755 index e220d57a2..000000000 --- a/src/scitex/plt/ax/_plot/_stx_scatter_hist.py +++ /dev/null @@ -1,133 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 18:14:56 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_scatter_hist.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot/_plot_scatter_hist.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import numpy as np - - -def stx_scatter_hist( - ax, - x, - y, - fig=None, - hist_bins: int = 20, - scatter_alpha: float = 0.6, - scatter_size: float = 20, - scatter_color: str = "blue", - hist_color_x: str = "blue", - hist_color_y: str = "red", - hist_alpha: float = 0.5, - scatter_ratio: float = 0.8, - **kwargs, -): - """ - Plot a scatter plot with histograms on the x and y axes. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The main scatter plot axes - x : array-like - x data for scatter plot and histogram - y : array-like - y data for scatter plot and histogram - fig : matplotlib.figure.Figure, optional - Figure to create axes in. If None, uses ax.figure - hist_bins : int, optional - Number of bins for histograms, default 20 - scatter_alpha : float, optional - Alpha value for scatter points, default 0.6 - scatter_size : float, optional - Size of scatter points, default 20 - scatter_color : str, optional - Color of scatter points, default "blue" - hist_color_x : str, optional - Color of x-axis histogram, default "blue" - hist_color_y : str, optional - Color of y-axis histogram, default "red" - hist_alpha : float, optional - Alpha value for histograms, default 0.5 - scatter_ratio : float, optional - Ratio of main plot to histograms, default 0.8 - **kwargs - Additional keyword arguments passed to scatter and hist functions - - Returns - ------- - tuple - (ax, ax_histx, ax_histy, hist_data) - All axes objects and histogram data - hist_data is a dictionary containing histogram counts and bin edges - """ - # Get the current figure if not provided - if fig is None: - fig = ax.figure - - # Calculate the positions based on scatter_ratio - margin = 0.1 * (1 - scatter_ratio) - hist_size = 0.2 * scatter_ratio - - # Create the histogram axes - ax_histx = fig.add_axes( - [ - ax.get_position().x0, - ax.get_position().y1 + margin, - ax.get_position().width * scatter_ratio, - hist_size, - ] - ) - ax_histy = fig.add_axes( - [ - ax.get_position().x1 + margin, - ax.get_position().y0, - hist_size, - ax.get_position().height * scatter_ratio, - ] - ) - - # No labels for histograms - ax_histx.tick_params(axis="x", labelbottom=False) - ax_histy.tick_params(axis="y", labelleft=False) - - # The scatter plot - ax.scatter( - x, - y, - alpha=scatter_alpha, - s=scatter_size, - color=scatter_color, - **kwargs, - ) - - # Calculate histogram data - hist_x, bin_edges_x = np.histogram(x, bins=hist_bins) - hist_y, bin_edges_y = np.histogram(y, bins=hist_bins) - - # Plot histograms - ax_histx.hist(x, bins=hist_bins, color=hist_color_x, alpha=hist_alpha) - ax_histy.hist( - y, - bins=hist_bins, - orientation="horizontal", - color=hist_color_y, - alpha=hist_alpha, - ) - - # Create return data structure - hist_data = { - "hist_x": hist_x, - "hist_y": hist_y, - "bin_edges_x": bin_edges_x, - "bin_edges_y": bin_edges_y, - } - - return ax, ax_histx, ax_histy, hist_data - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_shaded_line.py b/src/scitex/plt/ax/_plot/_stx_shaded_line.py deleted file mode 100755 index 56346df60..000000000 --- a/src/scitex/plt/ax/_plot/_stx_shaded_line.py +++ /dev/null @@ -1,220 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 13:15:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_plot/_plot_shaded_line.py - -"""Line plots with shaded uncertainty regions (e.g., confidence intervals).""" - -from typing import Any, List, Optional, Tuple, Union - -import numpy as np -import pandas as pd -from matplotlib.axes import Axes - -from scitex.types import ColorLike - -from ....plt.utils import assert_valid_axis - - -def _plot_single_shaded_line( - axis: Union[Axes, "AxisWrapper"], - xx: np.ndarray, - y_lower: np.ndarray, - y_middle: np.ndarray, - y_upper: np.ndarray, - color: Optional[ColorLike] = None, - alpha: float = 0.3, - **kwargs: Any, -) -> Tuple[Union[Axes, "AxisWrapper"], pd.DataFrame]: - """Plot a single line with shaded area between y_lower and y_upper bounds. - - Parameters - ---------- - axis : matplotlib.axes.Axes or AxisWrapper - Axes to plot on. - xx : np.ndarray - X values. - y_lower : np.ndarray - Lower bound y values. - y_middle : np.ndarray - Middle (mean/median) y values. - y_upper : np.ndarray - Upper bound y values. - color : ColorLike, optional - Color for line and fill. - alpha : float, default 0.3 - Transparency for shaded region. - **kwargs : dict - Additional keyword arguments passed to plot(). - - Returns - ------- - axis : matplotlib.axes.Axes or AxisWrapper - The axes with the plot. - df : pd.DataFrame - DataFrame with x, y_lower, y_middle, y_upper columns. - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - assert ( - len(xx) == len(y_middle) == len(y_lower) == len(y_upper) - ), "All arrays must have the same length" - - label = kwargs.pop("label", None) - axis.plot(xx, y_middle, color=color, alpha=alpha, label=label, **kwargs) - kwargs["linewidth"] = 0 - kwargs["edgecolor"] = "none" # Remove edge line - axis.fill_between(xx, y_lower, y_upper, alpha=alpha, color=color, **kwargs) - - return axis, pd.DataFrame( - {"x": xx, "y_lower": y_lower, "y_middle": y_middle, "y_upper": y_upper} - ) - - -def _plot_shaded_line( - axis: Union[Axes, "AxisWrapper"], - xs: List[np.ndarray], - ys_lower: List[np.ndarray], - ys_middle: List[np.ndarray], - ys_upper: List[np.ndarray], - color: Optional[Union[List[ColorLike], ColorLike]] = None, - **kwargs: Any, -) -> Tuple[Union[Axes, "AxisWrapper"], List[pd.DataFrame]]: - """Plot multiple lines with shaded areas between ys_lower and ys_upper bounds. - - Parameters - ---------- - axis : matplotlib.axes.Axes or AxisWrapper - Axes to plot on. - xs : list of np.ndarray - List of x value arrays. - ys_lower : list of np.ndarray - List of lower bound y value arrays. - ys_middle : list of np.ndarray - List of middle y value arrays. - ys_upper : list of np.ndarray - List of upper bound y value arrays. - color : ColorLike or list of ColorLike, optional - Color(s) for lines and fills. - **kwargs : dict - Additional keyword arguments passed to plot(). - - Returns - ------- - axis : matplotlib.axes.Axes or AxisWrapper - The axes with the plots. - results : list of pd.DataFrame - List of DataFrames with plot data. - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - assert ( - len(xs) == len(ys_lower) == len(ys_middle) == len(ys_upper) - ), "All input lists must have the same length" - - results = [] - colors = color - color_list = colors - - if colors is not None: - if not isinstance(colors, list): - color_list = [colors] * len(xs) - else: - assert len(colors) == len(xs), "Number of colors must match number of lines" - color_list = colors - - for idx, (xx, y_lower, y_middle, y_upper) in enumerate( - zip(xs, ys_lower, ys_middle, ys_upper) - ): - this_kwargs = kwargs.copy() - this_kwargs["color"] = color_list[idx] - _, result_df = _plot_single_shaded_line( - axis, xx, y_lower, y_middle, y_upper, **this_kwargs - ) - results.append(result_df) - else: - for xx, y_lower, y_middle, y_upper in zip(xs, ys_lower, ys_middle, ys_upper): - _, result_df = _plot_single_shaded_line( - axis, xx, y_lower, y_middle, y_upper, **kwargs - ) - results.append(result_df) - - return axis, results - - -def stx_shaded_line( - axis: Union[Axes, "AxisWrapper"], - xs: Union[np.ndarray, List[np.ndarray]], - ys_lower: Union[np.ndarray, List[np.ndarray]], - ys_middle: Union[np.ndarray, List[np.ndarray]], - ys_upper: Union[np.ndarray, List[np.ndarray]], - color: Optional[Union[ColorLike, List[ColorLike]]] = None, - **kwargs: Any, -) -> Tuple[Union[Axes, "AxisWrapper"], Union[pd.DataFrame, List[pd.DataFrame]]]: - """Plot line(s) with shaded uncertainty regions. - - Automatically handles both single and multiple line cases. Useful for - plotting mean/median with confidence intervals or standard deviation bands. - - Parameters - ---------- - axis : matplotlib.axes.Axes or AxisWrapper - Axes to plot on. - xs : np.ndarray or list of np.ndarray - X values (single array or list of arrays for multiple lines). - ys_lower : np.ndarray or list of np.ndarray - Lower bound y values. - ys_middle : np.ndarray or list of np.ndarray - Middle (mean/median) y values. - ys_upper : np.ndarray or list of np.ndarray - Upper bound y values. - color : ColorLike or list of ColorLike, optional - Color(s) for lines and shaded regions. - **kwargs : dict - Additional keyword arguments passed to plot(). - - Returns - ------- - axis : matplotlib.axes.Axes or AxisWrapper - The axes with the plot(s). - data : pd.DataFrame or list of pd.DataFrame - DataFrame(s) containing plot data with columns: - x, y_lower, y_middle, y_upper. - - Examples - -------- - >>> import numpy as np - >>> import scitex as stx - >>> x = np.linspace(0, 10, 100) - >>> y_mean = np.sin(x) - >>> y_std = 0.2 - >>> fig, ax = stx.plt.subplots() - >>> ax, df = stx.plt.ax.stx_shaded_line( - ... ax, x, y_mean - y_std, y_mean, y_mean + y_std, - ... color='blue', alpha=0.3 - ... ) - """ - is_single = not ( - isinstance(xs, list) - and isinstance(ys_lower, list) - and isinstance(ys_middle, list) - and isinstance(ys_upper, list) - ) - - if is_single: - assert ( - len(xs) == len(ys_lower) == len(ys_middle) == len(ys_upper) - ), "All arrays must have the same length for single line plot" - - return _plot_single_shaded_line( - axis, xs, ys_lower, ys_middle, ys_upper, color=color, **kwargs - ) - else: - return _plot_shaded_line( - axis, xs, ys_lower, ys_middle, ys_upper, color=color, **kwargs - ) - - -# EOF diff --git a/src/scitex/plt/ax/_plot/_stx_violin.py b/src/scitex/plt/ax/_plot/_stx_violin.py deleted file mode 100755 index 4e27cf98a..000000000 --- a/src/scitex/plt/ax/_plot/_stx_violin.py +++ /dev/null @@ -1,353 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 22:01:54 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_violin.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_plot/_plot_violin.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns - -from ....plt.utils import assert_valid_axis - - -def stx_violin( - ax, - values_list, - labels=None, - colors=None, - half=False, - **kwargs, -): - """ - Plot a violin plot using seaborn. - - Parameters - ---------- - ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axes to plot on - values_list : list of array-like, shape (n_groups,) where each element is (n_samples,) - List of 1D arrays to plot as violins, one per group - labels : list, optional - Labels for each array in values_list - colors : list, optional - Colors for each violin - half : bool, optional - If True, plots only the left half of the violins, default False - **kwargs - Additional keyword arguments passed to seaborn.violinplot - - Returns - ------- - ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axes object with the plot - """ - # Add sample size to label if provided (show range if variable) - if kwargs.get("label"): - n_per_group = [len(g) for g in values_list] - n_min, n_max = min(n_per_group), max(n_per_group) - n_str = str(n_min) if n_min == n_max else f"{n_min}-{n_max}" - kwargs["label"] = f"{kwargs['label']} ($n$={n_str})" - - # Convert list-style data to DataFrame - all_values = [] - all_groups = [] - - for idx, values in enumerate(values_list): - all_values.extend(values) - group_label = labels[idx] if labels and idx < len(labels) else f"x {idx}" - all_groups.extend([group_label] * len(values)) - - # Create DataFrame - df = pd.DataFrame({"x": all_groups, "y": all_values}) - - # Setup colors if provided - if colors: - if isinstance(colors, list): - kwargs["palette"] = { - group: color - for group, color in zip(set(all_groups), colors[: len(set(all_groups))]) - } - else: - kwargs["palette"] = colors - - # Call seaborn-based function - return sns_plot_violin(ax, data=df, x="x", y="y", hue="x", half=half, **kwargs) - - -def sns_plot_violin(ax, data=None, x=None, y=None, hue=None, half=False, **kwargs): - """ - Plot a violin plot with option for half violins. - Parameters - ---------- - ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axes to plot on - data : DataFrame - The dataframe containing the data - x : str - Column name for x-axis variable - y : str - Column name for y-axis variable - hue : str, optional - Column name for hue variable - half : bool, optional - If True, plots only the left half of the violins, default False - **kwargs - Additional keyword arguments passed to seaborn.violinplot - Returns - ------- - ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper - The axes object with the plot - """ - assert_valid_axis( - ax, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - if not half: - # Standard violin plot - return sns.violinplot(data=data, x=x, y=y, hue=hue, ax=ax, **kwargs) - - # Create a copy of the dataframe to avoid modifying the original - df = data.copy() - - # If no hue provided, create default hue - if hue is None: - df["_hue"] = "default" - hue = "_hue" - - # Add fake hue for the right side - df["_fake_hue"] = df[hue] + "_right" - - # Adjust hue_order and palette if provided - if "hue_order" in kwargs: - kwargs["hue_order"] = kwargs["hue_order"] + [ - h + "_right" for h in kwargs["hue_order"] - ] - else: - kwargs["hue_order"] = [] - for group in df[x].unique().tolist(): - kwargs["hue_order"].append(group) - kwargs["hue_order"].append(group + "_right") - - if "palette" in kwargs: - palette = kwargs["palette"] - if isinstance(palette, dict): - kwargs["palette"] = { - **palette, - **{k + "_right": v for k, v in palette.items()}, - } - elif isinstance(palette, list): - kwargs["palette"] = palette + palette - - # Conc left and right - df_left = df[[x, y]] - df_right = df[["_fake_hue", y]].rename(columns={"_fake_hue": x}) - df_right[y] = [np.nan for _ in range(len(df_right))] - df_conc = pd.concat([df_left, df_right], axis=0, ignore_index=True) - df_conc = df_conc.sort_values(x) - - # Plot - sns.violinplot(data=df_conc, x=x, y=y, hue="x", split=True, ax=ax, **kwargs) - - # Remove right half of violins - for collection in ax.collections: - if isinstance(collection, plt.matplotlib.collections.PolyCollection): - collection.set_clip_path(None) - - # Adjust legend - if ax.legend_ is not None: - handles, labels = ax.get_legend_handles_labels() - ax.legend(handles[: len(handles) // 2], labels[: len(labels) // 2]) - - return ax - - -# def _plot_half_violin(ax, data=None, x=None, y=None, hue=None, **kwargs): - -# assert isinstance( -# ax, matplotlib.axes._axes.Axes -# ), "First argument must be a matplotlib axis" - -# # Prepare data -# df = data.copy() -# if hue is None: -# df["_hue"] = "default" -# hue = "_hue" - -# # Add fake hue for the right side -# df["_fake_hue"] = df[hue] + "_right" - -# # Adjust hue_order and palette if provided -# if "hue_order" in kwargs: -# kwargs["hue_order"] = kwargs["hue_order"] + [ -# h + "_right" for h in kwargs["hue_order"] -# ] - -# if "palette" in kwargs: -# palette = kwargs["palette"] -# if isinstance(palette, dict): -# kwargs["palette"] = { -# **palette, -# **{k + "_right": v for k, v in palette.items()}, -# } -# elif isinstance(palette, list): -# kwargs["palette"] = palette + palette - -# # Plot -# sns.violinplot( -# data=df, x=x, y=y, hue="_fake_hue", split=True, ax=ax, **kwargs -# ) - -# # Remove right half of violins -# for collection in ax.collections: -# if isinstance(collection, plt.matplotlib.collections.PolyCollection): -# collection.set_clip_path(None) - -# # Adjust legend -# if ax.legend_ is not None: -# handles, labels = ax.get_legend_handles_labels() -# ax.legend(handles[: len(handles) // 2], labels[: len(labels) // 2]) - -# return ax - -# import matplotlib -# import matplotlib.pyplot as plt -# import seaborn as sns - -# def plot_violin_half(ax, data=None, x=None, y=None, hue=None, **kwargs): -# """ -# Plot a half violin plot (showing only the left side of violins). - -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The axes to plot on -# data : DataFrame -# The dataframe containing the data -# x : str -# Column name for x-axis variable -# y : str -# Column name for y-axis variable -# hue : str, optional -# Column name for hue variable -# **kwargs -# Additional keyword arguments passed to seaborn.violinplot - -# Returns -# ------- -# ax : matplotlib.axes.Axes -# The axes object with the plot -# """ -# assert isinstance( -# ax, matplotlib.axes._axes.Axes -# ), "First argument must be a matplotlib axis" - -# # Prepare data -# df = data.copy() -# if hue is None: -# df["_hue"] = "default" -# hue = "_hue" - -# # Add fake hue for the right side -# df["_fake_hue"] = df[hue] + "_right" - -# # Adjust hue_order and palette if provided -# if "hue_order" in kwargs: -# kwargs["hue_order"] = kwargs["hue_order"] + [ -# h + "_right" for h in kwargs["hue_order"] -# ] -# if "palette" in kwargs: -# palette = kwargs["palette"] -# if isinstance(palette, dict): -# kwargs["palette"] = { -# **palette, -# **{k + "_right": v for k, v in palette.items()}, -# } -# elif isinstance(palette, list): -# kwargs["palette"] = palette + palette - -# # Plot -# sns.violinplot( -# data=df, x=x, y=y, hue="_fake_hue", split=True, ax=ax, **kwargs -# ) - -# # Remove right half of violins -# for collection in ax.collections: -# if isinstance(collection, matplotlib.collections.PolyCollection): -# collection.set_clip_path(None) - -# # Adjust legend -# if ax.legend_ is not None: -# handles, labels = ax.get_legend_handles_labels() -# ax.legend(handles[: len(handles) // 2], labels[: len(labels) // 2]) - -# return ax - - -## Probably working -def half_violin(ax, data=None, x=None, y=None, hue=None, **kwargs): - # Prepare data - df = data.copy() - if hue is None: - df["_hue"] = "default" - hue = "_hue" - - # Add fake hue for the right side - df["_fake_hue"] = df[hue] + "_right" - - # Adjust hue_order and palette if provided - if "hue_order" in kwargs: - kwargs["hue_order"] = kwargs["hue_order"] + [ - h + "_right" for h in kwargs["hue_order"] - ] - - if "palette" in kwargs: - palette = kwargs["palette"] - if isinstance(palette, dict): - kwargs["palette"] = { - **palette, - **{k + "_right": v for k, v in palette.items()}, - } - elif isinstance(palette, list): - kwargs["palette"] = palette + palette - - # Plot - sns.violinplot(data=df, x=x, y=y, hue="_fake_hue", split=True, ax=ax, **kwargs) - - # Remove right half of violins - for collection in ax.collections: - if isinstance(collection, plt.matplotlib.collections.PolyCollection): - collection.set_clip_path(None) - - # Adjust legend - if ax.legend_ is not None: - handles, labels = ax.get_legend_handles_labels() - ax.legend(handles[: len(handles) // 2], labels[: len(labels) // 2]) - - return ax - - -# import scitex -# import numpy as np -# fig, ax = scitex.plt.subplots() -# # Test with list data -# data_list = [ -# np.random.normal(0, 1, 100), -# np.random.normal(2, 1.5, 100), -# np.random.normal(5, 0.8, 100), -# ] -# labels = ["x A", "x B", "x C"] -# colors = ["red", "blue", "green"] -# half = True -# ax = half_violin( -# ax, data_list, x="" -# ) - -# EOF diff --git a/src/scitex/plt/ax/_style/__init__.py b/src/scitex/plt/ax/_style/__init__.py deleted file mode 100755 index 8e86d5213..000000000 --- a/src/scitex/plt/ax/_style/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:00:59 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/__init__.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_style/__init__.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from ._add_marginal_ax import add_marginal_ax -from ._add_panel import add_panel -from ._auto_scale_axis import auto_scale_axis -from ._extend import extend -from ._force_aspect import force_aspect -from ._format_label import format_label -from ._hide_spines import hide_spines -from ._map_ticks import map_ticks -from ._rotate_labels import rotate_labels -from ._sci_note import OOMFormatter, sci_note -from ._set_meta import export_metadata_yaml, set_figure_meta, set_meta -from ._set_n_ticks import set_n_ticks -from ._set_size import set_size -from ._set_supxyt import set_supxyt, set_supxytc -from ._set_ticks import set_ticks -from ._set_xyt import set_xyt, set_xytc -from ._share_axes import ( - get_global_xlim, - get_global_ylim, - set_xlims, - set_ylims, - sharex, - sharexy, - sharey, -) -from ._shift import shift -from ._show_spines import show_spines -from ._style_boxplot import style_boxplot -from ._style_violinplot import style_violinplot - -# EOF diff --git a/src/scitex/plt/ax/_style/_add_marginal_ax.py b/src/scitex/plt/ax/_style/_add_marginal_ax.py deleted file mode 100755 index f30d9e9e9..000000000 --- a/src/scitex/plt/ax/_style/_add_marginal_ax.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-04-30 20:18:52 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_add_marginal_ax.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_add_marginal_ax.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -from mpl_toolkits.axes_grid1 import make_axes_locatable - -from ....plt.utils import assert_valid_axis - - -def add_marginal_ax(axis, place, size=0.2, pad=0.1): - """ - Add a marginal axis to the specified side of an existing axis. - - Arguments: - axis (matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper): The axis to which a marginal axis will be added. - place (str): Where to place the marginal axis ('top', 'right', 'bottom', or 'left'). - size (float, optional): Fractional size of the marginal axis relative to the main axis. Defaults to 0.2. - pad (float, optional): Padding between the axes. Defaults to 0.1. - - Returns: - matplotlib.axes.Axes: The newly created marginal axis. - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - divider = make_axes_locatable(axis) - - size_perc_str = f"{size * 100}%" - if place in ["left", "right"]: - size = 1.0 / size - - axis_marginal = divider.append_axes(place, size=size_perc_str, pad=pad) - axis_marginal.set_box_aspect(size) - - return axis_marginal - - -# EOF diff --git a/src/scitex/plt/ax/_style/_add_panel.py b/src/scitex/plt/ax/_style/_add_panel.py deleted file mode 100755 index 3353fb54d..000000000 --- a/src/scitex/plt/ax/_style/_add_panel.py +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-04-30 21:24:49 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_panel.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_panel.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -# Time-stamp: "2024-02-03 15:34:08 (ywatanabe)" - -import matplotlib.pyplot as plt - -from scitex.decorators import deprecated - - -def add_panel(tgt_width_mm=40, tgt_height_mm=None): - """Creates a fixed-size ax figure for panels.""" - - H_TO_W_RATIO = 0.7 - MM_TO_INCH_FACTOR = 1 / 25.4 - - if tgt_height_mm is None: - tgt_height_mm = H_TO_W_RATIO * tgt_width_mm - - # Convert target dimensions from millimeters to inches - tgt_width_in = tgt_width_mm * MM_TO_INCH_FACTOR - tgt_height_in = tgt_height_mm * MM_TO_INCH_FACTOR - - # Create a figure with the specified dimensions - fig = plt.figure(figsize=(tgt_width_in * 2, tgt_height_in * 2)) - - # Calculate the position and size of the axes in figure units (0 to 1) - left = (fig.get_figwidth() - tgt_width_in) / 2 / fig.get_figwidth() - bottom = (fig.get_figheight() - tgt_height_in) / 2 / fig.get_figheight() - ax = fig.add_axes( - [ - left, - bottom, - tgt_width_in / fig.get_figwidth(), - tgt_height_in / fig.get_figheight(), - ] - ) - - return fig, ax - - -@deprecated("Use add_panel instead") -def panel(tgt_width_mm=40, tgt_height_mm=None): - """Create a figure panel with specified dimensions (deprecated). - - This function is deprecated and maintained only for backward compatibility. - Please use `add_panel` instead. - - Parameters - ---------- - tgt_width_mm : float, optional - Target width in millimeters. Default is 40. - tgt_height_mm : float or None, optional - Target height in millimeters. If None, uses golden ratio. - Default is None. - - Returns - ------- - tuple - (fig, ax) - matplotlib figure and axes objects - - See Also - -------- - add_panel : The recommended function to use instead - - Examples - -------- - >>> # Deprecated usage - >>> fig, ax = panel(tgt_width_mm=40, tgt_height_mm=30) - - >>> # Recommended alternative - >>> fig, ax = add_panel(tgt_width_mm=40, tgt_height_mm=30) - """ - return add_panel(tgt_width_mm=40, tgt_height_mm=None) - - -if __name__ == "__main__": - # Example usage: - fig, ax = panel(tgt_width_mm=40, tgt_height_mm=40 * 0.7) - ax.plot([1, 2, 3], [4, 5, 6]) - ax.scatter([1, 2, 3], [4, 5, 6]) - # ... compatible with other ax plotting methods as well - plt.show() - -# EOF diff --git a/src/scitex/plt/ax/_style/_auto_scale_axis.py b/src/scitex/plt/ax/_style/_auto_scale_axis.py deleted file mode 100755 index adc77fea5..000000000 --- a/src/scitex/plt/ax/_style/_auto_scale_axis.py +++ /dev/null @@ -1,200 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2025-11-19 18:45:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_style/_auto_scale_axis.py - -""" -Automatic axis scaling to factor out common powers of 10. - -This utility automatically detects when axis tick values are very small or very large -and factors out the appropriate power of 10, updating both the tick labels and axis label. - -Examples: - 0.0000, 0.0008, 0.0016, 0.0024 → 0, 0.8, 1.6, 2.4 with label "[×10⁻³]" - 10000, 20000, 30000, 40000 → 10, 20, 30, 40 with label "[×10³]" -""" - -from typing import Optional, Tuple - -import numpy as np - - -def detect_scale_factor( - values: np.ndarray, threshold: float = 1e-2 -) -> Tuple[int, bool]: - """ - Detect appropriate power of 10 to factor out from axis values. - - Parameters - ---------- - values : np.ndarray - Array of tick values on the axis - threshold : float - Threshold below which we consider factoring out (default: 0.01) - - Returns - ------- - power : int - Power of 10 to factor out (e.g., -3 for values like 0.001-0.009) - should_scale : bool - Whether scaling should be applied - - Examples - -------- - >>> detect_scale_factor(np.array([0.0, 0.0008, 0.0016, 0.0024])) - (-3, True) - >>> detect_scale_factor(np.array([10000, 20000, 30000])) - (3, True) - >>> detect_scale_factor(np.array([0, 1, 2, 3])) - (0, False) - """ - # Filter out zero values for calculation - nonzero_values = values[values != 0] - - if len(nonzero_values) == 0: - return 0, False - - # Get the order of magnitude of the maximum absolute value - max_abs = np.max(np.abs(nonzero_values)) - - # Check if values are very small (< threshold) or very large (> 1/threshold) - if max_abs < threshold: - # Values are very small - factor out negative power - power = int(np.floor(np.log10(max_abs))) - return power, True - elif max_abs > 1.0 / threshold: - # Values are very large - factor out positive power - power = int(np.floor(np.log10(max_abs))) - # Only scale if power >= 3 (thousands or larger) - if power >= 3: - return power, True - - return 0, False - - -def format_scale_factor(power: int) -> str: - """ - Format the scale factor for display in axis label. - - Parameters - ---------- - power : int - Power of 10 (e.g., -3, 3, 6) - - Returns - ------- - str - Formatted string using matplotlib mathtext (e.g., "×10$^{-3}$", "×10$^{6}$") - - Examples - -------- - >>> format_scale_factor(-3) - '×10$^{-3}$' - >>> format_scale_factor(6) - '×10$^{6}$' - """ - if power == 0: - return "" - - # Use matplotlib's mathtext for reliable rendering across all formats - return f"×10$^{{{power}}}$" - - -def auto_scale_axis(ax, axis: str = "both", threshold: float = 1e-2) -> None: - """ - Automatically scale axis to factor out common powers of 10. - - This function: - 1. Detects when tick values are very small or very large - 2. Factors out the appropriate power of 10 - 3. Updates tick labels to show factored values - 4. Appends the scale factor to the axis label - - Parameters - ---------- - ax : matplotlib.axes.Axes - Axes object to apply scaling to - axis : str, optional - Which axis to scale: 'x', 'y', or 'both' (default: 'both') - threshold : float, optional - Threshold for triggering scaling (default: 1e-2) - Values with max < threshold or max > 1/threshold will be scaled - - Examples - -------- - >>> import matplotlib.pyplot as plt - >>> fig, ax = plt.subplots() - >>> ax.plot([0, 1, 2], [0.0001, 0.0002, 0.0003]) - >>> ax.set_ylabel('Density') - >>> auto_scale_axis(ax, axis='y') - >>> # Y-axis now shows: 0.1, 0.2, 0.3 with label "Density [×10⁻³]" - - Notes - ----- - - Only scales if the range of values justifies it (very small or very large) - - Preserves the original axis label and appends the scale factor - - Uses Unicode superscripts for clean display (×10⁻³, ×10⁶, etc.) - """ - import matplotlib.ticker as ticker - - def scale_axis_impl(ax_obj, is_x_axis: bool): - """Internal implementation for scaling a single axis.""" - # Get current tick values - if is_x_axis: - tick_values = np.array(ax_obj.get_xticks()) - get_label = ax_obj.get_xlabel - set_label = ax_obj.set_xlabel - set_formatter = ax_obj.xaxis.set_major_formatter - else: - tick_values = np.array(ax_obj.get_yticks()) - get_label = ax_obj.get_ylabel - set_label = ax_obj.set_ylabel - set_formatter = ax_obj.yaxis.set_major_formatter - - # Detect if scaling is needed - power, should_scale = detect_scale_factor(tick_values, threshold) - - if not should_scale: - return - - # Create scaling factor - scale_factor = 10**power - - # Update tick formatter to show scaled values - def format_func(value, pos): - scaled_value = value / scale_factor - # Format with appropriate precision - if abs(scaled_value) < 10: - return f"{scaled_value:.1f}" - else: - return f"{scaled_value:.0f}" - - set_formatter(ticker.FuncFormatter(format_func)) - - # Update axis label with scale factor - current_label = get_label() - scale_str = format_scale_factor(power) - - # Check if label already has units in brackets - if "[" in current_label and "]" in current_label: - # Insert scale factor before the closing bracket - # e.g., "Density [a.u.]" → "Density [×10⁻³ a.u.]" - label_parts = current_label.rsplit("]", 1) - new_label = f"{label_parts[0]} {scale_str}]{label_parts[1]}" - else: - # Append scale factor in brackets - # e.g., "Density" → "Density [×10⁻³]" - new_label = ( - f"{current_label} [{scale_str}]" if current_label else f"[{scale_str}]" - ) - - set_label(new_label) - - # Apply to requested axes - if axis in ["x", "both"]: - scale_axis_impl(ax, is_x_axis=True) - if axis in ["y", "both"]: - scale_axis_impl(ax, is_x_axis=False) - - -# EOF diff --git a/src/scitex/plt/ax/_style/_extend.py b/src/scitex/plt/ax/_style/_extend.py deleted file mode 100755 index 6c8f9d0b0..000000000 --- a/src/scitex/plt/ax/_style/_extend.py +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:00:51 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_extend.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_style/_extend.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib - -from ....plt.utils import assert_valid_axis - - -def extend(axis, x_ratio=1.0, y_ratio=1.0): - """ - Extend or shrink a matplotlib axis or scitex axis wrapper while maintaining its center position. - - Args: - axis (matplotlib.axes._axes.Axes or scitex.plt._subplots.AxisWrapper): The axis to be modified. - x_ratio (float, optional): The ratio to scale the width. Default is 1.0. - y_ratio (float, optional): The ratio to scale the height. Default is 1.0. - - Returns: - matplotlib.axes._axes.Axes or scitex.plt._subplots.AxisWrapper: The modified axis. - - Raises: - AssertionError: If the first argument is not a valid axis. - """ - - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - assert x_ratio != 0, "x_ratio must not be 0." - assert y_ratio != 0, "y_ratio must not be 0." - - ## Original coordinates - bbox = axis.get_position() - left_orig = bbox.x0 - bottom_orig = bbox.y0 - width_orig = bbox.x1 - bbox.x0 - height_orig = bbox.y1 - bbox.y0 - g_orig = (left_orig + width_orig / 2.0, bottom_orig + height_orig / 2.0) - - ## Target coordinates - g_tgt = g_orig - width_tgt = width_orig * x_ratio - height_tgt = height_orig * y_ratio - left_tgt = g_tgt[0] - width_tgt / 2 - bottom_tgt = g_tgt[1] - height_tgt / 2 - - # Extend the axis - axis.set_position( - [ - left_tgt, - bottom_tgt, - width_tgt, - height_tgt, - ] - ) - return axis - - -# EOF diff --git a/src/scitex/plt/ax/_style/_force_aspect.py b/src/scitex/plt/ax/_style/_force_aspect.py deleted file mode 100755 index 6e00e13b1..000000000 --- a/src/scitex/plt/ax/_style/_force_aspect.py +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:00:52 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_force_aspect.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_style/_force_aspect.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib - -from ....plt.utils import assert_valid_axis - - -def force_aspect(axis, aspect=1): - """ - Forces aspect ratio of an axis based on the extent of the image. - - Arguments: - axis (matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper): The axis to adjust. - aspect (float, optional): The aspect ratio to apply. Defaults to 1. - - Returns: - matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper: The axis with adjusted aspect ratio. - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - im = axis.get_images() - - extent = im[0].get_extent() - - axis.set_aspect(abs((extent[1] - extent[0]) / (extent[3] - extent[2])) / aspect) - return axis - - -# EOF diff --git a/src/scitex/plt/ax/_style/_format_label.py b/src/scitex/plt/ax/_style/_format_label.py deleted file mode 100755 index d7e012283..000000000 --- a/src/scitex/plt/ax/_style/_format_label.py +++ /dev/null @@ -1,23 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-09-15 09:39:02 (ywatanabe)" -# /home/ywatanabe/proj/_scitex_repo_openhands/src/scitex/plt/ax/_format_label.py - - -def format_label(label): - """ - Format label by capitalizing first letter and replacing underscores with spaces. - """ - - # if isinstance(label, str): - # # Replace underscores with spaces - # label = label.replace("_", " ") - - # # Capitalize first letter of each word - # label = " ".join(word.capitalize() for word in label.split()) - - # # Special case for abbreviations (all caps) - # if label.isupper(): - # return label - - return label diff --git a/src/scitex/plt/ax/_style/_format_units.py b/src/scitex/plt/ax/_style/_format_units.py deleted file mode 100755 index 2424f01c5..000000000 --- a/src/scitex/plt/ax/_style/_format_units.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-11-19 15:10:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_style/_format_units.py - -""" -Utility functions for formatting axis labels with proper unit notation. -""" - -from typing import Optional - - -def format_label(label: str, unit: Optional[str] = None) -> str: - """ - Format axis label with unit in brackets (publication standard). - - Parameters - ---------- - label : str - The label text (e.g., "Time", "Voltage") - unit : str, optional - The unit (e.g., "s", "mV", "Hz"). If None, returns label as-is. - - Returns - ------- - str - Formatted label with unit in brackets (e.g., "Time [s]") - - Examples - -------- - >>> stx.ax.format_label("Time", "s") - 'Time [s]' - - >>> stx.ax.format_label("Voltage", "mV") - 'Voltage [mV]' - - >>> stx.ax.format_label("Count") - 'Count' - - >>> # Direct usage with axis - >>> ax.set_xlabel(stx.ax.format_label("Time", "s")) - >>> ax.set_ylabel(stx.ax.format_label("Amplitude", "mV")) - - Notes - ----- - According to publication standards (Nature, Science, Cell), units should be - enclosed in square brackets, not parentheses: - - Correct: "Time [s]", "Voltage [mV]" - - Incorrect: "Time (s)", "Voltage (mV)" - """ - if unit is None or unit == "": - return label - return f"{label} [{unit}]" - - -def format_label_auto(text: str) -> str: - """ - Automatically convert parentheses-style units to bracket-style. - - This function detects units in parentheses and converts them to brackets. - - Parameters - ---------- - text : str - Label text, possibly with units in parentheses - - Returns - ------- - str - Label text with units in brackets - - Examples - -------- - >>> stx.ax.format_label_auto("Time (s)") - 'Time [s]' - - >>> stx.ax.format_label_auto("Voltage (mV)") - 'Voltage [mV]' - - >>> stx.ax.format_label_auto("Count") - 'Count' - - Notes - ----- - This is useful for automatically correcting existing labels that use - parentheses notation. - """ - import re - - # Pattern to match units in parentheses at the end of the string - # e.g., "Time (s)" or "Frequency (Hz)" - pattern = r"\s*\(([^)]+)\)\s*$" - - match = re.search(pattern, text) - if match: - unit = match.group(1) - label = text[: match.start()].strip() - return f"{label} [{unit}]" - - return text - - -# EOF diff --git a/src/scitex/plt/ax/_style/_hide_spines.py b/src/scitex/plt/ax/_style/_hide_spines.py deleted file mode 100755 index adf69fc8d..000000000 --- a/src/scitex/plt/ax/_style/_hide_spines.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-06-07 15:45:36 (ywatanabe)" -# File: /ssh:ywatanabe@sp:/home/ywatanabe/proj/.claude-worktree/scitex_repo/src/scitex/plt/ax/_style/_hide_spines.py -# ---------------------------------------- -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -# Time-stamp: "2024-04-26 20:03:45 (ywatanabe)" - -import matplotlib - -from ....plt.utils import assert_valid_axis - - -def hide_spines( - axis, - top=True, - bottom=False, - left=False, - right=True, - ticks=False, - labels=False, -): - """ - Hides the specified spines of a matplotlib Axes object or scitex axis wrapper and optionally removes the ticks and labels. - - This function is designed to work with matplotlib Axes objects or scitex axis wrappers. It allows for a cleaner, more minimalist - presentation of plots by hiding the spines (the lines denoting the boundaries of the plot area) and optionally - removing the ticks and labels from the axes. - - Arguments: - ax (matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper): The axis for which the spines will be hidden. - top (bool, optional): If True, hides the top spine. Defaults to True. - bottom (bool, optional): If True, hides the bottom spine. Defaults to False. - left (bool, optional): If True, hides the left spine. Defaults to False. - right (bool, optional): If True, hides the right spine. Defaults to True. - ticks (bool, optional): If True, removes the ticks from the hidden spines' axes. Defaults to False. - labels (bool, optional): If True, removes the labels from the hidden spines' axes. Defaults to False. - - Returns: - matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper: The modified axis with the specified spines hidden. - - Example: - >>> fig, ax = plt.subplots() - >>> hide_spines(ax) - >>> plt.show() - """ - assert_valid_axis( - axis, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - tgts = [] - if top: - tgts.append("top") - if bottom: - tgts.append("bottom") - if left: - tgts.append("left") - if right: - tgts.append("right") - - for tgt in tgts: - # Spines - axis.spines[tgt].set_visible(False) - - # Ticks - if ticks: - if tgt == "bottom": - axis.xaxis.set_ticks_position("none") - elif tgt == "left": - axis.yaxis.set_ticks_position("none") - - # Labels - if labels: - if tgt == "bottom": - axis.set_xticklabels([]) - elif tgt == "left": - axis.set_yticklabels([]) - - return axis - - -# EOF diff --git a/src/scitex/plt/ax/_style/_map_ticks.py b/src/scitex/plt/ax/_style/_map_ticks.py deleted file mode 100755 index bf0bff5f1..000000000 --- a/src/scitex/plt/ax/_style/_map_ticks.py +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:00:56 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_map_ticks.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_style/_map_ticks.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np - -from ....plt.utils import assert_valid_axis - - -def map_ticks(ax, src, tgt, axis="x"): - """ - Maps source tick positions or labels to new target labels on a matplotlib Axes object. - Supports both numeric positions and string labels for source ticks ('src'), enabling the mapping - to new target labels ('tgt'). This ensures only the specified target ticks are displayed on the - final axis, enhancing the clarity and readability of plots. - - Parameters: - - ax (matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper): The Axes object to modify. - - src (list of str or numeric): Source positions (if numeric) or labels (if str) to map from. - When using string labels, ensure they match the current tick labels on the axis. - - tgt (list of str): New target labels to apply to the axis. Must have the same length as 'src'. - - axis (str): Specifies which axis to apply the tick modifications ('x' or 'y'). - - Returns: - - ax (matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper): The modified Axes object with adjusted tick labels. - - Examples: - -------- - Numeric Example: - fig, ax = plt.subplots() - x = np.linspace(0, 2 * np.pi, 100) - y = np.sin(x) - ax.plot(x, y) # Plot a sine wave - src = [0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi] # Numeric src positions - tgt = ['0', 'π/2', 'π', '3π/2', '2π'] # Corresponding target labels - map_ticks(ax, src, tgt, axis="x") # Map src to tgt on the x-axis - plt.show() - - String Example: - fig, ax = plt.subplots() - categories = ['A', 'B', 'C', 'D', 'E'] # Initial categories - values = [1, 3, 2, 5, 4] - ax.bar(categories, values) # Bar plot with string labels - src = ['A', 'B', 'C', 'D', 'E'] # Source labels to map from - tgt = ['Alpha', 'Beta', 'Gamma', 'Delta', 'Epsilon'] # New target labels - map_ticks(ax, src, tgt, axis="x") # Apply the mapping - plt.show() - """ - assert_valid_axis( - ax, "First argument must be a matplotlib axis or scitex axis wrapper" - ) - - if len(src) != len(tgt): - raise ValueError( - "Source ('src') and target ('tgt') must have the same number of elements." - ) - - # Determine tick positions if src is string data - if all(isinstance(item, str) for item in src): - if axis == "x": - all_labels = [label.get_text() for label in ax.get_xticklabels()] - else: - all_labels = [label.get_text() for label in ax.get_yticklabels()] - - # Find positions of src labels - src_positions = [all_labels.index(s) for s in src if s in all_labels] - else: - # Use src as positions directly if numeric - src_positions = src - - # Set the ticks and labels based on the specified axis - if axis == "x": - ax.set_xticks(src_positions) - ax.set_xticklabels(tgt) - elif axis == "y": - ax.set_yticks(src_positions) - ax.set_yticklabels(tgt) - else: - raise ValueError("Invalid axis argument. Use 'x' or 'y'.") - - return ax - - -def numeric_example(): - """Example demonstrating numeric tick mapping. - - Shows how to replace numeric tick positions with custom labels, - such as replacing radian values with pi notation in trigonometric plots. - - Returns - ------- - matplotlib.figure.Figure - Figure with two subplots showing before and after tick mapping. - - Examples - -------- - >>> fig = numeric_example() - >>> plt.show() - - Notes - ----- - The top subplot shows original numeric labels, while the bottom - subplot shows the same data with custom pi notation labels. - """ - fig, axs = plt.subplots(2, 1, figsize=(10, 6)) # Two rows, one column - - # Original plot - x = np.linspace(0, 2 * np.pi, 100) - y = np.sin(x) - axs[0].plot(x, y) # Plot a sine wave on the first row - axs[0].set_title("Original Numeric Labels") - - # Numeric src positions for ticks (e.g., multiples of pi) and target labels - src = [0, np.pi / 2, np.pi, 3 * np.pi / 2, 2 * np.pi] - tgt = ["0", "π/2", "π", "3π/2", "2π"] - - # Plot with mapped ticks - axs[1].plot(x, y) # Plot again on the second row for mapped labels - map_ticks(axs[1], src, tgt, axis="x") - axs[1].set_title("Mapped Numeric Labels") - - return fig - - -def string_example(): - """Example demonstrating string tick mapping. - - Shows how to replace categorical string labels with more descriptive - alternatives, useful for improving plot readability. - - Returns - ------- - matplotlib.figure.Figure - Figure with two subplots showing before and after tick mapping. - - Examples - -------- - >>> fig = string_example() - >>> plt.show() - - Notes - ----- - The top subplot shows original short category labels (A, B, C...), - while the bottom subplot shows the same data with descriptive Greek - letter names. - """ - fig, axs = plt.subplots(2, 1, figsize=(10, 6)) # Two rows, one column - - # Original plot with categorical string labels - categories = ["A", "B", "C", "D", "E"] - values = [1, 3, 2, 5, 4] - axs[0].bar(categories, values) - axs[0].set_title("Original String Labels") - - # src as the existing labels to change and target labels - src = categories - tgt = ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"] - - # Plot with mapped string labels - axs[1].bar(categories, values) # Bar plot again on the second row for mapped labels - map_ticks(axs[1], src, tgt, axis="x") - axs[1].set_title("Mapped String Labels") - - return fig - - -# Execute examples -if __name__ == "__main__": - fig_numeric = numeric_example() - fig_string = string_example() - - plt.tight_layout() - plt.show() - -# EOF diff --git a/src/scitex/plt/ax/_style/_rotate_labels.py b/src/scitex/plt/ax/_style/_rotate_labels.py deleted file mode 100755 index 4cd199faf..000000000 --- a/src/scitex/plt/ax/_style/_rotate_labels.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-09-24 13:22:52 (ywatanabe)" -# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_rotate_labels.py -# ---------------------------------------- -from __future__ import annotations - -import os - -__FILE__ = __file__ -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -# Time-stamp: "2024-10-27 13:24:32 (ywatanabe)" -# /home/ywatanabe/proj/_scitex_repo_openhands/src/scitex/plt/ax/_rotate_labels.py - -"""This script does XYZ.""" - -"""Imports""" -import numpy as np - - -def rotate_labels( - ax, - x=None, - y=None, - x_ha=None, - y_ha=None, - x_va=None, - y_va=None, - auto_adjust=True, - scientific_convention=True, - tight_layout=False, -): - """ - Rotate x and y axis labels of a matplotlib Axes object with automatic positioning. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The Axes object to modify. - x : float or None, optional - Rotation angle for x-axis labels in degrees. Default is None. - If 0 or None, x-axis labels are not rotated. - y : float or None, optional - Rotation angle for y-axis labels in degrees. Default is None. - If 0 or None, y-axis labels are not rotated. - x_ha : str, optional - Horizontal alignment for x-axis labels. If None, automatically determined. - y_ha : str, optional - Horizontal alignment for y-axis labels. If None, automatically determined. - x_va : str, optional - Vertical alignment for x-axis labels. If None, automatically determined. - y_va : str, optional - Vertical alignment for y-axis labels. If None, automatically determined. - auto_adjust : bool, optional - Whether to automatically adjust alignment based on rotation angle. Default is True. - scientific_convention : bool, optional - Whether to follow scientific plotting conventions. Default is True. - tight_layout : bool, optional - Whether to apply tight_layout to prevent overlapping. Default is False. - - Returns - ------- - matplotlib.axes.Axes - The modified Axes object. - - Example - ------- - fig, ax = plt.subplots() - ax.plot([1, 2, 3], [1, 2, 3]) - rotate_labels(ax) - plt.show() - - Notes - ----- - Scientific conventions for label rotation: - - X-axis labels: For angles 0-90°, use 'right' alignment; for 90-180°, use 'left' - - Y-axis labels: For angles 0-90°, use 'center' alignment; adjust vertical as needed - - Optimal readability maintained through automatic positioning - """ - # Determine which axes to rotate (skip if None or 0) - rotate_x = x is not None and x != 0 - rotate_y = y is not None and y != 0 - - # Get current tick positions - xticks = ax.get_xticks() - yticks = ax.get_yticks() - - # Set ticks explicitly - ax.set_xticks(xticks) - ax.set_yticks(yticks) - - # Auto-adjust alignment based on rotation angle and scientific conventions - if auto_adjust: - if rotate_x: - x_ha, x_va = _get_optimal_alignment( - "x", x, x_ha, x_va, scientific_convention - ) - if rotate_y: - y_ha, y_va = _get_optimal_alignment( - "y", y, y_ha, y_va, scientific_convention - ) - - # Apply defaults if not auto-adjusting - if rotate_x: - if x_ha is None: - x_ha = "center" - if x_va is None: - x_va = "center" - if rotate_y: - if y_ha is None: - y_ha = "center" - if y_va is None: - y_va = "center" - - # Check if this axis is part of a shared x-axis configuration - # If labels are already visible (bottom subplot or not shared), keep them visible - # This preserves matplotlib's default sharex behavior - x_labels_visible = ax.xaxis.get_tick_params()["labelbottom"] - y_labels_visible = ax.yaxis.get_tick_params()["labelleft"] - - # Set labels with rotation and proper alignment - # Only set labels if they're currently visible (respects sharex/sharey) - if x_labels_visible and rotate_x: - ax.set_xticklabels(ax.get_xticklabels(), rotation=x, ha=x_ha, va=x_va) - if y_labels_visible and rotate_y: - ax.set_yticklabels(ax.get_yticklabels(), rotation=y, ha=y_ha, va=y_va) - - # Auto-adjust subplot parameters for better layout if needed - if auto_adjust and scientific_convention: - # Only pass non-zero angles for adjustment - x_angle = x if rotate_x else 0 - y_angle = y if rotate_y else 0 - _adjust_subplot_params(ax, x_angle, y_angle) - - # Apply tight_layout if requested to prevent overlapping - if tight_layout: - fig = ax.get_figure() - try: - fig.tight_layout() - except Exception: - # Fallback to manual adjustment if tight_layout fails - x_angle = x if rotate_x else 0 - y_angle = y if rotate_y else 0 - _adjust_subplot_params(ax, x_angle, y_angle) - - return ax - - -def _get_optimal_alignment(axis, angle, ha, va, scientific_convention): - """ - Determine optimal alignment based on rotation angle and scientific conventions. - - Parameters - ---------- - axis : str - 'x' or 'y' axis - angle : float - Rotation angle in degrees - ha : str or None - Current horizontal alignment - va : str or None - Current vertical alignment - scientific_convention : bool - Whether to follow scientific conventions - - Returns - ------- - tuple - (horizontal_alignment, vertical_alignment) - """ - # Normalize angle to 0-360 range - angle = angle % 360 - - if axis == "x": - if scientific_convention: - # Scientific convention for x-axis labels - if 0 <= angle <= 30: - ha = ha or "center" - va = va or "top" - elif 30 < angle <= 60: - ha = ha or "right" - va = va or "top" - elif 60 < angle < 90: - ha = ha or "right" - va = va or "top" - elif angle == 90: - # Special case for exact 90 degrees - ha = ha or "right" - va = va or "top" - elif 90 < angle <= 120: - ha = ha or "right" - va = va or "center" - elif 120 < angle <= 150: - ha = ha or "right" - va = va or "bottom" - elif 150 < angle <= 210: - ha = ha or "center" - va = va or "bottom" - elif 210 < angle <= 240: - ha = ha or "left" - va = va or "bottom" - elif 240 < angle <= 300: - ha = ha or "left" - va = va or "center" - else: # 300-360 - ha = ha or "left" - va = va or "top" - else: - ha = ha or "center" - va = va or "top" - - else: # y-axis - if scientific_convention: - # Scientific convention for y-axis labels - if 0 <= angle <= 30: - ha = ha or "right" - va = va or "center" - elif 30 < angle <= 60: - ha = ha or "right" - va = va or "bottom" - elif 60 < angle <= 120: - ha = ha or "center" - va = va or "bottom" - elif 120 < angle <= 150: - ha = ha or "left" - va = va or "bottom" - elif 150 < angle <= 210: - ha = ha or "left" - va = va or "center" - elif 210 < angle <= 240: - ha = ha or "left" - va = va or "top" - elif 240 < angle <= 300: - ha = ha or "center" - va = va or "top" - else: # 300-360 - ha = ha or "right" - va = va or "top" - else: - ha = ha or "center" - va = va or "center" - - return ha, va - - -def _adjust_subplot_params(ax, x_angle, y_angle): - """ - Automatically adjust subplot parameters to accommodate rotated labels. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The axes object - x_angle : float - X-axis rotation angle - y_angle : float - Y-axis rotation angle - """ - fig = ax.get_figure() - - # Check if figure is using a layout engine that is incompatible with subplots_adjust - try: - # For matplotlib >= 3.6 - if hasattr(fig, "get_layout_engine"): - layout_engine = fig.get_layout_engine() - if layout_engine is not None: - # If using constrained_layout or tight_layout, don't adjust - return - except AttributeError: - pass - - # Check for constrained_layout (older matplotlib versions) - try: - if hasattr(fig, "get_constrained_layout"): - if fig.get_constrained_layout(): - # Constrained layout is active, don't adjust - return - except AttributeError: - pass - - # Calculate required margins based on rotation angles - # Special handling for 90-degree rotation - if x_angle == 90: - x_margin_factor = 0.3 # Maximum margin for 90 degrees - else: - # Increase margin more significantly for rotated x-axis labels to prevent xlabel overlap - x_margin_factor = abs(np.sin(np.radians(x_angle))) * 0.25 # Increased from 0.2 - - y_margin_factor = abs(np.sin(np.radians(y_angle))) * 0.15 - - # Get current subplot parameters - try: - subplotpars = fig.subplotpars - current_bottom = subplotpars.bottom - current_left = subplotpars.left - - # Adjust margins if they need to be increased - # Ensure more space for rotated x-labels and xlabel - new_bottom = max( - current_bottom, 0.2 + x_margin_factor - ) # Increased base from 0.15 - new_left = max(current_left, 0.1 + y_margin_factor) - - # Only adjust if we're increasing the margins significantly - if ( - new_bottom > current_bottom + 0.02 or new_left > current_left + 0.02 - ): # Reduced threshold - # Suppress warning and try to adjust - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - fig.subplots_adjust(bottom=new_bottom, left=new_left) - except Exception: - # Skip adjustment if there are issues - pass - - -# EOF diff --git a/src/scitex/plt/ax/_style/_rotate_labels_v01.py b/src/scitex/plt/ax/_style/_rotate_labels_v01.py deleted file mode 100755 index 335d1ad3d..000000000 --- a/src/scitex/plt/ax/_style/_rotate_labels_v01.py +++ /dev/null @@ -1,258 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-10-27 13:24:32 (ywatanabe)" -# /home/ywatanabe/proj/_scitex_repo_openhands/src/scitex/plt/ax/_rotate_labels.py - -"""This script does XYZ.""" - -"""Imports""" -import numpy as np - - -def rotate_labels( - ax, - x=45, - y=45, - x_ha=None, - y_ha=None, - x_va=None, - y_va=None, - auto_adjust=True, - scientific_convention=True, -): - """ - Rotate x and y axis labels of a matplotlib Axes object with automatic positioning. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The Axes object to modify. - x : float, optional - Rotation angle for x-axis labels in degrees. Default is 45. - y : float, optional - Rotation angle for y-axis labels in degrees. Default is 45. - x_ha : str, optional - Horizontal alignment for x-axis labels. If None, automatically determined. - y_ha : str, optional - Horizontal alignment for y-axis labels. If None, automatically determined. - x_va : str, optional - Vertical alignment for x-axis labels. If None, automatically determined. - y_va : str, optional - Vertical alignment for y-axis labels. If None, automatically determined. - auto_adjust : bool, optional - Whether to automatically adjust alignment based on rotation angle. Default is True. - scientific_convention : bool, optional - Whether to follow scientific plotting conventions. Default is True. - - Returns - ------- - matplotlib.axes.Axes - The modified Axes object. - - Example - ------- - fig, ax = plt.subplots() - ax.plot([1, 2, 3], [1, 2, 3]) - rotate_labels(ax) - plt.show() - - Notes - ----- - Scientific conventions for label rotation: - - X-axis labels: For angles 0-90°, use 'right' alignment; for 90-180°, use 'left' - - Y-axis labels: For angles 0-90°, use 'center' alignment; adjust vertical as needed - - Optimal readability maintained through automatic positioning - """ - # Get current tick positions - xticks = ax.get_xticks() - yticks = ax.get_yticks() - - # Set ticks explicitly - ax.set_xticks(xticks) - ax.set_yticks(yticks) - - # Auto-adjust alignment based on rotation angle and scientific conventions - if auto_adjust: - x_ha, x_va = _get_optimal_alignment("x", x, x_ha, x_va, scientific_convention) - y_ha, y_va = _get_optimal_alignment("y", y, y_ha, y_va, scientific_convention) - - # Apply defaults if not auto-adjusting - if x_ha is None: - x_ha = "center" - if y_ha is None: - y_ha = "center" - if x_va is None: - x_va = "center" - if y_va is None: - y_va = "center" - - # Check if this axis is part of a shared x-axis configuration - # If labels are already visible (bottom subplot or not shared), keep them visible - # This preserves matplotlib's default sharex behavior - x_labels_visible = ax.xaxis.get_tick_params()["labelbottom"] - y_labels_visible = ax.yaxis.get_tick_params()["labelleft"] - - # Set labels with rotation and proper alignment - # Only set labels if they're currently visible (respects sharex/sharey) - if x_labels_visible: - ax.set_xticklabels(ax.get_xticklabels(), rotation=x, ha=x_ha, va=x_va) - if y_labels_visible: - ax.set_yticklabels(ax.get_yticklabels(), rotation=y, ha=y_ha, va=y_va) - - # Auto-adjust subplot parameters for better layout if needed - if auto_adjust and scientific_convention: - _adjust_subplot_params(ax, x, y) - - return ax - - -def _get_optimal_alignment(axis, angle, ha, va, scientific_convention): - """ - Determine optimal alignment based on rotation angle and scientific conventions. - - Parameters - ---------- - axis : str - 'x' or 'y' axis - angle : float - Rotation angle in degrees - ha : str or None - Current horizontal alignment - va : str or None - Current vertical alignment - scientific_convention : bool - Whether to follow scientific conventions - - Returns - ------- - tuple - (horizontal_alignment, vertical_alignment) - """ - # Normalize angle to 0-360 range - angle = angle % 360 - - if axis == "x": - if scientific_convention: - # Scientific convention for x-axis labels - if 0 <= angle <= 30: - ha = ha or "center" - va = va or "top" - elif 30 < angle <= 60: - ha = ha or "right" - va = va or "top" - elif 60 < angle <= 120: - ha = ha or "right" - va = va or "center" - elif 120 < angle <= 150: - ha = ha or "right" - va = va or "bottom" - elif 150 < angle <= 210: - ha = ha or "center" - va = va or "bottom" - elif 210 < angle <= 240: - ha = ha or "left" - va = va or "bottom" - elif 240 < angle <= 300: - ha = ha or "left" - va = va or "center" - else: # 300-360 - ha = ha or "left" - va = va or "top" - else: - ha = ha or "center" - va = va or "top" - - else: # y-axis - if scientific_convention: - # Scientific convention for y-axis labels - if 0 <= angle <= 30: - ha = ha or "right" - va = va or "center" - elif 30 < angle <= 60: - ha = ha or "right" - va = va or "bottom" - elif 60 < angle <= 120: - ha = ha or "center" - va = va or "bottom" - elif 120 < angle <= 150: - ha = ha or "left" - va = va or "bottom" - elif 150 < angle <= 210: - ha = ha or "left" - va = va or "center" - elif 210 < angle <= 240: - ha = ha or "left" - va = va or "top" - elif 240 < angle <= 300: - ha = ha or "center" - va = va or "top" - else: # 300-360 - ha = ha or "right" - va = va or "top" - else: - ha = ha or "center" - va = va or "center" - - return ha, va - - -def _adjust_subplot_params(ax, x_angle, y_angle): - """ - Automatically adjust subplot parameters to accommodate rotated labels. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The axes object - x_angle : float - X-axis rotation angle - y_angle : float - Y-axis rotation angle - """ - fig = ax.get_figure() - - # Check if figure is using a layout engine that is incompatible with subplots_adjust - try: - # For matplotlib >= 3.6 - if hasattr(fig, "get_layout_engine"): - layout_engine = fig.get_layout_engine() - if layout_engine is not None: - # If using constrained_layout or tight_layout, don't adjust - return - except AttributeError: - pass - - # Check for constrained_layout (older matplotlib versions) - try: - if hasattr(fig, "get_constrained_layout"): - if fig.get_constrained_layout(): - # Constrained layout is active, don't adjust - return - except AttributeError: - pass - - # Calculate required margins based on rotation angles - x_margin_factor = abs(np.sin(np.radians(x_angle))) * 0.1 - y_margin_factor = abs(np.sin(np.radians(y_angle))) * 0.15 - - # Get current subplot parameters - try: - subplotpars = fig.subplotpars - current_bottom = subplotpars.bottom - current_left = subplotpars.left - - # Adjust margins if they need to be increased - new_bottom = max(current_bottom, 0.1 + x_margin_factor) - new_left = max(current_left, 0.1 + y_margin_factor) - - # Only adjust if we're increasing the margins significantly - if new_bottom > current_bottom + 0.05 or new_left > current_left + 0.05: - # Suppress warning and try to adjust - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - fig.subplots_adjust(bottom=new_bottom, left=new_left) - except Exception: - # Skip adjustment if there are issues - pass diff --git a/src/scitex/plt/ax/_style/_sci_note.py b/src/scitex/plt/ax/_style/_sci_note.py deleted file mode 100755 index 016e0781a..000000000 --- a/src/scitex/plt/ax/_style/_sci_note.py +++ /dev/null @@ -1,279 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-03 11:58:58 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_sci_note.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_style/_sci_note.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import numpy as np - - -class OOMFormatter(matplotlib.ticker.ScalarFormatter): - """Custom formatter for scientific notation with fixed order of magnitude. - - A matplotlib formatter that allows you to specify a fixed exponent for - scientific notation, rather than letting matplotlib choose it automatically. - Useful when you want consistent notation across multiple plots or specific - exponent values. - - Parameters - ---------- - order : int or None, optional - Fixed order of magnitude (exponent) to use. If None, calculated - automatically. Default is None. - fformat : str, optional - Format string for the mantissa. Default is "%1.1f". - offset : bool, optional - Whether to use offset notation. Default is True. - mathText : bool, optional - Whether to use mathtext rendering. Default is True. - - Attributes - ---------- - order : int or None - The fixed order of magnitude to use. - fformat : str - Format string for displaying numbers. - - Examples - -------- - >>> # Force all labels to use 10^3 notation - >>> formatter = OOMFormatter(order=3, fformat="%1.2f") - >>> ax.xaxis.set_major_formatter(formatter) - - >>> # Use 10^-6 for microvolts - >>> formatter = OOMFormatter(order=-6, fformat="%1.1f") - >>> ax.yaxis.set_major_formatter(formatter) - - See Also - -------- - matplotlib.ticker.ScalarFormatter : Base formatter class - sci_note : Convenience function using this formatter - """ - - def __init__(self, order=None, fformat="%1.1f", offset=True, mathText=True): - self.order = order - self.fformat = fformat - matplotlib.ticker.ScalarFormatter.__init__( - self, useOffset=offset, useMathText=mathText - ) - - def _set_order_of_magnitude(self): - if self.order is not None: - self.orderOfMagnitude = self.order - else: - super()._set_order_of_magnitude() - - def _set_format(self, vmin=None, vmax=None): - self.format = self.fformat - if self._useMathText: - self.format = r"$\mathdefault{%s}$" % self.format - - -def sci_note( - ax, - fformat="%1.1f", - x=False, - y=False, - scilimits=(-3, 3), - order_x=None, - order_y=None, - pad_x=-22, - pad_y=-20, -): - """ - Apply scientific notation to axis with optional manual order of magnitude. - - Parameters: - ----------- - ax : matplotlib Axes - The axes to apply scientific notation to - fformat : str - Format string for tick labels - x, y : bool - Whether to apply to x or y axis - scilimits : tuple - Scientific notation limits - order_x, order_y : int or None - Manual order of magnitude (exponent). If None, calculated automatically - pad_x, pad_y : int - Padding for the axis labels - """ - if x: - # Calculate order if not specified - if order_x is None: - order_x = np.floor(np.log10(np.max(np.abs(ax.get_xlim())) + 1e-5)) - - ax.xaxis.set_major_formatter(OOMFormatter(order=int(order_x), fformat=fformat)) - ax.ticklabel_format(axis="x", style="sci", scilimits=scilimits) - ax.xaxis.labelpad = pad_x - shift_x = (ax.get_xlim()[0] - ax.get_xlim()[1]) * 0.01 - ax.xaxis.get_offset_text().set_position((shift_x, 0)) - - if y: - # Calculate order if not specified - if order_y is None: - order_y = np.floor(np.log10(np.max(np.abs(ax.get_ylim())) + 1e-5)) - - ax.yaxis.set_major_formatter(OOMFormatter(order=int(order_y), fformat=fformat)) - ax.ticklabel_format(axis="y", style="sci", scilimits=scilimits) - ax.yaxis.labelpad = pad_y - shift_y = (ax.get_ylim()[0] - ax.get_ylim()[1]) * 0.01 - ax.yaxis.get_offset_text().set_position((0, shift_y)) - - return ax - - -# import matplotlib -# import numpy as np - - -# class OOMFormatter(matplotlib.ticker.ScalarFormatter): -# def __init__(self, order=0, fformat="%1.1f", offset=True, mathText=True): -# self.order = order -# self.fformat = fformat -# matplotlib.ticker.ScalarFormatter.__init__( -# self, useOffset=offset, useMathText=mathText -# ) - -# def _set_order_of_magnitude(self): -# self.orderOfMagnitude = self.order - -# def _set_format(self, vmin=None, vmax=None): -# self.format = self.fformat -# if self._useMathText: -# self.format = r"$\mathdefault{%s}$" % self.format - - -# def sci_note(ax, fformat="%1.1f", x=False, y=False, scilimits=(-3, 3)): -# order_x = 0 -# order_y = 0 - -# if x: -# order_x = np.floor(np.log10(np.max(np.abs(ax.get_xlim())) + 1e-5)) -# ax.xaxis.set_major_formatter( -# OOMFormatter(order=int(order_x), fformat=fformat) -# ) -# ax.ticklabel_format(axis="x", style="sci", scilimits=scilimits) -# ax.xaxis.labelpad = -22 -# shift_x = (ax.get_xlim()[0] - ax.get_xlim()[1]) * 0.01 -# ax.xaxis.get_offset_text().set_position((shift_x, 0)) - -# if y: -# order_y = np.floor(np.log10(np.max(np.abs(ax.get_ylim())) + 1e-5)) -# ax.yaxis.set_major_formatter( -# OOMFormatter(order=int(order_y), fformat=fformat) -# ) -# ax.ticklabel_format(axis="y", style="sci", scilimits=scilimits) -# ax.yaxis.labelpad = -20 -# shift_y = (ax.get_ylim()[0] - ax.get_ylim()[1]) * 0.01 -# ax.yaxis.get_offset_text().set_position((0, shift_y)) - -# return ax - - -# # class OOMFormatter(matplotlib.ticker.ScalarFormatter): -# # def __init__(self, order=0, fformat="%1.1f", offset=True, mathText=True): -# # self.order = order -# # self.fformat = fformat -# # matplotlib.ticker.ScalarFormatter.__init__( -# # self, useOffset=offset, useMathText=mathText -# # ) - -# # def _set_order_of_magnitude(self): -# # self.orderOfMagnitude = self.order - -# # def _set_format(self, vmin=None, vmax=None): -# # self.format = self.fformat -# # if self._useMathText: -# # self.format = r"$\mathdefault{%s}$" % self.format - - -# # def sci_note(ax, fformat="%1.1f", x=False, y=False, scilimits=(-3, 3)): -# # order_x = 0 -# # order_y = 0 - -# # if x: -# # order_x = np.floor(np.log10(np.max(np.abs(ax.get_xlim())) + 1e-5)) -# # ax.xaxis.set_major_formatter( -# # OOMFormatter(order=int(order_x), fformat=fformat) -# # ) -# # ax.ticklabel_format(axis="x", style="sci", scilimits=scilimits) - -# # if y: -# # order_y = np.floor(np.log10(np.max(np.abs(ax.get_ylim()) + 1e-5))) -# # ax.yaxis.set_major_formatter( -# # OOMFormatter(order=int(order_y), fformat=fformat) -# # ) -# # ax.ticklabel_format(axis="y", style="sci", scilimits=scilimits) - -# # return ax - - -# # #!/usr/bin/env python3 - - -# # import matplotlib - - -# # class OOMFormatter(matplotlib.ticker.ScalarFormatter): -# # # https://stackoverflow.com/questions/42656139/set-scientific-notation-with-fixed-exponent-and-significant-digits-for-multiple -# # # def __init__(self, order=0, fformat="%1.1f", offset=True, mathText=True): -# # def __init__(self, order=0, fformat="%1.0d", offset=True, mathText=True): -# # self.oom = order -# # self.fformat = fformat -# # matplotlib.ticker.ScalarFormatter.__init__( -# # self, useOffset=offset, useMathText=mathText -# # ) - -# # def _set_order_of_magnitude(self): -# # self.orderOfMagnitude = self.oom - -# # def _set_format(self, vmin=None, vmax=None): -# # self.format = self.fformat -# # if self._useMathText: -# # self.format = r"$\mathdefault{%s}$" % self.format - - -# # def sci_note( -# # ax, -# # order, -# # fformat="%1.0d", -# # x=False, -# # y=False, -# # scilimits=(-3, 3), -# # ): -# # """ -# # Change the expression of the x- or y-axis to the scientific notation like *10^3 -# # , where 3 is the first argument, order. - -# # Example: -# # order = 4 # 10^4 -# # ax = sci_note( -# # ax, -# # order, -# # fformat="%1.0d", -# # x=True, -# # y=False, -# # scilimits=(-3, 3), -# # """ - -# # if x == True: -# # ax.xaxis.set_major_formatter( -# # OOMFormatter(order=order, fformat=fformat) -# # ) -# # ax.ticklabel_format(axis="x", style="sci", scilimits=scilimits) -# # if y == True: -# # ax.yaxis.set_major_formatter( -# # OOMFormatter(order=order, fformat=fformat) -# # ) -# # ax.ticklabel_format(axis="y", style="sci", scilimits=scilimits) - -# # return ax - -# EOF diff --git a/src/scitex/plt/ax/_style/_set_log_scale.py b/src/scitex/plt/ax/_style/_set_log_scale.py deleted file mode 100755 index ca31d67fe..000000000 --- a/src/scitex/plt/ax/_style/_set_log_scale.py +++ /dev/null @@ -1,335 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2025-06-04 11:10:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_style/_set_log_scale.py - -""" -Functionality: - Set logarithmic scale with proper minor ticks for scientific plots -Input: - Matplotlib axes object and scale parameters -Output: - Axes with properly configured logarithmic scale -Prerequisites: - matplotlib, numpy -""" - -from typing import List, Optional, Union - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.ticker import LogFormatter, LogLocator, NullFormatter - - -def set_log_scale( - ax, - axis: str = "both", - base: Union[int, float] = 10, - show_minor_ticks: bool = True, - minor_tick_length: float = 2.0, - major_tick_length: float = 4.0, - minor_tick_width: float = 0.5, - major_tick_width: float = 0.8, - grid: bool = False, - minor_grid: bool = False, - grid_alpha: float = 0.3, - minor_grid_alpha: float = 0.15, - format_minor_labels: bool = False, - scientific_notation: bool = True, -) -> object: - """ - Set logarithmic scale with comprehensive minor tick support. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The axes object to modify - axis : str, optional - Which axis to set: 'x', 'y', or 'both', by default 'both' - base : Union[int, float], optional - Logarithmic base, by default 10 - show_minor_ticks : bool, optional - Whether to show minor ticks, by default True - minor_tick_length : float, optional - Length of minor ticks in points, by default 2.0 - major_tick_length : float, optional - Length of major ticks in points, by default 4.0 - minor_tick_width : float, optional - Width of minor ticks in points, by default 0.5 - major_tick_width : float, optional - Width of major ticks in points, by default 0.8 - grid : bool, optional - Whether to show major grid lines, by default False - minor_grid : bool, optional - Whether to show minor grid lines, by default False - grid_alpha : float, optional - Alpha for major grid lines, by default 0.3 - minor_grid_alpha : float, optional - Alpha for minor grid lines, by default 0.15 - format_minor_labels : bool, optional - Whether to show labels on minor ticks, by default False - scientific_notation : bool, optional - Whether to use scientific notation for labels, by default True - - Returns - ------- - matplotlib.axes.Axes - The modified axes object - - Examples - -------- - >>> fig, ax = plt.subplots() - >>> ax.semilogy([1, 10, 100, 1000], [1, 2, 3, 4]) - >>> set_log_scale(ax, axis='y', show_minor_ticks=True, grid=True) - """ - - if axis in ["x", "both"]: - _configure_log_axis( - ax, - "x", - base, - show_minor_ticks, - minor_tick_length, - major_tick_length, - minor_tick_width, - major_tick_width, - grid, - minor_grid, - grid_alpha, - minor_grid_alpha, - format_minor_labels, - scientific_notation, - ) - - if axis in ["y", "both"]: - _configure_log_axis( - ax, - "y", - base, - show_minor_ticks, - minor_tick_length, - major_tick_length, - minor_tick_width, - major_tick_width, - grid, - minor_grid, - grid_alpha, - minor_grid_alpha, - format_minor_labels, - scientific_notation, - ) - - return ax - - -def _configure_log_axis( - ax, - axis_name: str, - base: Union[int, float], - show_minor_ticks: bool, - minor_tick_length: float, - major_tick_length: float, - minor_tick_width: float, - major_tick_width: float, - grid: bool, - minor_grid: bool, - grid_alpha: float, - minor_grid_alpha: float, - format_minor_labels: bool, - scientific_notation: bool, -) -> None: - """Configure a single axis for logarithmic scale.""" - - # Set the logarithmic scale - if axis_name == "x": - ax.set_xscale("log", base=base) - axis_obj = ax.xaxis - tick_params_kwargs = {"axis": "x"} - else: # y-axis - ax.set_yscale("log", base=base) - axis_obj = ax.yaxis - tick_params_kwargs = {"axis": "y"} - - # Configure major ticks - major_locator = LogLocator(base=base, numticks=12) - axis_obj.set_major_locator(major_locator) - - # Configure major tick formatting - if scientific_notation: - major_formatter = LogFormatter(base=base, labelOnlyBase=False) - else: - major_formatter = LogFormatter(base=base, labelOnlyBase=True) - axis_obj.set_major_formatter(major_formatter) - - # Configure minor ticks - if show_minor_ticks: - # Create minor tick positions - minor_locator = LogLocator(base=base, subs="all", numticks=100) - axis_obj.set_minor_locator(minor_locator) - - # Format minor tick labels - if format_minor_labels: - minor_formatter = LogFormatter(base=base, labelOnlyBase=False) - else: - minor_formatter = NullFormatter() # No labels on minor ticks - axis_obj.set_minor_formatter(minor_formatter) - - # Set minor tick appearance - ax.tick_params( - which="minor", - length=minor_tick_length, - width=minor_tick_width, - **tick_params_kwargs, - ) - - # Set major tick appearance - ax.tick_params( - which="major", - length=major_tick_length, - width=major_tick_width, - **tick_params_kwargs, - ) - - # Configure grid - if grid or minor_grid: - ax.grid(True, which="major", alpha=grid_alpha if grid else 0) - if minor_grid and show_minor_ticks: - ax.grid(True, which="minor", alpha=minor_grid_alpha) - - -def smart_log_limits( - data: Union[List, np.ndarray], - axis: str = "y", - base: Union[int, float] = 10, - padding_factor: float = 0.1, - min_decades: int = 1, -) -> tuple: - """ - Calculate smart logarithmic axis limits based on data. - - Parameters - ---------- - data : Union[List, np.ndarray] - Data values to calculate limits from - axis : str, optional - Axis name for reference, by default 'y' - base : Union[int, float], optional - Logarithmic base, by default 10 - padding_factor : float, optional - Padding as fraction of data range, by default 0.1 - min_decades : int, optional - Minimum number of decades to show, by default 1 - - Returns - ------- - tuple - (lower_limit, upper_limit) - - Examples - -------- - >>> smart_log_limits([1, 10, 100, 1000]) - (0.1, 10000.0) - """ - data_array = np.array(data) - positive_data = data_array[data_array > 0] - - if len(positive_data) == 0: - return 1, base**min_decades - - data_min = np.min(positive_data) - data_max = np.max(positive_data) - - # Calculate log range - log_min = np.log(data_min) / np.log(base) - log_max = np.log(data_max) / np.log(base) - log_range = log_max - log_min - - # Ensure minimum range - if log_range < min_decades: - log_center = (log_min + log_max) / 2 - log_min = log_center - min_decades / 2 - log_max = log_center + min_decades / 2 - log_range = min_decades - - # Add padding - padding = log_range * padding_factor - log_min_padded = log_min - padding - log_max_padded = log_max + padding - - # Convert back to linear scale - lower_limit = base**log_min_padded - upper_limit = base**log_max_padded - - return lower_limit, upper_limit - - -def add_log_scale_indicator( - ax, - axis: str = "y", - base: Union[int, float] = 10, - position: str = "auto", - fontsize: Union[str, int] = "small", - color: str = "gray", - alpha: float = 0.7, -) -> None: - """ - Add a log scale indicator to the plot. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The axes object - axis : str, optional - Which axis has log scale, by default 'y' - base : Union[int, float], optional - Logarithmic base, by default 10 - position : str, optional - Position of indicator: 'auto', 'top-left', 'top-right', 'bottom-left', 'bottom-right', by default 'auto' - fontsize : Union[str, int], optional - Font size for indicator, by default 'small' - color : str, optional - Color of indicator text, by default 'gray' - alpha : float, optional - Alpha transparency, by default 0.7 - - Examples - -------- - >>> add_log_scale_indicator(ax, axis='y', base=10) - """ - # Determine position - if position == "auto": - if axis == "y": - position = "top-left" - else: - position = "bottom-right" - - # Position mapping - positions = { - "top-left": (0.05, 0.95), - "top-right": (0.95, 0.95), - "bottom-left": (0.05, 0.05), - "bottom-right": (0.95, 0.05), - } - - x_pos, y_pos = positions.get(position, (0.05, 0.95)) - - # Create indicator text - if base == 10: - indicator_text = f"Log₁₀ scale ({axis}-axis)" - else: - indicator_text = f"Log_{{{base}}} scale ({axis}-axis)" - - # Add text - ax.text( - x_pos, - y_pos, - indicator_text, - transform=ax.transAxes, - fontsize=fontsize, - color=color, - alpha=alpha, - bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), - ) - - -# EOF diff --git a/src/scitex/plt/ax/_style/_set_meta.py b/src/scitex/plt/ax/_style/_set_meta.py deleted file mode 100755 index 896a73c1e..000000000 --- a/src/scitex/plt/ax/_style/_set_meta.py +++ /dev/null @@ -1,294 +0,0 @@ -#!./env/bin/python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2025-06-04 11:35:00 (ywatanabe)" -# Author: Yusuke Watanabe (ywatanabe@scitex.ai) - -""" -Scientific metadata management for figures with YAML export. -""" - -from typing import Any, Dict, List, Optional - -# Imports -import yaml - - -# Functions -def set_meta( - ax, - caption=None, - methods=None, - stats=None, - keywords=None, - experimental_details=None, - journal_style=None, - significance=None, - **kwargs, -): - """Set comprehensive scientific metadata for figures with YAML export - - Parameters - ---------- - ax : matplotlib.axes.Axes or scitex AxisWrapper - The axes to modify - caption : str, optional - Figure caption text - methods : str, optional - Experimental methods description - stats : str, optional - Statistical analysis details - keywords : List[str], optional - Keywords for categorization and search - experimental_details : Dict[str, Any], optional - Structured experimental parameters (n_samples, temperature, etc.) - journal_style : str, optional - Target journal style ('nature', 'science', 'ieee', 'cell', etc.) - significance : str, optional - Significance statement or implications - **kwargs : additional metadata - Any additional metadata fields - - Returns - ------- - ax : matplotlib.axes.Axes or scitex AxisWrapper - The modified axes - - Examples - -------- - >>> fig, ax = scitex.plt.subplots() - >>> ax.plot(x, y, id='neural_data') - >>> ax.set_xyt(x='Time (ms)', y='Voltage (mV)', t='Neural Recording') - >>> ax.set_meta( - ... caption='Intracellular recording showing action potentials.', - ... methods='Whole-cell patch-clamp in acute brain slices.', - ... stats='Statistical analysis using paired t-test (p<0.05).', - ... keywords=['electrophysiology', 'neural_recording', 'patch_clamp'], - ... experimental_details={ - ... 'n_samples': 15, - ... 'temperature': 32, - ... 'recording_duration': 600, - ... 'electrode_resistance': '3-5 MΩ' - ... }, - ... journal_style='nature', - ... significance='Demonstrates novel neural dynamics in layer 2/3 pyramidal cells.' - ... ) - >>> scitex.io.save(fig, 'neural_recording.png') # YAML metadata auto-saved - """ - - # Build comprehensive metadata dictionary - metadata = {} - - if caption is not None: - metadata["caption"] = caption - if methods is not None: - metadata["methods"] = methods - if stats is not None: - metadata["statistical_analysis"] = stats - if keywords is not None: - metadata["keywords"] = keywords if isinstance(keywords, list) else [keywords] - if experimental_details is not None: - metadata["experimental_details"] = experimental_details - if journal_style is not None: - metadata["journal_style"] = journal_style - if significance is not None: - metadata["significance"] = significance - - # Add any additional metadata - for key, value in kwargs.items(): - if value is not None: - metadata[key] = value - - # Add automatic metadata - import datetime - - metadata["created_timestamp"] = datetime.datetime.now().isoformat() - - # Get version dynamically - try: - import scitex - - metadata["scitex_version"] = getattr(scitex, "__version__", "unknown") - except ImportError: - metadata["scitex_version"] = "unknown" - - # Store metadata in figure for automatic saving - fig = ax.get_figure() - if not hasattr(fig, "_scitex_metadata"): - fig._scitex_metadata = {} - - # Use axis as key for panel-specific metadata - fig._scitex_metadata[ax] = metadata - - # Also store as YAML-ready structure - if not hasattr(fig, "_scitex_yaml_metadata"): - fig._scitex_yaml_metadata = {} - fig._scitex_yaml_metadata[ax] = metadata - - # Backward compatibility - store simple caption - if caption is not None: - if not hasattr(fig, "_scitex_captions"): - fig._scitex_captions = {} - fig._scitex_captions[ax] = caption - - return ax - - -def set_figure_meta( - ax, - caption=None, - methods=None, - stats=None, - significance=None, - funding=None, - conflicts=None, - data_availability=None, - **kwargs, -): - """Set figure-level metadata for multi-panel figures - - Parameters - ---------- - ax : matplotlib.axes.Axes or scitex AxisWrapper - Any axis in the figure (figure accessed via ax.get_figure()) - caption : str, optional - Figure-level caption - methods : str, optional - Overall experimental methods - stats : str, optional - Overall statistical approach - significance : str, optional - Significance and implications - funding : str, optional - Funding acknowledgments - conflicts : str, optional - Conflict of interest statement - data_availability : str, optional - Data availability statement - **kwargs : additional metadata - Any additional figure-level metadata - - Returns - ------- - ax : matplotlib.axes.Axes or scitex AxisWrapper - The modified axes - - Examples - -------- - >>> fig, ((ax1, ax2), (ax3, ax4)) = scitex.plt.subplots(2, 2) - >>> # Set individual panel metadata... - >>> ax1.set_meta(caption='Panel A analysis...') - >>> ax2.set_meta(caption='Panel B comparison...') - >>> - >>> # Set figure-level metadata - >>> ax1.set_figure_meta( - ... caption='Comprehensive analysis of neural dynamics...', - ... significance='This work demonstrates novel therapeutic targets.', - ... funding='Supported by NIH grant R01-NS123456.', - ... data_availability='Data available at doi:10.5061/dryad.example' - ... ) - """ - - # Build figure-level metadata - figure_metadata = {} - - if caption is not None: - figure_metadata["main_caption"] = caption - if methods is not None: - figure_metadata["overall_methods"] = methods - if stats is not None: - figure_metadata["overall_statistics"] = stats - if significance is not None: - figure_metadata["significance"] = significance - if funding is not None: - figure_metadata["funding"] = funding - if conflicts is not None: - figure_metadata["conflicts_of_interest"] = conflicts - if data_availability is not None: - figure_metadata["data_availability"] = data_availability - - # Add any additional metadata - for key, value in kwargs.items(): - if value is not None: - figure_metadata[key] = value - - # Add automatic metadata - import datetime - - figure_metadata["created_timestamp"] = datetime.datetime.now().isoformat() - - # Store in figure - fig = ax.get_figure() - fig._scitex_figure_metadata = figure_metadata - - # Backward compatibility - if caption is not None: - fig._scitex_main_caption = caption - - return ax - - -def export_metadata_yaml(fig, filepath): - """Export all figure metadata to YAML file - - Parameters - ---------- - fig : matplotlib.figure.Figure - Figure with metadata - filepath : str - Output YAML file path - """ - import datetime - - # Collect all metadata - export_data = { - "figure_metadata": {}, - "panel_metadata": {}, - "export_info": { - "timestamp": datetime.datetime.now().isoformat(), - "scitex_version": "1.11.0", - }, - } - - # Figure-level metadata - if hasattr(fig, "_scitex_figure_metadata"): - export_data["figure_metadata"] = fig._scitex_figure_metadata - - # Panel-level metadata - if hasattr(fig, "_scitex_yaml_metadata"): - for i, (ax, metadata) in enumerate(fig._scitex_yaml_metadata.items()): - panel_key = f"panel_{i + 1}" - export_data["panel_metadata"][panel_key] = metadata - - # Write YAML file - with open(filepath, "w") as f: - yaml.dump(export_data, f, default_flow_style=False, sort_keys=False, indent=2) - - -if __name__ == "__main__": - # Start - import sys - - import matplotlib.pyplot as plt - - import scitex - - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - # Example usage - fig, ax = plt.subplots() - ax.plot([1, 2, 3], [1, 4, 2]) - - set_meta( - ax, - caption="Example figure showing data trends.", - methods="Synthetic data generated for demonstration.", - keywords=["example", "demo", "synthetic"], - experimental_details={"n_samples": 3, "data_type": "synthetic"}, - ) - - export_metadata_yaml(fig, "example_metadata.yaml") - - # Close - scitex.session.close(CONFIG) - -# EOF diff --git a/src/scitex/plt/ax/_style/_set_n_ticks.py b/src/scitex/plt/ax/_style/_set_n_ticks.py deleted file mode 100755 index 9efa12f20..000000000 --- a/src/scitex/plt/ax/_style/_set_n_ticks.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-04-29 12:02:14 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_set_n_ticks.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_set_n_ticks.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib - - -def set_n_ticks( - ax, - n_xticks=4, - n_yticks=4, -): - """ - Example: - ax = set_n_ticks(ax) - """ - - if n_xticks is not None: - ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(n_xticks)) - - if n_yticks is not None: - ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(n_yticks)) - - # Force the figure to redraw to reflect changes - ax.figure.canvas.draw() - - return ax - - -# EOF diff --git a/src/scitex/plt/ax/_style/_set_size.py b/src/scitex/plt/ax/_style/_set_size.py deleted file mode 100755 index 999413380..000000000 --- a/src/scitex/plt/ax/_style/_set_size.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2022-12-09 13:38:11 (ywatanabe)" - - -def set_size(ax, w, h): - """w, h: width, height in inches""" - # if not ax: ax=plt.gca() - l = ax.figure.subplotpars.left - r = ax.figure.subplotpars.right - t = ax.figure.subplotpars.top - b = ax.figure.subplotpars.bottom - figw = float(w) / (r - l) - figh = float(h) / (t - b) - ax.figure.set_size_inches(figw, figh) - return ax diff --git a/src/scitex/plt/ax/_style/_set_supxyt.py b/src/scitex/plt/ax/_style/_set_supxyt.py deleted file mode 100755 index b1be5943e..000000000 --- a/src/scitex/plt/ax/_style/_set_supxyt.py +++ /dev/null @@ -1,133 +0,0 @@ -#!./env/bin/python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-07-13 07:56:46 (ywatanabe)" -# Author: Yusuke Watanabe (ywatanabe@scitex.ai) - -""" -This script does XYZ. -""" - -# Imports -import matplotlib.pyplot as plt - -from ._format_label import format_label - - -# Functions -def set_supxyt(ax, xlabel=False, ylabel=False, title=False, format_labels=True): - """Sets xlabel, ylabel and title""" - fig = ax.get_figure() - - # if xlabel is not False: - # fig.supxlabel(xlabel) - - # if ylabel is not False: - # fig.supylabel(ylabel) - - # if title is not False: - # fig.suptitle(title) - if xlabel is not False: - xlabel = format_label(xlabel) if format_labels else xlabel - fig.supxlabel(xlabel) - - if ylabel is not False: - ylabel = format_label(ylabel) if format_labels else ylabel - fig.supylabel(ylabel) - - if title is not False: - title = format_label(title) if format_labels else title - fig.suptitle(title) - - return ax - - -def set_supxytc( - ax, - xlabel=False, - ylabel=False, - title=False, - caption=False, - methods=False, - stats=False, - significance=False, - format_labels=True, -): - """Sets figure-level xlabel, ylabel, title, and caption with SciTeX-Paper integration - - Parameters - ---------- - ax : matplotlib.axes.Axes or scitex AxisWrapper - The axes to modify (figure accessed via ax.get_figure()) - xlabel : str or False, optional - Figure-level X-axis label, by default False - ylabel : str or False, optional - Figure-level Y-axis label, by default False - title : str or False, optional - Figure-level title (suptitle), by default False - caption : str or False, optional - Figure-level caption to store for later use with scitex.io.save(), by default False - methods : str or False, optional - Overall methods description for SciTeX-Paper integration, by default False - stats : str or False, optional - Overall statistical analysis details for SciTeX-Paper integration, by default False - significance : str or False, optional - Significance statement for SciTeX-Paper integration, by default False - format_labels : bool, optional - Whether to apply automatic formatting, by default True - - Returns - ------- - ax : matplotlib.axes.Axes or scitex AxisWrapper - The modified axes - - Examples - -------- - >>> fig, ((ax1, ax2), (ax3, ax4)) = scitex.plt.subplots(2, 2) - >>> # Add plots to each panel... - >>> ax1.set_supxytc(xlabel='Time (s)', ylabel='Signal Amplitude', - ... title='Multi-Panel Analysis', - ... caption='Comprehensive analysis showing (A) raw data, (B) filtered signal, (C) power spectrum, and (D) phase analysis.', - ... methods='All experiments performed using standardized protocols.', - ... significance='This work demonstrates novel therapeutic targets.') - >>> scitex.io.save(fig, 'multi_panel.png') # Caption automatically saved - """ - # Set labels and title using existing function - set_supxyt( - ax, xlabel=xlabel, ylabel=ylabel, title=title, format_labels=format_labels - ) - - # Store figure-level caption and extended metadata - if ( - caption is not False - or methods is not False - or stats is not False - or significance is not False - ): - fig = ax.get_figure() - # Store comprehensive figure-level metadata - fig_metadata = { - "main_caption": caption if caption is not False else None, - "methods": methods if methods is not False else None, - "stats": stats if stats is not False else None, - "significance": significance if significance is not False else None, - } - - fig._scitex_figure_metadata = fig_metadata - - # Backward compatibility - also store simple caption - if caption is not False: - fig._scitex_main_caption = caption - - return ax - - -if __name__ == "__main__": - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - # (YOUR AWESOME CODE) - - # Close - scitex.session.close(CONFIG) - -# EOF diff --git a/src/scitex/plt/ax/_style/_set_ticks.py b/src/scitex/plt/ax/_style/_set_ticks.py deleted file mode 100755 index 21bdc72cf..000000000 --- a/src/scitex/plt/ax/_style/_set_ticks.py +++ /dev/null @@ -1,276 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-04-27 20:04:55 (ywatanabe)" -# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_set_ticks.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_style/_set_ticks.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib.pyplot as plt -import numpy as np - -from scitex.dict._to_str import to_str -from scitex.types import is_listed_X - - -def set_ticks(ax, xvals=None, xticks=None, yvals=None, yticks=None): - """Set custom tick labels on both x and y axes. - - Convenience function to set tick positions and labels for both axes - at once. Automatically handles canvas updates for interactive backends. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The axes object to modify. - xvals : array-like, optional - Values corresponding to x-axis data points. - xticks : list, optional - Desired tick labels for the x-axis. - yvals : array-like, optional - Values corresponding to y-axis data points. - yticks : list, optional - Desired tick labels for the y-axis. - - Returns - ------- - matplotlib.axes.Axes - The modified axes object. - - Examples - -------- - >>> fig, ax = plt.subplots() - >>> x = np.linspace(0, 10, 100) - >>> ax.plot(x, np.sin(x)) - >>> ax = set_ticks(ax, xvals=x, xticks=[0, 5, 10], - ... yvals=[-1, 0, 1], yticks=['-1', '0', '1']) - - See Also - -------- - set_x_ticks : Set ticks for x-axis only - set_y_ticks : Set ticks for y-axis only - """ - ax = set_x_ticks(ax, x_vals=xvals, x_ticks=xticks) - ax = set_y_ticks(ax, y_vals=yvals, y_ticks=yticks) - canvas_type = type(ax.figure.canvas).__name__ - if "TkAgg" in canvas_type: - ax.get_figure().canvas.draw() # Redraw the canvas once after making all updates - return ax - - -def set_x_ticks(ax, x_vals=None, x_ticks=None): - """ - Set custom tick labels on the x and y axes based on specified values and desired ticks. - - Parameters: - - ax: The axis object to modify. - - x_vals: Array of x-axis values. - - x_ticks: List of desired tick labels on the x-axis. - - y_vals: Array of y-axis values. - - y_ticks: List of desired tick labels on the y-axis. - - Example: - import matplotlib.pyplot as plt - import numpy as np - - fig, axes = plt.subplots(nrows=4) - x = np.linspace(0, 10, 100) - y = np.sin(x) - for ax in axes: - ax.plot(x, y) # Plot a sine wave - - set_ticks(axes[0]) # Do nothing # OK - set_ticks(axes[1], x_vals=x+3) # OK - set_ticks(axes[2], x_ticks=[1,2]) # OK - set_ticks(axes[3], x_vals=x+3, x_ticks=[4,5]) # Auto-generate ticks across the range - fig.tight_layout() - plt.show() - """ - - def _avoid_overlaps(values): - values = np.array(values) - if ("int" in str(values.dtype)) or ("float" in str(values.dtype)): - values = values.astype(float) + np.arange(len(values)) * 1e-5 - return values - - def _set_x_vals(ax, x_vals): - x_vals = _avoid_overlaps(x_vals) - new_x_axis = np.linspace(*ax.get_xlim(), len(x_vals)) - ax.set_xticks(new_x_axis) - ax.set_xticklabels([f"{xv}" for xv in x_vals]) - return ax - - def _set_x_ticks(ax, x_ticks): - x_ticks = np.array(x_ticks) - if x_ticks.dtype.kind in ["U", "S", "O"]: # If x_ticks are strings - ax.set_xticks(range(len(x_ticks))) - ax.set_xticklabels(x_ticks) - else: - x_vals = np.array( - [label.get_text().replace("−", "-") for label in ax.get_xticklabels()] - ) - x_vals = x_vals.astype(float) - x_indi = np.argmin( - np.array(np.abs(x_vals[:, np.newaxis] - x_ticks[np.newaxis, :])), - axis=0, - ) - ax.set_xticks(ax.get_xticks()[x_indi]) - ax.set_xticklabels([f"{xt}" for xt in x_ticks]) - return ax - - x_vals_passed = x_vals is not None - x_ticks_passed = x_ticks is not None - - if is_listed_X(x_ticks, dict): - x_ticks = [to_str(xt, delimiter="\n") for xt in x_ticks] - - if (not x_vals_passed) and (not x_ticks_passed): - # Do nothing - pass - - elif x_vals_passed and (not x_ticks_passed): - # Replaces the x axis to x_vals - x_ticks = np.linspace(x_vals[0], x_vals[-1], 4) - ax = _set_x_vals(ax, x_ticks) - - elif (not x_vals_passed) and x_ticks_passed: - # Locates 'x_ticks' on the original x axis - ax.set_xticks(x_ticks) - - elif x_vals_passed and x_ticks_passed: - if isinstance(x_vals, str): - if x_vals == "auto": - x_vals = np.arange(len(x_ticks)) - - # Replaces the original x axis to 'x_vals' and locates the 'x_ticks' on the new axis - ax = _set_x_vals(ax, x_vals) - ax = _set_x_ticks(ax, x_ticks) - - return ax - - -def set_y_ticks(ax, y_vals=None, y_ticks=None): - """ - Set custom tick labels on the y-axis based on specified values and desired ticks. - - Parameters: - - ax: The axis object to modify. - - y_vals: Array of y-axis values where ticks should be placed. - - y_ticks: List of labels for ticks on the y-axis. - - Example: - import matplotlib.pyplot as plt - import numpy as np - - fig, ax = plt.subplots() - x = np.linspace(0, 10, 100) - y = np.sin(x) - ax.plot(x, y) # Plot a sine wave - - set_y_ticks(ax, y_vals=y, y_ticks=['Low', 'High']) # Set custom y-axis ticks - plt.show() - """ - - def _avoid_overlaps(values): - values = np.array(values) - if ("int" in str(values.dtype)) or ("float" in str(values.dtype)): - values = values.astype(float) + np.arange(len(values)) * 1e-5 - return values - - def _set_y_vals(ax, y_vals): - y_vals = _avoid_overlaps(y_vals) - new_y_axis = np.linspace(*ax.get_ylim(), len(y_vals)) - ax.set_yticks(new_y_axis) - ax.set_yticklabels([f"{yv:.2f}" for yv in y_vals]) - return ax - - # def _set_y_ticks(ax, y_ticks): - # y_ticks = np.array(y_ticks) - # y_vals = np.array( - # [ - # label.get_text().replace("−", "-") - # for label in ax.get_yticklabels() - # ] - # ) - # y_vals = y_vals.astype(float) - # y_indi = np.argmin( - # np.array(np.abs(y_vals[:, np.newaxis] - y_ticks[np.newaxis, :])), - # axis=0, - # ) - - # # y_indi = [np.argmin(np.abs(y_vals - yt)) for yt in y_ticks] - # ax.set_yticks(ax.get_yticks()[y_indi]) - # ax.set_yticklabels([f"{yt}" for yt in y_ticks]) - # return ax - def _set_y_ticks(ax, y_ticks): - y_ticks = np.array(y_ticks) - if y_ticks.dtype.kind in ["U", "S", "O"]: # If y_ticks are strings - ax.set_yticks(range(len(y_ticks))) - ax.set_yticklabels(y_ticks) - else: - y_vals = np.array( - [label.get_text().replace("−", "-") for label in ax.get_yticklabels()] - ) - y_vals = y_vals.astype(float) - y_indi = np.argmin( - np.array(np.abs(y_vals[:, np.newaxis] - y_ticks[np.newaxis, :])), - axis=0, - ) - ax.set_yticks(ax.get_yticks()[y_indi]) - ax.set_yticklabels([f"{yt}" for yt in y_ticks]) - return ax - - y_vals_passed = y_vals is not None - y_ticks_passed = y_ticks is not None - - if is_listed_X(y_ticks, dict): - y_ticks = [to_str(yt, delimiter="\n") for yt in y_ticks] - - if (not y_vals_passed) and (not y_ticks_passed): - # Do nothing - pass - - elif y_vals_passed and (not y_ticks_passed): - # Replaces the y axis to y_vals - ax = _set_y_vals(ax, y_vals) - - elif (not y_vals_passed) and y_ticks_passed: - # Locates 'y_ticks' on the original y axis - ax.set_yticks(y_ticks) - - elif y_vals_passed and y_ticks_passed: - # Replaces the original y axis to 'y_vals' and locates the 'y_ticks' on the new axis - if y_vals == "auto": - y_vals = np.arange(len(y_ticks)) - - ax = _set_y_vals(ax, y_vals) - ax = _set_y_ticks(ax, y_ticks) - return ax - - -if __name__ == "__main__": - import scitex - - xx, tt, fs = scitex.dsp.demo_sig() - pha, amp, freqs = scitex.dsp.wavelet(xx, fs) - - i_batch, i_ch = 0, 0 - ff = freqs[i_batch, i_ch] - fig, ax = scitex.plt.subplots() - - ax.image2d(amp[i_batch, i_ch]) - - ax = set_ticks( - ax, - x_vals=tt, - x_ticks=[0, 1, 2, 3, 4], - y_vals=ff, - y_ticks=[0, 128, 256], - ) - - plt.show() - -# EOF diff --git a/src/scitex/plt/ax/_style/_set_xyt.py b/src/scitex/plt/ax/_style/_set_xyt.py deleted file mode 100755 index b937e6bc6..000000000 --- a/src/scitex/plt/ax/_style/_set_xyt.py +++ /dev/null @@ -1,130 +0,0 @@ -#!./env/bin/python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-07-13 08:14:19 (ywatanabe)" -# Author: Yusuke Watanabe (ywatanabe@scitex.ai) - -""" -This script does XYZ. -""" - -# Imports -import matplotlib.pyplot as plt - -from ._format_label import format_label - - -# Functions -def set_xyt(ax, x=False, y=False, t=False, format_labels=True): - """Sets xlabel, ylabel and title""" - - if x is not False: - x = format_label(x) if format_labels else x - ax.set_xlabel(x) - - if y is not False: - y = format_label(y) if format_labels else y - ax.set_ylabel(y) - - if t is not False: - t = format_label(t) if format_labels else t - ax.set_title(t) - - return ax - - -def set_xytc( - ax, - x=False, - y=False, - t=False, - c=False, - methods=False, - stats=False, - format_labels=True, -): - """Sets xlabel, ylabel, title, and caption with SciTeX-Paper integration - - Parameters - ---------- - ax : matplotlib.axes.Axes or scitex AxisWrapper - The axes to modify - x : str or False, optional - X-axis label, by default False - y : str or False, optional - Y-axis label, by default False - t : str or False, optional - Title, by default False - c : str or False, optional - Caption to store for later use with scitex.io.save(), by default False - methods : str or False, optional - Methods description for SciTeX-Paper integration, by default False - stats : str or False, optional - Statistical analysis details for SciTeX-Paper integration, by default False - format_labels : bool, optional - Whether to apply automatic formatting, by default True - - Returns - ------- - ax : matplotlib.axes.Axes or scitex AxisWrapper - The modified axes - - Examples - -------- - >>> fig, ax = scitex.plt.subplots() - >>> ax.plot(x, y) - >>> ax.set_xytc(x='Time (s)', y='Voltage (mV)', - ... t='Neural Signal', - ... c='Example neural recording showing action potentials.', - ... methods='Intracellular recordings performed using patch-clamp technique.', - ... stats='Data analyzed using t-test with p<0.05 significance.') - >>> scitex.io.save(fig, 'neural_signal.png') # Caption automatically saved - """ - # Set labels and title using existing function - set_xyt(ax, x=x, y=y, t=t, format_labels=format_labels) - - # Store caption and extended metadata for later use by scitex.io.save() - if c is not False or methods is not False or stats is not False: - # Store comprehensive metadata as axis attribute for retrieval by save function - metadata = { - "caption": c if c is not False else None, - "methods": methods if methods is not False else None, - "stats": stats if stats is not False else None, - } - - if hasattr(ax, "_scitex_metadata"): - ax._scitex_metadata.update(metadata) - else: - # For matplotlib axes, store in figure metadata - fig = ax.get_figure() - if not hasattr(fig, "_scitex_metadata"): - fig._scitex_metadata = {} - # Use axis position as identifier - fig._scitex_metadata[ax] = metadata - - # Backward compatibility - also store simple caption - if c is not False: - if hasattr(ax, "_scitex_caption"): - ax._scitex_caption = c - else: - fig = ax.get_figure() - if not hasattr(fig, "_scitex_captions"): - fig._scitex_captions = {} - fig._scitex_captions[ax] = c - - return ax - - -if __name__ == "__main__": - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - # (YOUR AWESOME CODE) - - # Close - scitex.session.close(CONFIG) - -# EOF - -""" -/ssh:ywatanabe@444:/home/ywatanabe/proj/entrance/scitex/plt/ax/_set_lt.py -""" diff --git a/src/scitex/plt/ax/_style/_share_axes.py b/src/scitex/plt/ax/_style/_share_axes.py deleted file mode 100755 index 934d4b7bc..000000000 --- a/src/scitex/plt/ax/_style/_share_axes.py +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-01 08:47:27 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_share_axes.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_style/_share_axes.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib.pyplot as plt -import numpy as np - -import scitex - - -def sharexy(*multiple_axes): - """Share both x and y axis limits across multiple axes. - - Synchronizes both x and y axis limits across all provided axes objects, - ensuring they all display the same data range. Useful for comparing - multiple plots on the same scale. - - Parameters - ---------- - *multiple_axes : matplotlib.axes.Axes or array of Axes - Variable number of axes objects to synchronize. - - Examples - -------- - >>> fig, (ax1, ax2, ax3) = plt.subplots(1, 3) - >>> ax1.plot([1, 2, 3], [1, 4, 9]) - >>> ax2.plot([1, 2, 3], [2, 5, 8]) - >>> ax3.plot([1, 2, 3], [3, 6, 10]) - >>> sharexy(ax1, ax2, ax3) # All axes now show same range - - See Also - -------- - sharex : Share only x-axis limits - sharey : Share only y-axis limits - """ - sharex(*multiple_axes) - sharey(*multiple_axes) - - -def sharex(*multiple_axes): - """Share x-axis limits across multiple axes. - - Finds the global x-axis limits across all axes and applies them - to each axis, ensuring horizontal alignment of data. - - Parameters - ---------- - *multiple_axes : matplotlib.axes.Axes or array of Axes - Variable number of axes objects to synchronize. - - Returns - ------- - axes : axes object(s) - The modified axes with shared x-limits. - xlim : tuple - The (xmin, xmax) limits applied. - - Examples - -------- - >>> fig, axes = plt.subplots(2, 1) - >>> axes[0].plot([1, 5], [1, 2]) - >>> axes[1].plot([2, 4], [3, 4]) - >>> sharex(axes[0], axes[1]) # Both show x-range [1, 5] - """ - xlim = get_global_xlim(*multiple_axes) - return set_xlims(*multiple_axes, xlim=xlim) - - -def sharey(*multiple_axes): - """Share y-axis limits across multiple axes. - - Finds the global y-axis limits across all axes and applies them - to each axis, ensuring vertical alignment of data. - - Parameters - ---------- - *multiple_axes : matplotlib.axes.Axes or array of Axes - Variable number of axes objects to synchronize. - - Returns - ------- - axes : axes object(s) - The modified axes with shared y-limits. - ylim : tuple - The (ymin, ymax) limits applied. - - Examples - -------- - >>> fig, axes = plt.subplots(1, 2) - >>> axes[0].plot([1, 2], [1, 5]) - >>> axes[1].plot([1, 2], [2, 4]) - >>> sharey(axes[0], axes[1]) # Both show y-range [1, 5] - """ - ylim = get_global_ylim(*multiple_axes) - return set_ylims(*multiple_axes, ylim=ylim) - - -def get_global_xlim(*multiple_axes): - """Get the global x-axis limits across multiple axes. - - Scans all provided axes to find the minimum and maximum x-values - across all of them. Handles both single axes and arrays of axes. - - Parameters - ---------- - *multiple_axes : matplotlib.axes.Axes or array of Axes - Variable number of axes objects to scan. - - Returns - ------- - tuple - (xmin, xmax) representing the global x-axis limits. - - Examples - -------- - >>> fig, (ax1, ax2) = plt.subplots(1, 2) - >>> ax1.plot([1, 3], [1, 2]) # x-range: [1, 3] - >>> ax2.plot([2, 5], [1, 2]) # x-range: [2, 5] - >>> xlim = get_global_xlim(ax1, ax2) - >>> print(xlim) # (1, 5) - - Notes - ----- - There appears to be a bug in the current implementation where - get_ylim() is called instead of get_xlim(). This should be fixed. - """ - xmin, xmax = np.inf, -np.inf - for axes in multiple_axes: - # axes - if isinstance(axes, (np.ndarray, scitex.plt._subplots.AxesWrapper)): - for ax in axes.flat: - _xmin, _xmax = ax.get_xlim() # Fixed: was get_ylim() - xmin = min(xmin, _xmin) - xmax = max(xmax, _xmax) - # axis - else: - ax = axes - _xmin, _xmax = ax.get_xlim() # Fixed: was get_ylim() - xmin = min(xmin, _xmin) - xmax = max(xmax, _xmax) - - return (xmin, xmax) - - -# def get_global_xlim(*multiple_axes): -# xmin, xmax = np.inf, -np.inf -# for axes in multiple_axes: -# for ax in axes.flat: -# _xmin, _xmax = ax.get_xlim() -# xmin = min(xmin, _xmin) -# xmax = max(xmax, _xmax) -# return (xmin, xmax) - - -def get_global_ylim(*multiple_axes): - """Get the global y-axis limits across multiple axes. - - Scans all provided axes to find the minimum and maximum y-values - across all of them. Handles both single axes and arrays of axes. - - Parameters - ---------- - *multiple_axes : matplotlib.axes.Axes or array of Axes - Variable number of axes objects to scan. - - Returns - ------- - tuple - (ymin, ymax) representing the global y-axis limits. - - Examples - -------- - >>> fig, (ax1, ax2) = plt.subplots(1, 2) - >>> ax1.plot([1, 2], [1, 3]) # y-range: [1, 3] - >>> ax2.plot([1, 2], [2, 5]) # y-range: [2, 5] - >>> ylim = get_global_ylim(ax1, ax2) - >>> print(ylim) # (1, 5) - """ - ymin, ymax = np.inf, -np.inf - for axes in multiple_axes: - # axes - if isinstance(axes, (np.ndarray, scitex.plt._subplots.AxesWrapper)): - for ax in axes.flat: - _ymin, _ymax = ax.get_ylim() - ymin = min(ymin, _ymin) - ymax = max(ymax, _ymax) - # axis - else: - ax = axes - _ymin, _ymax = ax.get_ylim() - ymin = min(ymin, _ymin) - ymax = max(ymax, _ymax) - - return (ymin, ymax) - - -def set_xlims(*multiple_axes, xlim=None): - if xlim is None: - raise ValueError("Please set xlim. get_global_xlim() might be useful.") - - for axes in multiple_axes: - # axes - if isinstance(axes, (np.ndarray, scitex.plt._subplots.AxesWrapper)): - for ax in axes.flat: - ax.set_xlim(xlim) - # axis - else: - ax = axes - ax.set_xlim(xlim) - - # Return - if len(multiple_axes) == 1: - return multiple_axes[0], xlim - else: - return multiple_axes, xlim - - -def set_ylims(*multiple_axes, ylim=None): - if ylim is None: - raise ValueError("Please set ylim. get_global_xlim() might be useful.") - - for axes in multiple_axes: - # axes - if isinstance(axes, (np.ndarray, scitex.plt._subplots.AxesWrapper)): - for ax in axes.flat: - ax.set_ylim(ylim) - - # axis - else: - ax = axes - ax.set_ylim(ylim) - - # Return - if len(multiple_axes) == 1: - return multiple_axes[0], ylim - else: - return multiple_axes, ylim - - -def main(): - pass - - -if __name__ == "__main__": - # # Argument Parser - # import argparse - import sys - - # parser = argparse.ArgumentParser(description='') - # parser.add_argument('--var', '-v', type=int, default=1, help='') - # parser.add_argument('--flag', '-f', action='store_true', default=False, help='') - # args = parser.parse_args() - # Main - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( - sys, plt, verbose=False - ) - main() - scitex.session.close(CONFIG, verbose=False, notify=False) - -# EOF diff --git a/src/scitex/plt/ax/_style/_shift.py b/src/scitex/plt/ax/_style/_shift.py deleted file mode 100755 index 1cd108163..000000000 --- a/src/scitex/plt/ax/_style/_shift.py +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:00:54 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_shift.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/plt/ax/_style/_shift.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - - -def shift(ax, dx=0, dy=0): - """ - Adjusts the position of an Axes object within a Figure by specified offsets in centimeters. - - This function modifies the position of a given matplotlib.axes.Axes object by shifting it horizontally and vertically within its parent figure. The shift amounts are specified in centimeters, and the function converts these values into the figure's coordinate system to perform the adjustment. - - Parameters: - - ax (matplotlib.axes.Axes): The Axes object to modify. This must be an instance of a Matplotlib Axes. - - dx (float): The horizontal offset in centimeters. Positive values shift the Axes to the right, while negative values shift it to the left. - - dy (float): The vertical offset in centimeters. Positive values shift the Axes up, while negative values shift it down. - - Returns: - - matplotlib.axes.Axes: The modified Axes object with the adjusted position. - """ - - bbox = ax.get_position() - - # Convert centimeters to inches for consistency with matplotlib dimensions - dx_in, dy_in = dx / 2.54, dy / 2.54 - - # Calculate delta ratios relative to the figure size - fig = ax.get_figure() - fig_dx_in, fig_dy_in = fig.get_size_inches() - dx_ratio, dy_ratio = dx_in / fig_dx_in, dy_in / fig_dy_in - - # Determine updated bbox position and optionally adjust dimensions - left = bbox.x0 + dx_ratio - bottom = bbox.y0 + dy_ratio - width = bbox.width - height = bbox.height - - # Main - ax.set_position([left, bottom, width, height]) - - return ax - - -# def adjust_axes_position_and_dimension( -# ax, dx, dy, adjust_width_for_dx=False, adjust_height_for_dy=False -# ): - -# def set_pos(ax, x_cm, y_cm, extend_x=False, extend_y=False): -# """ -# Adjusts the position of an Axes object within a Figure by a specified offset in centimeters. - -# Parameters: -# - ax (matplotlib.axes.Axes): The Axes object to modify. -# - x_cm (float): The horizontal offset in centimeters to adjust the Axes position. -# - y_cm (float): The vertical offset in centimeters to adjust the Axes position. -# - extend_x (bool): If True, reduces the width of the Axes by the horizontal offset. -# - extend_y (bool): If True, reduces the height of the Axes by the vertical offset. - -# Returns: -# - ax (matplotlib.axes.Axes): The modified Axes object with the adjusted position. -# """ - -# bbox = ax.get_position() - -# # Inches -# x_in, y_in = x_cm / 2.54, y_cm / 2.54 - -# # Calculates delta ratios -# fig = ax.get_figure() -# fig_x_in, fig_y_in = fig.get_size_inches() -# x_ratio, y_ratio = x_in / fig_x_in, y_in / fig_y_in - -# # Determines updated bbox position -# left = bbox.x0 + x_ratio -# bottom = bbox.y0 + y_ratio -# width = bbox.width -# height = bbox.height - -# if extend_x: -# width -= x_ratio - -# if extend_y: -# height -= y_ratio - -# ax.set_position([left, bottom, width, height]) - -# return ax - - -# def set_pos( -# fig, -# ax, -# x_cm, -# y_cm, -# dragh=False, -# dragv=False, -# ): - -# bbox = ax.get_position() - -# ## Calculates delta ratios -# fig_x_in, fig_y_in = fig.get_size_inches() - -# x_in = float(x_cm) / 2.54 -# y_in = float(y_cm) / 2.54 - -# x_ratio = x_in / fig_x_in -# y_ratio = y_in / fig_x_in - -# ## Determines updated bbox position -# left = bbox.x0 + x_ratio -# bottom = bbox.y0 + y_ratio -# width = bbox.x1 - bbox.x0 -# height = bbox.y1 - bbox.y0 - -# if dragh: -# width -= x_ratio - -# if dragv: -# height -= y_ratio - -# ax.set_pos( -# [ -# left, -# bottom, -# width, -# height, -# ] -# ) - -# return ax - -# EOF diff --git a/src/scitex/plt/ax/_style/_show_spines.py b/src/scitex/plt/ax/_style/_show_spines.py deleted file mode 100755 index 76130ba36..000000000 --- a/src/scitex/plt/ax/_style/_show_spines.py +++ /dev/null @@ -1,335 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2025-06-04 11:15:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_style/_show_spines.py - -""" -Functionality: - Show spines for matplotlib axes with intuitive API -Input: - Matplotlib axes object and spine visibility parameters -Output: - Axes with specified spines made visible -Prerequisites: - matplotlib -""" - -from typing import List, Union - -import matplotlib - - -def show_spines( - axis, - top: bool = True, - bottom: bool = True, - left: bool = True, - right: bool = True, - ticks: bool = True, - labels: bool = True, - restore_defaults: bool = True, - spine_width: float = None, - spine_color: str = None, -): - """ - Shows the specified spines of a matplotlib Axes object and optionally restores ticks and labels. - - This function provides the intuitive counterpart to hide_spines. It's especially useful when - you have spines hidden by default (as in scitex configuration) and want to selectively show them - for clearer scientific plots or specific visualization needs. - - Parameters - ---------- - axis : matplotlib.axes.Axes - The Axes object for which the spines will be shown. - top : bool, optional - If True, shows the top spine. Defaults to True. - bottom : bool, optional - If True, shows the bottom spine. Defaults to True. - left : bool, optional - If True, shows the left spine. Defaults to True. - right : bool, optional - If True, shows the right spine. Defaults to True. - ticks : bool, optional - If True, restores ticks on the shown spines' axes. Defaults to True. - labels : bool, optional - If True, restores labels on the shown spines' axes. Defaults to True. - restore_defaults : bool, optional - If True, restores default tick positions and labels. Defaults to True. - spine_width : float, optional - Width of the spines to show. If None, uses matplotlib default. - spine_color : str, optional - Color of the spines to show. If None, uses matplotlib default. - - Returns - ------- - matplotlib.axes.Axes - The modified Axes object with the specified spines shown. - - Examples - -------- - >>> fig, ax = plt.subplots() - >>> # Show only bottom and left spines (classic scientific plot style) - >>> show_spines(ax, top=False, right=False) - >>> plt.show() - - >>> # Show all spines with custom styling - >>> show_spines(ax, spine_width=1.5, spine_color='black') - >>> plt.show() - - >>> # Show spines but without ticks/labels (for clean overlay plots) - >>> show_spines(ax, ticks=False, labels=False) - >>> plt.show() - - Notes - ----- - This function is designed to work seamlessly with scitex plotting where spines are hidden - by default. It provides an intuitive API for showing spines without needing to remember - that hide_spines(top=False, right=False) shows top and right spines. - """ - # Handle both matplotlib axes and scitex AxisWrapper - if hasattr(axis, "_axis_mpl"): - # This is an scitex AxisWrapper, get the underlying matplotlib axis - axis = axis._axis_mpl - - assert isinstance( - axis, matplotlib.axes._axes.Axes - ), "First argument must be a matplotlib axis or scitex AxisWrapper" - - # Define which spines to show - spine_settings = {"top": top, "bottom": bottom, "left": left, "right": right} - - for spine_name, should_show in spine_settings.items(): - # Set spine visibility - axis.spines[spine_name].set_visible(should_show) - - if should_show: - # Set spine width if specified - if spine_width is not None: - axis.spines[spine_name].set_linewidth(spine_width) - - # Set spine color if specified - if spine_color is not None: - axis.spines[spine_name].set_color(spine_color) - - # Restore ticks if requested - if ticks and restore_defaults: - # Determine tick positions based on which spines are shown - if bottom and not top: - axis.xaxis.set_ticks_position("bottom") - elif top and not bottom: - axis.xaxis.set_ticks_position("top") - elif bottom and top: - axis.xaxis.set_ticks_position("both") - - if left and not right: - axis.yaxis.set_ticks_position("left") - elif right and not left: - axis.yaxis.set_ticks_position("right") - elif left and right: - axis.yaxis.set_ticks_position("both") - - # Restore labels if requested and restore_defaults is True - if labels and restore_defaults: - # Only restore if we haven't explicitly hidden them - # This preserves any custom tick labels that might have been set - current_xticks = axis.get_xticks() - current_yticks = axis.get_yticks() - - if len(current_xticks) > 0 and (bottom or top): - # Generate default labels for x-axis - if not hasattr(axis, "_original_xticklabels"): - axis.set_xticks(current_xticks) - - if len(current_yticks) > 0 and (left or right): - # Generate default labels for y-axis - if not hasattr(axis, "_original_yticklabels"): - axis.set_yticks(current_yticks) - - return axis - - -def show_all_spines( - axis, - spine_width: float = None, - spine_color: str = None, - ticks: bool = True, - labels: bool = True, -): - """ - Convenience function to show all spines with optional styling. - - Parameters - ---------- - axis : matplotlib.axes.Axes - The Axes object to modify. - spine_width : float, optional - Width of all spines. - spine_color : str, optional - Color of all spines. - ticks : bool, optional - Whether to show ticks. Defaults to True. - labels : bool, optional - Whether to show labels. Defaults to True. - - Returns - ------- - matplotlib.axes.Axes - The modified Axes object. - - Examples - -------- - >>> show_all_spines(ax, spine_width=1.2, spine_color='gray') - """ - return show_spines( - axis, - top=True, - bottom=True, - left=True, - right=True, - ticks=ticks, - labels=labels, - spine_width=spine_width, - spine_color=spine_color, - ) - - -def show_classic_spines( - axis, - spine_width: float = None, - spine_color: str = None, - ticks: bool = True, - labels: bool = True, -): - """ - Show only bottom and left spines (classic scientific plot style). - - Parameters - ---------- - axis : matplotlib.axes.Axes - The Axes object to modify. - spine_width : float, optional - Width of the spines. - spine_color : str, optional - Color of the spines. - ticks : bool, optional - Whether to show ticks. Defaults to True. - labels : bool, optional - Whether to show labels. Defaults to True. - - Returns - ------- - matplotlib.axes.Axes - The modified Axes object. - - Examples - -------- - >>> show_classic_spines(ax) # Shows only bottom and left spines - """ - return show_spines( - axis, - top=False, - bottom=True, - left=True, - right=False, - ticks=ticks, - labels=labels, - spine_width=spine_width, - spine_color=spine_color, - ) - - -def show_box_spines( - axis, - spine_width: float = None, - spine_color: str = None, - ticks: bool = True, - labels: bool = True, -): - """ - Show all four spines to create a box around the plot. - - This is an alias for show_all_spines but with more descriptive naming - for when you specifically want a boxed appearance. - - Parameters - ---------- - axis : matplotlib.axes.Axes - The Axes object to modify. - spine_width : float, optional - Width of the box spines. - spine_color : str, optional - Color of the box spines. - ticks : bool, optional - Whether to show ticks. Defaults to True. - labels : bool, optional - Whether to show labels. Defaults to True. - - Returns - ------- - matplotlib.axes.Axes - The modified Axes object. - - Examples - -------- - >>> show_box_spines(ax, spine_width=1.0, spine_color='black') - """ - return show_all_spines(axis, spine_width, spine_color, ticks, labels) - - -def toggle_spines( - axis, top: bool = None, bottom: bool = None, left: bool = None, right: bool = None -): - """ - Toggle the visibility of spines (show if hidden, hide if shown). - - Parameters - ---------- - axis : matplotlib.axes.Axes - The Axes object to modify. - top : bool, optional - If specified, sets top spine visibility. If None, toggles current state. - bottom : bool, optional - If specified, sets bottom spine visibility. If None, toggles current state. - left : bool, optional - If specified, sets left spine visibility. If None, toggles current state. - right : bool, optional - If specified, sets right spine visibility. If None, toggles current state. - - Returns - ------- - matplotlib.axes.Axes - The modified Axes object. - - Examples - -------- - >>> toggle_spines(ax) # Toggles all spines - >>> toggle_spines(ax, top=True, right=True) # Shows top and right, toggles others - """ - spine_names = ["top", "bottom", "left", "right"] - spine_params = [top, bottom, left, right] - - for spine_name, param in zip(spine_names, spine_params): - if param is None: - # Toggle current state - current_state = axis.spines[spine_name].get_visible() - axis.spines[spine_name].set_visible(not current_state) - else: - # Set specific state - axis.spines[spine_name].set_visible(param) - - return axis - - -# Convenient aliases for common use cases -def scientific_spines(axis, **kwargs): - """Alias for show_classic_spines - shows only bottom and left spines.""" - return show_classic_spines(axis, **kwargs) - - -def clean_spines(axis, **kwargs): - """Alias for showing no spines - useful for overlay plots or clean visualizations.""" - return show_spines(axis, top=False, bottom=False, left=False, right=False, **kwargs) - - -# EOF diff --git a/src/scitex/plt/ax/_style/_style_barplot.py b/src/scitex/plt/ax/_style/_style_barplot.py deleted file mode 100755 index 390e9c7fa..000000000 --- a/src/scitex/plt/ax/_style/_style_barplot.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 20:00:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_style/_style_barplot.py - -""" -Style bar plot elements with millimeter-based control. - -Default values are loaded from SCITEX_STYLE.yaml via presets.py. -""" - -from typing import List, Optional, Union - -from scitex.plt.styles.presets import SCITEX_STYLE - -# Get defaults from centralized config -_DEFAULT_EDGE_THICKNESS_MM = SCITEX_STYLE.get("bar_edge_thickness_mm", 0.2) - - -def style_barplot( - bar_container, - edge_thickness_mm: float = None, - edgecolor: Optional[Union[str, List[str]]] = "black", -): - """ - Apply consistent styling to matplotlib bar plot elements. - - Parameters - ---------- - bar_container : BarContainer - Container returned by ax.bar() or ax.barh() - edge_thickness_mm : float, optional - Edge line thickness in millimeters (default: 0.2mm) - edgecolor : str or list of str, optional - Edge color(s) for bars. If None, uses default matplotlib colors. - - Returns - ------- - bar_container : BarContainer - The styled bar container - - Examples - -------- - >>> fig, ax = stx.plt.subplots(**stx.plt.presets.NATURE_STYLE) - >>> bars = ax.bar(x, heights) - >>> stx.plt.ax.style_barplot(bars, edge_thickness_mm=0.2, edgecolor='black') - """ - from scitex.plt.utils import mm_to_pt - - # Use centralized default if not specified - if edge_thickness_mm is None: - edge_thickness_mm = _DEFAULT_EDGE_THICKNESS_MM - - # Convert mm to points - lw_pt = mm_to_pt(edge_thickness_mm) - - # Style each bar - for i, bar in enumerate(bar_container): - bar.set_linewidth(lw_pt) - if edgecolor is not None: - if isinstance(edgecolor, list): - bar.set_edgecolor(edgecolor[i % len(edgecolor)]) - else: - bar.set_edgecolor(edgecolor) - - return bar_container - - -# EOF diff --git a/src/scitex/plt/ax/_style/_style_boxplot.py b/src/scitex/plt/ax/_style/_style_boxplot.py deleted file mode 100755 index c4fc77e2d..000000000 --- a/src/scitex/plt/ax/_style/_style_boxplot.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2025-12-01 20:00:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_style/_style_boxplot.py - -""" -Style boxplot elements with millimeter-based control. - -Default values are loaded from SCITEX_STYLE.yaml via presets.py. -""" - -from typing import Optional - -from scitex.plt.styles.presets import SCITEX_STYLE - -# Get defaults from centralized config -_DEFAULT_LINEWIDTH_MM = SCITEX_STYLE.get("trace_thickness_mm", 0.2) -_DEFAULT_FLIER_SIZE_MM = SCITEX_STYLE.get("marker_size_mm", 0.8) - - -def style_boxplot( # noqa: C901 - boxplot_dict, - linewidth_mm: float = None, - flier_size_mm: float = None, - median_color: str = "black", - edge_color: str = "black", - colors: Optional[list] = None, - add_legend: bool = False, - labels: Optional[list] = None, -): - """Apply publication-quality styling to matplotlib boxplot elements. - - This function modifies boxplots to: - - Set consistent line widths for all elements - - Set median line to black for visibility - - Set edge colors to black - - Apply consistent outlier marker styling - - Use scitex color palette by default for box fills - - Parameters - ---------- - boxplot_dict : dict - Dictionary returned by ax.boxplot(). - linewidth_mm : float, default 0.2 - Line width in millimeters for all elements. - flier_size_mm : float, default 0.8 - Outlier (flier) marker size in millimeters. - median_color : str, default "black" - Color for the median line inside boxes. - edge_color : str, default "black" - Color for box edges, whiskers, and caps. - colors : list, optional - List of colors for each box fill. If None, uses scitex color palette. - add_legend : bool, default False - Whether to add a legend. - labels : list, optional - Labels for legend entries (required if add_legend=True). - - Returns - ------- - boxplot_dict : dict - The styled boxplot dictionary. - - Examples - -------- - >>> import scitex as stx - >>> import numpy as np - >>> fig, ax = stx.plt.subplots() - >>> box_data = [np.random.normal(0, 1, 100) for _ in range(4)] - >>> bp = ax.boxplot(box_data, patch_artist=True) - >>> stx.plt.ax.style_boxplot(bp, median_color="black") - """ - from scitex.plt.color import HEX - from scitex.plt.utils import mm_to_pt - - # Use centralized defaults if not specified - if linewidth_mm is None: - linewidth_mm = _DEFAULT_LINEWIDTH_MM - if flier_size_mm is None: - flier_size_mm = _DEFAULT_FLIER_SIZE_MM - - # Convert mm to points - lw_pt = mm_to_pt(linewidth_mm) - flier_size_pt = mm_to_pt(flier_size_mm) - - # Use scitex color palette by default - if colors is None: - colors = [ - HEX["blue"], - HEX["red"], - HEX["green"], - HEX["yellow"], - HEX["purple"], - HEX["orange"], - HEX["lightblue"], - HEX["pink"], - ] - - # Style box elements with line width - for element_name in ["boxes", "whiskers", "caps"]: - if element_name in boxplot_dict: - for element in boxplot_dict[element_name]: - element.set_linewidth(lw_pt) - element.set_color(edge_color) - - # Style medians with specified color - if "medians" in boxplot_dict: - for median in boxplot_dict["medians"]: - median.set_linewidth(lw_pt) - median.set_color(median_color) - - # Style fliers (outliers) with marker size - if "fliers" in boxplot_dict: - for flier in boxplot_dict["fliers"]: - flier.set_markersize(flier_size_pt) - flier.set_markeredgewidth(lw_pt) - flier.set_markeredgecolor(edge_color) - flier.set_markerfacecolor("none") # Open circles - - # Apply fill colors to boxes - for i, box in enumerate(boxplot_dict.get("boxes", [])): - color = colors[i % len(colors)] - if hasattr(box, "set_facecolor"): - box.set_facecolor(color) - box.set_edgecolor(edge_color) - - # Add legend if requested - if add_legend and labels is not None: - # Create proxy artists for legend - import matplotlib.patches as mpatches - - if colors is not None: - legend_elements = [ - mpatches.Patch( - facecolor="none", edgecolor=color, linewidth=lw_pt, label=label - ) - for color, label in zip(colors, labels) - ] - else: - legend_elements = [ - mpatches.Patch( - facecolor="none", edgecolor="C0", linewidth=lw_pt, label=label - ) - for label in labels - ] - # Get the axes from one of the box elements - if boxplot_dict.get("boxes"): - ax = boxplot_dict["boxes"][0].axes - ax.legend(handles=legend_elements) - - return boxplot_dict - - -# EOF diff --git a/src/scitex/plt/ax/_style/_style_errorbar.py b/src/scitex/plt/ax/_style/_style_errorbar.py deleted file mode 100755 index 4039c0eaf..000000000 --- a/src/scitex/plt/ax/_style/_style_errorbar.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 20:00:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_style/_style_errorbar.py - -""" -Style error bar elements with millimeter-based control. - -Default values are loaded from SCITEX_STYLE.yaml via presets.py. -""" - -from typing import Optional - -from scitex.plt.styles.presets import SCITEX_STYLE - -# Get defaults from centralized config -_DEFAULT_THICKNESS_MM = SCITEX_STYLE.get("trace_thickness_mm", 0.2) -_DEFAULT_CAP_WIDTH_MM = SCITEX_STYLE.get("errorbar_cap_width_mm", 0.8) - - -def style_errorbar( - errorbar_container, - thickness_mm: float = None, - cap_width_mm: float = None, -): - """ - Apply consistent styling to matplotlib errorbar elements. - - Parameters - ---------- - errorbar_container : ErrorbarContainer - Container returned by ax.errorbar() - thickness_mm : float, optional - Line thickness for error bars in millimeters (default: 0.2mm) - cap_width_mm : float, optional - Cap width in millimeters (default: 0.8mm) - - Returns - ------- - errorbar_container : ErrorbarContainer - The styled errorbar container - - Examples - -------- - >>> fig, ax = stx.plt.subplots(**stx.plt.presets.NATURE_STYLE) - >>> eb = ax.errorbar(x, y, yerr=yerr) - >>> stx.plt.ax.style_errorbar(eb, thickness_mm=0.2, cap_width_mm=0.8) - """ - from scitex.plt.utils import mm_to_pt - - # Use centralized defaults if not specified - if thickness_mm is None: - thickness_mm = _DEFAULT_THICKNESS_MM - if cap_width_mm is None: - cap_width_mm = _DEFAULT_CAP_WIDTH_MM - - # Convert mm to points - lw_pt = mm_to_pt(thickness_mm) - cap_width_pt = mm_to_pt(cap_width_mm) - - # Style the data line - if errorbar_container[0] is not None: - errorbar_container[0].set_linewidth(lw_pt) - - # Style the error bar lines - if len(errorbar_container) > 2 and errorbar_container[2] is not None: - for line_collection in errorbar_container[2]: - if line_collection is not None: - line_collection.set_linewidth(lw_pt) - - # Style the caps - if len(errorbar_container) > 1 and errorbar_container[1] is not None: - for cap in errorbar_container[1]: - if cap is not None: - cap.set_linewidth(lw_pt) # Cap line thickness same as error bar - # Set cap marker size (width) - cap.set_markersize(cap_width_pt) - - return errorbar_container - - -# EOF diff --git a/src/scitex/plt/ax/_style/_style_scatter.py b/src/scitex/plt/ax/_style/_style_scatter.py deleted file mode 100755 index 7aaa285fa..000000000 --- a/src/scitex/plt/ax/_style/_style_scatter.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 20:00:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_style/_style_scatter.py - -""" -Style scatter plot elements with millimeter-based control. - -Default values are loaded from SCITEX_STYLE.yaml via presets.py. -""" - -from typing import Optional - -from scitex.plt.styles.presets import SCITEX_STYLE - -# Get defaults from centralized config -_DEFAULT_SIZE_MM = SCITEX_STYLE.get("scatter_size_mm", 0.8) -_DEFAULT_EDGE_THICKNESS_MM = SCITEX_STYLE.get("marker_edge_width_mm", 0.0) - - -def style_scatter( - path_collection, - size_mm: float = None, - edge_thickness_mm: float = None, -): - """ - Apply consistent styling to matplotlib scatter plot elements. - - Parameters - ---------- - path_collection : PathCollection - Collection returned by ax.scatter() - size_mm : float, optional - Marker size in millimeters (default: 0.8mm) - edge_thickness_mm : float, optional - Edge line thickness in millimeters (default: 0.0mm = no border) - - Returns - ------- - path_collection : PathCollection - The styled path collection - - Examples - -------- - >>> fig, ax = stx.plt.subplots(**stx.plt.presets.NATURE_STYLE) - >>> scatter = ax.scatter(x, y) - >>> stx.ax.style_scatter(scatter, size_mm=0.8) - - Notes - ----- - Matplotlib scatter uses marker size in points squared. - We convert mm to points, then square for the area. - By default, no border is applied (edge_thickness_mm=0). - """ - from scitex.plt.utils import mm_to_pt - - # Use centralized defaults if not specified - if size_mm is None: - size_mm = _DEFAULT_SIZE_MM - if edge_thickness_mm is None: - edge_thickness_mm = _DEFAULT_EDGE_THICKNESS_MM - - # Convert mm to points - size_pt = mm_to_pt(size_mm) - - # Matplotlib scatter uses area (points^2) - # For a marker of diameter d, area = (d/2)^2 * pi - # But matplotlib's 's' parameter is already area-like - # So we use size_pt^2 to get the right visual size - marker_area = size_pt**2 - - # Set marker size - path_collection.set_sizes([marker_area]) - - # Set edge thickness (0 by default = no border) - edge_width_pt = mm_to_pt(edge_thickness_mm) - path_collection.set_linewidths(edge_width_pt) - - return path_collection - - -# EOF diff --git a/src/scitex/plt/ax/_style/_style_suptitles.py b/src/scitex/plt/ax/_style/_style_suptitles.py deleted file mode 100755 index 1dea88c8d..000000000 --- a/src/scitex/plt/ax/_style/_style_suptitles.py +++ /dev/null @@ -1,76 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-11-19 15:20:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_style/_style_suptitles.py - -""" -Style figure-level titles and labels with proper font sizes. -""" - -from typing import Optional - - -def style_suptitles( - fig, - suptitle_font_size_pt: float = 7, - font_family: str = "DejaVu Sans", -): - """ - Apply consistent styling to figure-level titles and labels. - - Parameters - ---------- - fig : matplotlib.figure.Figure or FigWrapper - The figure to style - suptitle_font_size_pt : float, optional - Font size in points for suptitle, supxlabel, supylabel (default: 7) - font_family : str, optional - Font family to use (default: "DejaVu Sans") - - Returns - ------- - fig : matplotlib.figure.Figure or FigWrapper - The styled figure - - Examples - -------- - >>> fig, axes = stx.plt.subplots(2, 2, **stx.plt.presets.NATURE_STYLE) - >>> fig.suptitle("Main Title") - >>> fig.supxlabel("X Axis Label") - >>> fig.supylabel("Y Axis Label") - >>> stx.ax.style_suptitles(fig) - - Notes - ----- - This function applies font styling to: - - fig.suptitle() - Main figure title - - fig.supxlabel() - Figure-level X axis label - - fig.supylabel() - Figure-level Y axis label - - All are set to the same font size (default 7pt for publication). - """ - # Unwrap FigWrapper if needed - if hasattr(fig, "_fig_mpl"): - fig_mpl = fig._fig_mpl - else: - fig_mpl = fig - - # Style suptitle - if fig_mpl._suptitle is not None: - fig_mpl._suptitle.set_fontsize(suptitle_font_size_pt) - fig_mpl._suptitle.set_fontfamily(font_family) - - # Style supxlabel (if it exists) - if hasattr(fig_mpl, "_supxlabel") and fig_mpl._supxlabel is not None: - fig_mpl._supxlabel.set_fontsize(suptitle_font_size_pt) - fig_mpl._supxlabel.set_fontfamily(font_family) - - # Style supylabel (if it exists) - if hasattr(fig_mpl, "_supylabel") and fig_mpl._supylabel is not None: - fig_mpl._supylabel.set_fontsize(suptitle_font_size_pt) - fig_mpl._supylabel.set_fontfamily(font_family) - - return fig - - -# EOF diff --git a/src/scitex/plt/ax/_style/_style_violinplot.py b/src/scitex/plt/ax/_style/_style_violinplot.py deleted file mode 100755 index 111d8f1d8..000000000 --- a/src/scitex/plt/ax/_style/_style_violinplot.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 20:00:00 (ywatanabe)" -# File: ./src/scitex/plt/ax/_style/_style_violinplot.py - -"""Style violin plot elements with millimeter-based control. - -Default values are loaded from SCITEX_STYLE.yaml via presets.py. -""" - -from typing import Optional, Union - -from matplotlib.axes import Axes - -from scitex.plt.styles.presets import SCITEX_STYLE - -# Get defaults from centralized config -_DEFAULT_LINEWIDTH_MM = SCITEX_STYLE.get("trace_thickness_mm", 0.2) - - -def style_violinplot( - ax: Union[Axes, "AxisWrapper"], - linewidth_mm: float = None, - edge_color: str = "black", - median_color: str = "black", - remove_caps: bool = True, -) -> Union[Axes, "AxisWrapper"]: - """Apply publication-quality styling to seaborn violin plots. - - This function modifies violin plots created by seaborn.violinplot() to: - - Add borders to the KDE (violin body) edges - - Remove caps from the internal boxplot whiskers - - Set median line to black for better visibility - - Apply consistent line widths - - Parameters - ---------- - ax : matplotlib.axes.Axes or AxisWrapper - The axes containing the violin plot. - linewidth_mm : float, default 0.2 - Line width in millimeters for violin edges and boxplot elements. - edge_color : str, default "black" - Color for the violin body edges. - median_color : str, default "black" - Color for the median line inside the boxplot. - remove_caps : bool, default True - Whether to remove the caps (horizontal lines) from boxplot whiskers. - - Returns - ------- - ax : matplotlib.axes.Axes or AxisWrapper - The axes with styled violin plot. - - Examples - -------- - >>> import seaborn as sns - >>> import scitex as stx - >>> fig, ax = stx.plt.subplots() - >>> sns.violinplot(data=df, x="group", y="value", ax=ax) - >>> stx.plt.ax.style_violinplot(ax) - """ - from scitex.plt.utils import mm_to_pt - - # Use centralized default if not specified - if linewidth_mm is None: - linewidth_mm = _DEFAULT_LINEWIDTH_MM - - lw_pt = mm_to_pt(linewidth_mm) - - # Style violin bodies (PolyCollection) - for collection in ax.collections: - # Check if it's a violin body (PolyCollection with filled area) - if hasattr(collection, "set_edgecolor"): - collection.set_edgecolor(edge_color) - collection.set_linewidth(lw_pt) - - # Style internal boxplot elements (Line2D objects) - # Seaborn violin plot lines: whiskers (vertical), caps (horizontal), median (short horizontal) - lines = list(ax.lines) - n_violins = len( - [ - c - for c in ax.collections - if hasattr(c, "get_paths") and len(c.get_paths()) > 0 - ] - ) - - for line in lines: - # Get line data to identify element type - xdata = line.get_xdata() - ydata = line.get_ydata() - - if len(ydata) != 2: - continue - - # Caps are horizontal lines (same y-value for both points) with wider x-span - is_horizontal = ydata[0] == ydata[1] - x_span = abs(xdata[1] - xdata[0]) if len(xdata) == 2 else 0 - - if is_horizontal: - if remove_caps and x_span > 0.05: - # This is likely a cap (wider horizontal line at whisker ends) - line.set_visible(False) - else: - # This is likely a median line (short horizontal line) - line.set_color(median_color) - line.set_linewidth(lw_pt) - else: - # Vertical lines (whiskers) - line.set_linewidth(lw_pt) - - return ax - - -# EOF diff --git a/src/scitex/plt/styles/__init__.py b/src/scitex/plt/styles/__init__.py index e1390af34..7694207d5 100755 --- a/src/scitex/plt/styles/__init__.py +++ b/src/scitex/plt/styles/__init__.py @@ -1,34 +1,21 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 21:00:00 (ywatanabe)" # File: ./src/scitex/plt/styles/__init__.py """SciTeX plot styling module. -This module centralizes all plot-specific default styling, including: -- Pre-processing: Default kwargs applied before matplotlib method calls -- Post-processing: Styling applied after matplotlib method calls -- Style configuration with priority resolution: direct → yaml → env → default +Style configuration with priority resolution: direct -> yaml -> env -> default. Usage: - from scitex.plt.styles import apply_plot_defaults, apply_plot_postprocess + from scitex.plt.styles import SCITEX_STYLE, load_style, resolve_style_value - # In AxisWrapper.__getattr__ wrapper: - apply_plot_defaults(method_name, kwargs, id_value, ax) - result = orig_method(*args, **kwargs) - apply_plot_postprocess(method_name, result, ax, kwargs) - - # Style configuration - from scitex.plt.styles import SCITEX_STYLE, load_style - fig, ax = stx.plt.subplots(**SCITEX_STYLE) - - # Custom YAML - style = load_style("path/to/my_style.yaml") + # Load style as subplots kwargs + style = load_style() fig, ax = stx.plt.subplots(**style) + + # Resolve individual values + dpi = resolve_style_value("output.dpi", None, 300) """ -from ._plot_defaults import apply_plot_defaults -from ._plot_postprocess import apply_plot_postprocess from .presets import ( # DPI utilities DPI_DISPLAY, DPI_PREVIEW, @@ -46,9 +33,6 @@ ) __all__ = [ - # Styling functions - "apply_plot_defaults", - "apply_plot_postprocess", # Style configuration "SCITEX_STYLE", "STYLE", diff --git a/src/scitex/plt/styles/_plot_defaults.py b/src/scitex/plt/styles/_plot_defaults.py deleted file mode 100755 index 08fc75a9c..000000000 --- a/src/scitex/plt/styles/_plot_defaults.py +++ /dev/null @@ -1,210 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-12-01 10:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_plot_defaults.py - -"""Pre-processing default kwargs for plot methods. - -This module centralizes all default styling applied BEFORE matplotlib -methods are called. Each function modifies kwargs in-place. - -Priority: direct kwarg → env var → YAML config → default - -Style values use the key format from YAML (e.g., 'lines.trace_mm'). -Env vars: SCITEX_PLT_LINES_TRACE_MM (prefix + dots→underscores + uppercase) -""" - -from scitex.plt.styles.presets import resolve_style_value -from scitex.plt.utils import mm_to_pt - -# Default alpha for fill regions (0.3 = semi-transparent) -DEFAULT_FILL_ALPHA = 0.3 - - -# ============================================================================ -# Style helper function -# ============================================================================ -def _get_style_value(key, default, style_dict=None): - """Get style value with priority: style_dict → active_style → env → yaml → default. - - Args: - key: YAML-style key (e.g., 'lines.trace_mm') - default: Fallback default value - style_dict: Optional user-provided style dict (overrides all) - - Returns: - Resolved style value - """ - flat_key = _yaml_key_to_flat(key) - - # Priority 1: User passed explicit style dict - if style_dict is not None and flat_key in style_dict: - return style_dict[flat_key] - - # Priority 2: Check active style set via set_style() - from scitex.plt.styles.presets import _active_style - - if _active_style is not None and flat_key in _active_style: - return _active_style[flat_key] - - # Priority 3: Use resolve_style_value for: env → yaml → default - return resolve_style_value(key, None, default) - - -def _yaml_key_to_flat(key): - """Convert YAML key to flat SCITEX_STYLE key. - - Examples: - 'lines.trace_mm' -> 'trace_thickness_mm' - 'markers.size_mm' -> 'marker_size_mm' - """ - # Mapping from YAML keys to flat keys used in SCITEX_STYLE - mapping = { - "lines.trace_mm": "trace_thickness_mm", - "lines.errorbar_mm": "errorbar_thickness_mm", - "lines.errorbar_cap_mm": "errorbar_cap_width_mm", - "markers.size_mm": "marker_size_mm", - } - return mapping.get(key, key) - - -# ============================================================================ -# Pre-processing functions -# ============================================================================ -def apply_plot_defaults(method_name, kwargs, id_value=None, ax=None): - """Apply default kwargs for a plot method before calling matplotlib. - - Args: - method_name: Name of the matplotlib method being called - kwargs: Keyword arguments dict (modified in-place) - id_value: Optional id passed to the method - ax: The matplotlib axes (for methods needing axis setup) - - Returns: - Modified kwargs dict - - Note: - Priority: direct kwarg → style dict → env var → yaml → default - Users can pass `style=dict` kwarg to override env/yaml defaults. - """ - # Extract optional style dict (removes 'style' key from kwargs) - style_dict = kwargs.pop("style", None) - - # Dispatch to method-specific defaults - if method_name == "plot": - _apply_plot_line_defaults(kwargs, id_value, style_dict) - elif method_name in ("bar", "barh"): - _apply_bar_defaults(kwargs, style_dict) - elif method_name == "errorbar": - _apply_errorbar_defaults(kwargs, style_dict) - elif method_name in ("fill_between", "fill_betweenx"): - _apply_fill_defaults(kwargs) - elif method_name in ("quiver", "streamplot"): - _apply_vector_field_defaults(method_name, kwargs, ax, style_dict) - elif method_name == "boxplot": - _apply_boxplot_defaults(kwargs) - elif method_name == "violinplot": - _apply_violinplot_defaults(kwargs) - - return kwargs - - -def _apply_plot_line_defaults(kwargs, id_value=None, style_dict=None): - """Apply defaults for ax.plot() method.""" - line_width_mm = _get_style_value("lines.trace_mm", 0.2, style_dict) - - # Default line width - if "linewidth" not in kwargs and "lw" not in kwargs: - kwargs["linewidth"] = mm_to_pt(line_width_mm) - - # KDE-specific styling when id contains "kde" - if id_value and "kde" in str(id_value).lower(): - if "linestyle" not in kwargs and "ls" not in kwargs: - kwargs["linestyle"] = "--" - if "color" not in kwargs and "c" not in kwargs: - kwargs["color"] = "black" - - -def _apply_bar_defaults(kwargs, style_dict=None): - """Apply defaults for ax.bar() and ax.barh() methods.""" - line_width_mm = _get_style_value("lines.trace_mm", 0.2, style_dict) - - # Set error bar line thickness - if "error_kw" not in kwargs: - kwargs["error_kw"] = {} - if "elinewidth" not in kwargs.get("error_kw", {}): - kwargs["error_kw"]["elinewidth"] = mm_to_pt(line_width_mm) - if "capthick" not in kwargs.get("error_kw", {}): - kwargs["error_kw"]["capthick"] = mm_to_pt(line_width_mm) - # Set a temporary capsize that will be adjusted in post-processing - if "capsize" not in kwargs: - kwargs["capsize"] = 5 # Placeholder, adjusted later to 33% of bar width - - -def _apply_errorbar_defaults(kwargs, style_dict=None): - """Apply defaults for ax.errorbar() method.""" - line_width_mm = _get_style_value("lines.trace_mm", 0.2, style_dict) - cap_size_mm = _get_style_value("lines.errorbar_cap_mm", 0.8, style_dict) - - if "capsize" not in kwargs: - kwargs["capsize"] = mm_to_pt(cap_size_mm) - if "capthick" not in kwargs: - kwargs["capthick"] = mm_to_pt(line_width_mm) - if "elinewidth" not in kwargs: - kwargs["elinewidth"] = mm_to_pt(line_width_mm) - - -def _apply_fill_defaults(kwargs): - """Apply defaults for ax.fill_between() and ax.fill_betweenx() methods.""" - if "alpha" not in kwargs: - kwargs["alpha"] = DEFAULT_FILL_ALPHA # Transparent to see overlapping data - - -def _apply_vector_field_defaults(method_name, kwargs, ax, style_dict=None): - """Apply defaults for ax.quiver() and ax.streamplot() methods.""" - line_width_mm = _get_style_value("lines.trace_mm", 0.2, style_dict) - marker_size_mm = _get_style_value("markers.size_mm", 0.8, style_dict) - - # Set equal aspect ratio for proper vector display - if ax is not None: - ax.set_aspect("equal", adjustable="datalim") - - if method_name == "streamplot": - if "arrowsize" not in kwargs: - # arrowsize is a scaling factor; scale relative to default - kwargs["arrowsize"] = mm_to_pt(marker_size_mm) / 3 - if "linewidth" not in kwargs: - kwargs["linewidth"] = mm_to_pt(line_width_mm) - - elif method_name == "quiver": - if "width" not in kwargs: - kwargs["width"] = 0.003 # Narrow arrow shaft (axes fraction) - if "headwidth" not in kwargs: - kwargs["headwidth"] = 3 # Head width relative to shaft - if "headlength" not in kwargs: - kwargs["headlength"] = 4 - if "headaxislength" not in kwargs: - kwargs["headaxislength"] = 3.5 - - -def _apply_boxplot_defaults(kwargs): - """Apply defaults for ax.boxplot() method.""" - # Enable patch_artist for fillable boxes - if "patch_artist" not in kwargs: - kwargs["patch_artist"] = True - - -def _apply_violinplot_defaults(kwargs): - """Apply defaults for ax.violinplot() method.""" - # Default to showing boxplot overlay (can be disabled with boxplot=False) - # Store the boxplot setting for post-processing, then remove from kwargs - # so it doesn't get passed to matplotlib's violinplot - if "boxplot" not in kwargs: - kwargs["boxplot"] = True # Default: add boxplot overlay - - # Default to hiding extrema (min/max bars) when boxplot is shown - if "showextrema" not in kwargs: - kwargs["showextrema"] = False - - -# EOF diff --git a/src/scitex/plt/styles/_plot_postprocess.py b/src/scitex/plt/styles/_plot_postprocess.py deleted file mode 100755 index 19b2fe9c0..000000000 --- a/src/scitex/plt/styles/_plot_postprocess.py +++ /dev/null @@ -1,487 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2026-01-13 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_plot_postprocess.py - -"""Post-processing styling for plot methods. - -This module centralizes all styling applied AFTER matplotlib methods -are called. Each function modifies the plot result or axes in-place. - -All default values are loaded from SCITEX_STYLE.yaml via presets.py. -Delegates to figrecipe styling functions when available. -""" - -from matplotlib.category import StrCategoryConverter, UnitData -from matplotlib.ticker import FixedLocator, MaxNLocator - -from scitex.plt.styles._postprocess_helpers import ( - calculate_cap_width_from_bar, - calculate_cap_width_from_box, - make_errorbar_one_sided, -) -from scitex.plt.styles.presets import SCITEX_STYLE -from scitex.plt.utils import mm_to_pt - -# ============================================================================ -# Constants (loaded from centralized SCITEX_STYLE.yaml) -# ============================================================================ -DEFAULT_LINE_WIDTH_MM = SCITEX_STYLE.get("trace_thickness_mm", 0.2) -DEFAULT_MARKER_SIZE_MM = SCITEX_STYLE.get("marker_size_mm", 0.8) -DEFAULT_N_TICKS = SCITEX_STYLE.get("n_ticks", 4) - 1 # nbins = n_ticks - 1 -SPINE_ZORDER = 1000 - - -# ============================================================================ -# Main post-processing function -# ============================================================================ -def apply_plot_postprocess(method_name, result, ax, kwargs, args=None): # noqa: C901 - """Apply post-processing styling after matplotlib method call. - - Args: - method_name: Name of the matplotlib method that was called - result: Return value from the matplotlib method - ax: The matplotlib axes - kwargs: Original kwargs passed to the method - args: Original positional args passed to the method (needed for violinplot) - - Returns - ------- - The result (possibly modified) - """ - # Always ensure spines are on top - _ensure_spines_on_top(ax) - - # Apply tick locator for numerical axes - _apply_tick_locator(ax) - - # Method-specific post-processing - if method_name == "pie" and result is not None: - _postprocess_pie(result) - elif method_name == "stem" and result is not None: - _postprocess_stem(result) - elif method_name == "violinplot" and result is not None: - _postprocess_violin(result, ax, kwargs, args) - elif method_name == "boxplot" and result is not None: - _postprocess_boxplot(result, ax) - elif method_name == "scatter" and result is not None: - _postprocess_scatter(result, kwargs) - elif method_name == "bar" and result is not None: - _postprocess_bar(result, ax, kwargs) - elif method_name == "barh" and result is not None: - _postprocess_barh(result, ax, kwargs) - elif method_name == "errorbar" and result is not None: - _postprocess_errorbar(result) - elif method_name == "hist" and result is not None: - _postprocess_hist(result, ax) - elif method_name == "fill_between" and result is not None: - _postprocess_fill_between(result, kwargs) - - return result - - -# ============================================================================ -# General post-processing -# ============================================================================ -def _ensure_spines_on_top(ax): - """Ensure axes spines are always drawn in front of plot elements.""" - try: - ax.set_axisbelow(False) - - # Set very high z-order for spines - for spine in ax.spines.values(): - spine.set_zorder(SPINE_ZORDER) - - # Set z-order for tick marks - ax.tick_params(zorder=SPINE_ZORDER) - - # Ensure plot patches have lower z-order than spines - # But preserve intentionally set z-orders (e.g., boxplot in violin) - for patch in ax.patches: - current_z = patch.get_zorder() - # Only lower z-order if it's >= SPINE_ZORDER or is at matplotlib default (1) - if current_z >= SPINE_ZORDER: - patch.set_zorder(current_z - SPINE_ZORDER) - elif current_z == 1: - # Default matplotlib z-order, lower it - patch.set_zorder(0.5) - # Otherwise, preserve the intentionally set z-order - - # Set axes patch behind everything - ax.patch.set_zorder(-1) - except Exception: - pass - - -def _apply_tick_locator(ax): - """Apply MaxNLocator only to numerical (non-categorical) axes. - - Target: 3-4 ticks per axis for clean publication figures. - MaxNLocator's nbins=3 gives approximately 3-4 tick marks. - min_n_ticks=3 ensures at least 3 ticks (never 2). - """ - try: - - def is_categorical_axis(axis): - # Use get_converter() for matplotlib 3.10+ compatibility - converter = getattr(axis, "get_converter", lambda: axis.converter)() - if isinstance(converter, StrCategoryConverter): - return True - if hasattr(axis, "units") and isinstance(axis.units, UnitData): - return True - if isinstance(axis.get_major_locator(), FixedLocator): - return True - return False - - if not is_categorical_axis(ax.xaxis): - ax.xaxis.set_major_locator( - MaxNLocator( - nbins=DEFAULT_N_TICKS, min_n_ticks=3, integer=False, prune=None - ) - ) - - if not is_categorical_axis(ax.yaxis): - ax.yaxis.set_major_locator( - MaxNLocator( - nbins=DEFAULT_N_TICKS, min_n_ticks=3, integer=False, prune=None - ) - ) - except Exception: - pass - - -# ============================================================================ -# Method-specific post-processing -# ============================================================================ -def _postprocess_pie(result): - """Apply styling for pie charts.""" - # pie returns (wedges, texts, autotexts) when autopct is used - if len(result) >= 3: - autotexts = result[2] - for autotext in autotexts: - autotext.set_fontsize(6) # 6pt for inline percentages - - -def _postprocess_stem(result): - """Apply styling for stem plots.""" - baseline = result.baseline - if baseline is not None: - baseline.set_color("black") - baseline.set_linestyle("--") - - -def _postprocess_errorbar(result): - """Apply styling for errorbar plots. - - Simplifies the legend to show only a line (no caps/bars). - """ - import matplotlib.legend as mlegend - from matplotlib.container import ErrorbarContainer - from matplotlib.legend_handler import HandlerErrorbar, HandlerLine2D - - # Custom handler that shows only a simple line for errorbar - class SimpleLineHandler(HandlerErrorbar): - def create_artists( - self, - legend, - orig_handle, - xdescent, - ydescent, - width, - height, - fontsize, - trans, - ): - # Use HandlerLine2D to create just a line - line_handler = HandlerLine2D() - # Get the data line from the ErrorbarContainer - data_line = orig_handle[0] - if data_line is not None: - return line_handler.create_artists( - legend, - data_line, - xdescent, - ydescent, - width, - height, - fontsize, - trans, - ) - return [] - - # Register the handler globally for ErrorbarContainer - mlegend.Legend.update_default_handler_map({ErrorbarContainer: SimpleLineHandler()}) - - -def _postprocess_violin(result, ax, kwargs, args): # noqa: C901 - """Apply styling for violin plots with optional boxplot overlay.""" - # Get scitex palette for coloring - from scitex.plt.color import HEX - - palette = [ - HEX["blue"], - HEX["red"], - HEX["green"], - HEX["yellow"], - HEX["purple"], - HEX["orange"], - HEX["lightblue"], - HEX["pink"], - ] - - if "bodies" in result: - for i, body in enumerate(result["bodies"]): - body.set_facecolor(palette[i % len(palette)]) - body.set_edgecolor("black") - body.set_linewidth(mm_to_pt(DEFAULT_LINE_WIDTH_MM)) - body.set_alpha(1.0) - - # Add boxplot overlay by default (disable with boxplot=False) - add_boxplot = kwargs.pop("boxplot", True) - if add_boxplot and args: - try: - # Get data from first positional argument - data = args[0] - # Get positions if specified, otherwise use default - positions = kwargs.get("positions", None) - if positions is None: - positions = range(1, len(data) + 1) - - # Calculate boxplot width dynamically from violin width - # Get violin width from kwargs or use matplotlib default (0.5) - violin_widths = kwargs.get("widths", 0.5) - if hasattr(violin_widths, "__iter__"): - violin_widths = violin_widths[0] if len(violin_widths) > 0 else 0.5 - # Boxplot width = 20% of violin width - boxplot_widths = violin_widths * 0.2 - - # Draw boxplot overlay with styling - line_width = mm_to_pt(DEFAULT_LINE_WIDTH_MM) - marker_size = mm_to_pt(DEFAULT_MARKER_SIZE_MM) - - # Call matplotlib's boxplot directly to avoid recursive post-processing - # which would override our gray styling with the default blue - if hasattr(ax, "_axes_mpl"): - mpl_ax = ax._axes_mpl - else: - mpl_ax = ax - bp = mpl_ax.boxplot( - data, - positions=list(positions), - widths=boxplot_widths, - patch_artist=True, - manage_ticks=False, # Don't modify existing ticks - ) - - # Style the boxplot: scitex gray fill with black edges for visibility - # Set high z-order so boxplot appears on top of violin bodies - boxplot_zorder = 10 - for box in bp.get("boxes", []): - box.set_facecolor(HEX["gray"]) # Scitex gray fill - box.set_edgecolor("black") - box.set_alpha(1.0) - box.set_linewidth(line_width) - box.set_zorder(boxplot_zorder) - for median in bp.get("medians", []): - median.set_color("black") # Black median line - median.set_linewidth(line_width) # 0.2mm thickness - median.set_zorder(boxplot_zorder + 1) - for whisker in bp.get("whiskers", []): - whisker.set_color("black") - whisker.set_linewidth(line_width) - whisker.set_zorder(boxplot_zorder) - for cap in bp.get("caps", []): - cap.set_color("black") - cap.set_linewidth(line_width) - cap.set_zorder(boxplot_zorder) - for flier in bp.get("fliers", []): - flier.set_markerfacecolor("none") # No fill (open circles) - flier.set_markeredgecolor("black") - flier.set_markersize(marker_size) # 0.8mm - flier.set_markeredgewidth(line_width) # 0.2mm - flier.set_zorder(boxplot_zorder + 2) - except Exception: - pass # Silently continue if boxplot overlay fails - - -def _postprocess_boxplot(result, ax): - """Apply styling for boxplots (standalone, not violin overlay).""" - # Use the centralized style_boxplot function for consistent styling - from scitex.plt.ax import style_boxplot - - style_boxplot(result) - - # Cap width: 33% of box width - if "caps" in result and "boxes" in result and len(result["boxes"]) > 0: - try: - cap_width_pts = calculate_cap_width_from_box(result["boxes"][0], ax) - for cap in result["caps"]: - cap.set_markersize(cap_width_pts) - except Exception: - pass - - -def _postprocess_scatter(result, kwargs): - """Apply styling for scatter plots.""" - # Apply default 0.8mm marker size if 's' not specified - if "s" not in kwargs: - size_pt = mm_to_pt(DEFAULT_MARKER_SIZE_MM) - marker_area = size_pt**2 - result.set_sizes([marker_area]) - - -def _postprocess_hist(result, ax): - """Apply styling for histogram plots. - - Ensures histogram bars have proper edge color and alpha for visibility. - Delegates edge styling to figrecipe when available. - """ - # Delegate edge styling to figrecipe with fallback - from scitex.plt.styles._postprocess_helpers import apply_hist_edge_style - - apply_hist_edge_style(ax, DEFAULT_LINE_WIDTH_MM) - - # Additionally ensure alpha is at least 0.7 for visibility - if len(result) >= 3: - patches = result[2] - if hasattr(patches, "__iter__"): - for patch_group in patches: - if hasattr(patch_group, "__iter__"): - for patch in patch_group: - if patch.get_alpha() is None or patch.get_alpha() < 0.7: - patch.set_alpha(1.0) - else: - if patch_group.get_alpha() is None or patch_group.get_alpha() < 0.7: - patch_group.set_alpha(1.0) - - -def _postprocess_fill_between(result, kwargs): - """Apply styling for fill_between plots. - - Ensures shaded regions have proper alpha for visibility. - """ - # result is a PolyCollection - if result is not None: - # Only set edge if not already specified - if "edgecolor" not in kwargs and "ec" not in kwargs: - result.set_edgecolor("none") - - # Ensure alpha is reasonable (default 0.3 is common for fill_between) - if "alpha" not in kwargs: - result.set_alpha(0.3) - - -def _postprocess_bar(result, ax, kwargs): # noqa: C901 - """Apply styling for bar plots with colors and error bars.""" - # Apply scitex palette only if color not explicitly set - if "color" not in kwargs and "c" not in kwargs: - from scitex.plt.color import HEX - - palette = [ - HEX["blue"], - HEX["red"], - HEX["green"], - HEX["yellow"], - HEX["purple"], - HEX["orange"], - HEX["lightblue"], - HEX["pink"], - ] - - for i, patch in enumerate(result.patches): - patch.set_facecolor(palette[i % len(palette)]) - - # Always apply SCITEX edge styling (black, 0.2mm) - delegate to figrecipe - from scitex.plt.styles._postprocess_helpers import apply_bar_edge_style - - apply_bar_edge_style(ax, DEFAULT_LINE_WIDTH_MM) - - if "yerr" not in kwargs or kwargs["yerr"] is None: - return - - try: - errorbar = result.errorbar - if errorbar is None: - return - - lines = errorbar.lines - if not lines or len(lines) < 3: - return - - caplines = lines[1] - if caplines and len(caplines) >= 2: - # Hide lower caps (one-sided error bars) - caplines[0].set_visible(False) - - # Adjust cap width to 33% of bar width - if len(result.patches) > 0: - cap_width_pts = calculate_cap_width_from_bar( - result.patches[0], ax, "width" - ) - for cap in caplines[1:]: - cap.set_markersize(cap_width_pts) - - # Make error bar lines one-sided - barlinecols = lines[2] - make_errorbar_one_sided(barlinecols, "vertical") - except Exception: - pass - - -def _postprocess_barh(result, ax, kwargs): # noqa: C901 - """Apply styling for horizontal bar plots with colors and error bars.""" - # Apply scitex palette only if color not explicitly set - if "color" not in kwargs and "c" not in kwargs: - from scitex.plt.color import HEX - - palette = [ - HEX["blue"], - HEX["red"], - HEX["green"], - HEX["yellow"], - HEX["purple"], - HEX["orange"], - HEX["lightblue"], - HEX["pink"], - ] - - for i, patch in enumerate(result.patches): - patch.set_facecolor(palette[i % len(palette)]) - - # Always apply SCITEX edge styling (black, 0.2mm) - delegate to figrecipe - from scitex.plt.styles._postprocess_helpers import apply_bar_edge_style - - apply_bar_edge_style(ax, DEFAULT_LINE_WIDTH_MM) - - if "xerr" not in kwargs or kwargs["xerr"] is None: - return - - try: - errorbar = result.errorbar - if errorbar is None: - return - - lines = errorbar.lines - if not lines or len(lines) < 3: - return - - caplines = lines[1] - if caplines and len(caplines) >= 2: - # Hide left caps (one-sided error bars) - caplines[0].set_visible(False) - - # Adjust cap width to 33% of bar height - if len(result.patches) > 0: - cap_width_pts = calculate_cap_width_from_bar( - result.patches[0], ax, "height" - ) - for cap in caplines[1:]: - cap.set_markersize(cap_width_pts) - - # Make error bar lines one-sided - barlinecols = lines[2] - make_errorbar_one_sided(barlinecols, "horizontal") - except Exception: - pass - - -# EOF diff --git a/src/scitex/plt/styles/_postprocess_helpers.py b/src/scitex/plt/styles/_postprocess_helpers.py deleted file mode 100755 index 27fa5aa1e..000000000 --- a/src/scitex/plt/styles/_postprocess_helpers.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2026-01-13 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_postprocess_helpers.py - -"""Helper functions for plot post-processing. - -Extracted from _plot_postprocess.py to keep modules within line limits. -Delegates to figrecipe styling functions when available. -""" - -import numpy as np - -# Try to import figrecipe styling (delegate when available) -try: - from figrecipe.styles._plot_styles import ( - apply_barplot_style as _fr_apply_barplot_style, - ) - from figrecipe.styles._plot_styles import ( - apply_histogram_style as _fr_apply_histogram_style, - ) - - FIGRECIPE_AVAILABLE = True -except ImportError: - FIGRECIPE_AVAILABLE = False - -# ============================================================================ -# Constants -# ============================================================================ -CAP_WIDTH_RATIO = 1 / 3 # 33% of bar/box width - - -# ============================================================================ -# Helper functions -# ============================================================================ -def calculate_cap_width_from_box(box, ax): - """Calculate cap width as 33% of box width in points.""" - # Get box width from path - if hasattr(box, "get_path"): - path = box.get_path() - vertices = path.vertices - x_coords = vertices[:, 0] - box_width_data = x_coords.max() - x_coords.min() - elif hasattr(box, "get_xdata"): - x_data = box.get_xdata() - box_width_data = max(x_data) - min(x_data) - else: - box_width_data = 0.5 # Default - - return data_width_to_points(box_width_data, ax, "x") * CAP_WIDTH_RATIO - - -def calculate_cap_width_from_bar(patch, ax, dimension): - """Calculate cap width as 33% of bar width/height in points.""" - if dimension == "width": - bar_size = patch.get_width() - return data_width_to_points(bar_size, ax, "x") * CAP_WIDTH_RATIO - else: # height - bar_size = patch.get_height() - return data_width_to_points(bar_size, ax, "y") * CAP_WIDTH_RATIO - - -def data_width_to_points(data_size, ax, axis="x"): - """Convert a data-space size to points.""" - fig = ax.get_figure() - bbox = ax.get_position() - - if axis == "x": - ax_size_inches = bbox.width * fig.get_figwidth() - lim = ax.get_xlim() - else: - ax_size_inches = bbox.height * fig.get_figheight() - lim = ax.get_ylim() - - data_range = lim[1] - lim[0] - size_inches = (data_size / data_range) * ax_size_inches - return size_inches * 72 # 72 points per inch - - -def make_errorbar_one_sided(barlinecols, direction): - """Make error bar line segments one-sided (outward only).""" - if not barlinecols or len(barlinecols) == 0: - return - - for lc in barlinecols: - if not hasattr(lc, "get_segments"): - continue - - segs = lc.get_segments() - new_segs = [] - for seg in segs: - if len(seg) < 2: - continue - - if direction == "vertical": - # Keep upper half - bottom_y = min(seg[0][1], seg[1][1]) - top_y = max(seg[0][1], seg[1][1]) - mid_y = (bottom_y + top_y) / 2 - new_seg = np.array([[seg[0][0], mid_y], [seg[0][0], top_y]]) - else: # horizontal - # Keep right half - left_x = min(seg[0][0], seg[1][0]) - right_x = max(seg[0][0], seg[1][0]) - mid_x = (left_x + right_x) / 2 - new_seg = np.array([[mid_x, seg[0][1]], [right_x, seg[0][1]]]) - - new_segs.append(new_seg) - - if new_segs: - lc.set_segments(new_segs) - - -def apply_bar_edge_style(ax, line_width_mm): - """Apply bar edge styling, delegating to figrecipe if available. - - Parameters - ---------- - ax : matplotlib Axes or AxisWrapper - The axes containing bar patches. - line_width_mm : float - Line width in millimeters. - """ - from scitex.plt.utils import mm_to_pt - - ax_mpl = getattr(ax, "_axis_mpl", ax) - - if FIGRECIPE_AVAILABLE: - _fr_apply_barplot_style(ax_mpl, {"barplot_edge_mm": line_width_mm}) - else: - # Fallback: apply edge styling directly - from matplotlib.patches import Rectangle - - line_width_pt = mm_to_pt(line_width_mm) - for patch in ax_mpl.patches: - if isinstance(patch, Rectangle): - patch.set_edgecolor("black") - patch.set_linewidth(line_width_pt) - - -def apply_hist_edge_style(ax, line_width_mm): - """Apply histogram edge styling, delegating to figrecipe if available.""" - from scitex.plt.utils import mm_to_pt - - ax_mpl = getattr(ax, "_axis_mpl", ax) - - if FIGRECIPE_AVAILABLE: - _fr_apply_histogram_style(ax_mpl, {"histogram_edge_mm": line_width_mm}) - else: - from matplotlib.patches import Rectangle - - line_width_pt = mm_to_pt(line_width_mm) - for patch in ax_mpl.patches: - if isinstance(patch, Rectangle): - patch.set_edgecolor("black") - patch.set_linewidth(line_width_pt) - - -# EOF diff --git a/src/scitex/scholar/citation_graph/visualization.py b/src/scitex/scholar/citation_graph/visualization.py index 04d8ae98b..9a50f1078 100755 --- a/src/scitex/scholar/citation_graph/visualization.py +++ b/src/scitex/scholar/citation_graph/visualization.py @@ -17,8 +17,8 @@ # ── Backend availability flags ─────────────────────────────────────────────── try: - from figrecipe._graph import draw_graph as _fr_draw_graph - from figrecipe._graph._presets import get_preset as _fr_get_preset + from figrecipe import get_graph_preset as _fr_get_preset + from figrecipe._graph import draw_graph as _fr_draw_graph # not yet in public API _FIGRECIPE_AVAILABLE = True except ImportError: diff --git a/src/scitex/stats/_figrecipe_integration.py b/src/scitex/stats/_figrecipe_integration.py index 317fa770b..ca6212b6f 100755 --- a/src/scitex/stats/_figrecipe_integration.py +++ b/src/scitex/stats/_figrecipe_integration.py @@ -6,13 +6,11 @@ from typing import Any, Dict, List, Optional, Union try: - from figrecipe._integrations._scitex_stats import ( - annotate_from_stats as _fr_annotate, - ) - from figrecipe._integrations._scitex_stats import from_scitex_stats as _fr_convert - from figrecipe._integrations._scitex_stats import ( + from figrecipe._integrations._scitex_stats import ( # not yet in public API load_stats_bundle as _fr_load_bundle, ) + from figrecipe.utils import annotate_from_stats as _fr_annotate + from figrecipe.utils import from_scitex_stats as _fr_convert _AVAILABLE = True except ImportError: diff --git a/tests/custom/test_axes_wrapper_flat_property.py b/tests/custom/test_axes_wrapper_flat_property.py deleted file mode 100644 index 51cdb27ba..000000000 --- a/tests/custom/test_axes_wrapper_flat_property.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python3 -"""Test the flat property of AxesWrapper""" - -import pytest - -pytest.importorskip("zarr") -import matplotlib.pyplot as plt -import numpy as np - -from scitex.plt._subplots._AxesWrapper import AxesWrapper -from scitex.plt._subplots._FigWrapper import FigWrapper - - -class TestAxesWrapperFlat: - """Test suite for AxesWrapper.flat property""" - - def test_flat_returns_iterator(self): - """Test that axes.flat returns a proper iterator like numpy arrays""" - fig, axes_array = plt.subplots(2, 3) - fig_wrapped = FigWrapper(fig) - axes = AxesWrapper(fig_wrapped, axes_array) - - # Check that flat returns a numpy flatiter - assert hasattr(axes, "flat") - flat_result = axes.flat - assert type(flat_result).__name__ == "flatiter" - - # Verify we can iterate over it - flat_list = list(flat_result) - assert len(flat_list) == 6 # 2x3 = 6 axes - - # Verify all items are matplotlib axes - for ax in flat_list: - assert hasattr(ax, "plot") - assert hasattr(ax, "set_xlabel") - - def test_flat_matches_numpy_behavior(self): - """Test that axes.flat behaves like numpy array flat""" - fig, axes_array = plt.subplots(3, 4) - fig_wrapped = FigWrapper(fig) - axes = AxesWrapper(fig_wrapped, axes_array) - - # Compare with numpy behavior - axes_flat = list(axes.flat) - numpy_flat = list(axes_array.flat) - - assert len(axes_flat) == len(numpy_flat) - assert len(axes_flat) == 12 # 3x4 = 12 - - # Verify order is the same (row-major) - for i, (ax1, ax2) in enumerate(zip(axes_flat, numpy_flat)): - assert ax1 is ax2 - - def test_flat_not_list_of_lists(self): - """Regression test: ensure flat doesn't return list of lists""" - fig, axes_array = plt.subplots(2, 2) - fig_wrapped = FigWrapper(fig) - axes = AxesWrapper(fig_wrapped, axes_array) - - flat_result = axes.flat - flat_list = list(flat_result) - - # Should NOT be a list of lists - assert not isinstance(flat_list[0], list) - # Should be matplotlib axes - assert hasattr(flat_list[0], "plot") - - def test_flat_with_single_axis(self): - """Test flat property with a single axis""" - fig, ax = plt.subplots() - # Need to make it an array to match expected structure - axes_array = np.array([[ax]]) - fig_wrapped = FigWrapper(fig) - axes = AxesWrapper(fig_wrapped, axes_array) - - flat_list = list(axes.flat) - assert len(flat_list) == 1 - assert flat_list[0] is ax - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/custom/test_imports.py b/tests/custom/test_imports.py index 334d9bf54..4e960b758 100755 --- a/tests/custom/test_imports.py +++ b/tests/custom/test_imports.py @@ -122,25 +122,9 @@ def test_torch_not_imported_at_module_level(self): # torch should still not be in sys.modules # (it might be if used elsewhere, but not from the types module) - assert ( - "torch" not in sys.modules or torch_modules - ), "torch should not be imported at types module level" - - def test_joypy_not_imported_at_module_level(self): - """Test that joypy is not imported when loading scitex.plt.ax.""" - # Clear joypy from sys.modules if it exists - joypy_modules = [m for m in sys.modules if m.startswith("joypy")] - for module in joypy_modules: - del sys.modules[module] - - # Import scitex.plt.ax - from scitex.plt import ax # noqa: F401 - - # joypy should not be in sys.modules after importing the module - joypy_in_modules = any(m.startswith("joypy") for m in sys.modules) - assert ( - not joypy_in_modules - ), "joypy should not be imported at plt.ax module level (should be lazy)" + assert "torch" not in sys.modules or torch_modules, ( + "torch should not be imported at types module level" + ) class TestIsArrayLike: @@ -191,22 +175,6 @@ def test_is_array_like_with_scalar(self): assert is_array_like("string") is False -class TestPlotJoyplotLazyImport: - """Test that stx_joyplot function works with lazy joypy import.""" - - def test_stx_joyplot_import(self): - """Test that stx_joyplot can be imported.""" - from scitex.plt.ax._plot import stx_joyplot - - assert callable(stx_joyplot) - - def test_stx_joyplot_function_callable(self): - """Test that stx_joyplot is callable.""" - from scitex.plt.ax._plot._stx_joyplot import stx_joyplot - - assert callable(stx_joyplot) - - # ============================================================================== # Comprehensive Import Tests # ============================================================================== diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__labels.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__labels.py deleted file mode 100644 index fd818beb3..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__labels.py +++ /dev/null @@ -1,280 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_labels.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: _labels.py - Label rotation and legend handling -# -# """Mixin for label rotation and legend positioning.""" -# -# import os -# -# from scitex import logging -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# -# logger = logging.getLogger(__name__) -# -# -# class LabelsMixin: -# """Mixin for label rotation and legend positioning.""" -# -# def _get_ax_module(self): -# """Lazy import ax module to avoid circular imports.""" -# from .....plt import ax as ax_module -# return ax_module -# -# def rotate_labels( -# self, -# x: float = None, -# y: float = None, -# x_ha: str = None, -# y_ha: str = None, -# x_va: str = None, -# y_va: str = None, -# auto_adjust: bool = True, -# scientific_convention: bool = True, -# tight_layout: bool = False, -# ) -> None: -# """Rotate x and y axis labels with automatic positioning. -# -# Parameters -# ---------- -# x : float or None, optional -# Rotation angle for x-axis labels in degrees. -# y : float or None, optional -# Rotation angle for y-axis labels in degrees. -# x_ha, y_ha : str or None, optional -# Horizontal alignment for x/y-axis labels. -# x_va, y_va : str or None, optional -# Vertical alignment for x/y-axis labels. -# auto_adjust : bool, optional -# Whether to automatically adjust alignment. Default is True. -# scientific_convention : bool, optional -# Whether to follow scientific conventions. Default is True. -# tight_layout : bool, optional -# Whether to apply tight_layout. Default is False. -# """ -# self._axis_mpl = self._get_ax_module().rotate_labels( -# self._axis_mpl, -# x=x, -# y=y, -# x_ha=x_ha, -# y_ha=y_ha, -# x_va=x_va, -# y_va=y_va, -# auto_adjust=auto_adjust, -# scientific_convention=scientific_convention, -# tight_layout=tight_layout, -# ) -# -# def legend( -# self, *args, loc: str = "best", check_overlap: bool = False, **kwargs -# ) -> None: -# """Places legend at specified location, with support for outside positions. -# -# Parameters -# ---------- -# *args : tuple -# Positional arguments (handles, labels) as in matplotlib -# loc : str -# Legend position. Default is "best" (matplotlib auto-placement). -# Special positions: -# - "best": Matplotlib automatic placement -# - "outer": Place outside plot area (right side) -# - "separate": Save legend as a separate figure file -# - upper/lower/center variants: e.g. "upper right out" -# check_overlap : bool -# If True, checks for overlap between legend and data. -# **kwargs : dict -# Additional keyword arguments passed to legend() -# """ -# import matplotlib.pyplot as plt -# -# if loc == "outer": -# legend = self._axis_mpl.legend( -# *args, loc="center left", bbox_to_anchor=(1.02, 0.5), **kwargs -# ) -# if hasattr(self, "_figure_wrapper") and self._figure_wrapper: -# self._figure_wrapper._fig_mpl.tight_layout() -# self._figure_wrapper._fig_mpl.subplots_adjust(right=0.85) -# return legend -# -# elif loc == "separate": -# handles, labels = self._axis_mpl.get_legend_handles_labels() -# if not handles: -# logger.warning("No legend handles found.") -# return None -# -# fig = self._axis_mpl.get_figure() -# if not hasattr(fig, "_separate_legend_params"): -# fig._separate_legend_params = [] -# -# figsize = kwargs.pop("figsize", (4, 3)) -# dpi = kwargs.pop("dpi", 150) -# frameon = kwargs.pop("frameon", True) -# fancybox = kwargs.pop("fancybox", True) -# shadow = kwargs.pop("shadow", True) -# -# axis_id = self._get_axis_id(fig) -# -# fig._separate_legend_params.append({ -# "axis": self._axis_mpl, -# "axis_id": axis_id, -# "handles": handles, -# "labels": labels, -# "figsize": figsize, -# "dpi": dpi, -# "frameon": frameon, -# "fancybox": fancybox, -# "shadow": shadow, -# "kwargs": kwargs, -# }) -# -# if self._axis_mpl.get_legend(): -# self._axis_mpl.get_legend().remove() -# -# return None -# -# outside_positions = { -# "upper right out": ("center left", (1.15, 0.85)), -# "right upper out": ("center left", (1.15, 0.85)), -# "center right out": ("center left", (1.15, 0.5)), -# "right out": ("center left", (1.15, 0.5)), -# "right": ("center left", (1.05, 0.5)), -# "lower right out": ("center left", (1.15, 0.15)), -# "right lower out": ("center left", (1.15, 0.15)), -# "upper left out": ("center right", (-0.25, 0.85)), -# "left upper out": ("center right", (-0.25, 0.85)), -# "center left out": ("center right", (-0.25, 0.5)), -# "left out": ("center right", (-0.25, 0.5)), -# "left": ("center right", (-0.15, 0.5)), -# "lower left out": ("center right", (-0.25, 0.15)), -# "left lower out": ("center right", (-0.25, 0.15)), -# "upper center out": ("lower center", (0.5, 1.25)), -# "upper out": ("lower center", (0.5, 1.25)), -# "lower center out": ("upper center", (0.5, -0.25)), -# "lower out": ("upper center", (0.5, -0.25)), -# } -# -# if loc in outside_positions: -# location, bbox = outside_positions[loc] -# legend_obj = self._axis_mpl.legend( -# *args, loc=location, bbox_to_anchor=bbox, **kwargs -# ) -# else: -# legend_obj = self._axis_mpl.legend(*args, loc=loc, **kwargs) -# -# if check_overlap and legend_obj is not None: -# self._check_legend_overlap(legend_obj) -# -# return legend_obj -# -# def _get_axis_id(self, fig): -# """Get unique axis identifier for separate legend handling.""" -# axis_id = None -# -# try: -# fig_axes = fig.get_axes() -# for idx, ax in enumerate(fig_axes): -# if ax is self._axis_mpl: -# axis_id = f"ax_{idx:02d}" -# break -# except: -# pass -# -# if axis_id is None and hasattr(self._axis_mpl, "get_subplotspec"): -# try: -# spec = self._axis_mpl.get_subplotspec() -# if spec is not None: -# gridspec = spec.get_gridspec() -# nrows, ncols = gridspec.get_geometry() -# rowspan = spec.rowspan -# colspan = spec.colspan -# row_start = rowspan.start if hasattr(rowspan, "start") else rowspan -# col_start = colspan.start if hasattr(colspan, "start") else colspan -# flat_idx = row_start * ncols + col_start -# axis_id = f"ax_{flat_idx:02d}" -# except: -# pass -# -# if axis_id is None: -# axis_id = f"ax_{len(fig._separate_legend_params):02d}" -# -# return axis_id -# -# def _check_legend_overlap(self, legend_obj): -# """Check if legend overlaps with plotted data and issue warning if needed.""" -# import warnings -# import matplotlib.transforms as transforms -# import numpy as np -# -# try: -# fig = self._axis_mpl.get_figure() -# fig.canvas.draw() -# -# legend_bbox = legend_obj.get_window_extent(fig.canvas.get_renderer()) -# inv_transform = self._axis_mpl.transData.inverted() -# legend_bbox_data = legend_bbox.transformed(inv_transform) -# -# data_bboxes = [] -# -# for line in self._axis_mpl.get_lines(): -# if line.get_visible(): -# try: -# data = line.get_xydata() -# if len(data) > 0: -# data_bboxes.append(data) -# except: -# pass -# -# for collection in self._axis_mpl.collections: -# if collection.get_visible(): -# try: -# offsets = collection.get_offsets() -# if len(offsets) > 0: -# data_bboxes.append(offsets) -# except: -# pass -# -# if data_bboxes: -# all_data = np.vstack(data_bboxes) -# -# x_overlap = (all_data[:, 0] >= legend_bbox_data.x0) & ( -# all_data[:, 0] <= legend_bbox_data.x1 -# ) -# y_overlap = (all_data[:, 1] >= legend_bbox_data.y0) & ( -# all_data[:, 1] <= legend_bbox_data.y1 -# ) -# overlap_points = np.sum(x_overlap & y_overlap) -# overlap_pct = (overlap_points / len(all_data)) * 100 -# -# if overlap_pct > 5: -# logger.warning( -# f"Legend overlaps with {overlap_pct:.1f}% of data points. " -# f"Consider using loc='outer' or loc='separate'." -# ) -# return True -# -# except Exception: -# pass -# -# return False -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_labels.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__metadata.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__metadata.py deleted file mode 100644 index f1b787368..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__metadata.py +++ /dev/null @@ -1,229 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_metadata.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: _metadata.py - Axis metadata and labels -# -# """Mixin for axis labels, titles, and metadata.""" -# -# import os -# from typing import Optional -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# -# -# class MetadataMixin: -# """Mixin for setting axis labels, titles, and metadata.""" -# -# def _get_ax_module(self): -# """Lazy import ax module to avoid circular imports.""" -# from .....plt import ax as ax_module -# return ax_module -# -# def set_xyt( -# self, -# x: Optional[str] = None, -# y: Optional[str] = None, -# t: Optional[str] = None, -# format_labels: bool = True, -# ) -> None: -# """Set xlabel, ylabel, and title.""" -# self._axis_mpl = self._get_ax_module().set_xyt( -# self._axis_mpl, -# x=x, -# y=y, -# t=t, -# format_labels=format_labels, -# ) -# -# def set_xytc( -# self, -# x: Optional[str] = None, -# y: Optional[str] = None, -# t: Optional[str] = None, -# c: Optional[str] = None, -# format_labels: bool = True, -# ) -> None: -# """Set xlabel, ylabel, title, and caption for automatic saving. -# -# Parameters -# ---------- -# x : str, optional -# X-axis label -# y : str, optional -# Y-axis label -# t : str, optional -# Title -# c : str, optional -# Caption to be saved automatically with scitex.io.save() -# format_labels : bool, optional -# Whether to apply automatic formatting, by default True -# """ -# self._axis_mpl = self._get_ax_module().set_xytc( -# self._axis_mpl, -# x=x, -# y=y, -# t=t, -# c=c, -# format_labels=format_labels, -# ) -# -# if c is not False and c is not None: -# self._scitex_caption = c -# -# def set_supxyt( -# self, -# xlabel: Optional[str] = None, -# ylabel: Optional[str] = None, -# title: Optional[str] = None, -# format_labels: bool = True, -# ) -> None: -# """Set figure-level xlabel, ylabel, and title (suptitle).""" -# self._axis_mpl = self._get_ax_module().set_supxyt( -# self._axis_mpl, -# xlabel=xlabel, -# ylabel=ylabel, -# title=title, -# format_labels=format_labels, -# ) -# -# def set_supxytc( -# self, -# xlabel: Optional[str] = None, -# ylabel: Optional[str] = None, -# title: Optional[str] = None, -# caption: Optional[str] = None, -# format_labels: bool = True, -# ) -> None: -# """Set figure-level xlabel, ylabel, title, and caption. -# -# Parameters -# ---------- -# xlabel : str, optional -# Figure-level X-axis label -# ylabel : str, optional -# Figure-level Y-axis label -# title : str, optional -# Figure-level title (suptitle) -# caption : str, optional -# Figure-level caption for automatic saving -# format_labels : bool, optional -# Whether to apply automatic formatting -# """ -# self._axis_mpl = self._get_ax_module().set_supxytc( -# self._axis_mpl, -# xlabel=xlabel, -# ylabel=ylabel, -# title=title, -# caption=caption, -# format_labels=format_labels, -# ) -# -# if caption is not False and caption is not None: -# fig = self._axis_mpl.get_figure() -# fig._scitex_main_caption = caption -# -# def set_meta( -# self, -# caption=None, -# methods=None, -# stats=None, -# keywords=None, -# experimental_details=None, -# journal_style=None, -# significance=None, -# **kwargs, -# ) -> None: -# """Set comprehensive scientific metadata with YAML export capability. -# -# Parameters -# ---------- -# caption : str, optional -# Figure caption text -# methods : str, optional -# Experimental methods description -# stats : str, optional -# Statistical analysis details -# keywords : List[str], optional -# Keywords for categorization -# experimental_details : Dict[str, Any], optional -# Structured experimental parameters -# journal_style : str, optional -# Target journal style -# significance : str, optional -# Significance statement -# **kwargs : additional metadata -# """ -# self._axis_mpl = self._get_ax_module().set_meta( -# self._axis_mpl, -# caption=caption, -# methods=methods, -# stats=stats, -# keywords=keywords, -# experimental_details=experimental_details, -# journal_style=journal_style, -# significance=significance, -# **kwargs, -# ) -# -# def set_figure_meta( -# self, -# caption=None, -# methods=None, -# stats=None, -# significance=None, -# funding=None, -# conflicts=None, -# data_availability=None, -# **kwargs, -# ) -> None: -# """Set figure-level metadata for multi-panel figures. -# -# Parameters -# ---------- -# caption : str, optional -# Figure-level caption -# methods : str, optional -# Overall experimental methods -# stats : str, optional -# Overall statistical approach -# significance : str, optional -# Significance and implications -# funding : str, optional -# Funding acknowledgments -# conflicts : str, optional -# Conflict of interest statement -# data_availability : str, optional -# Data availability statement -# **kwargs : additional metadata -# """ -# self._axis_mpl = self._get_ax_module().set_figure_meta( -# self._axis_mpl, -# caption=caption, -# methods=methods, -# stats=stats, -# significance=significance, -# funding=funding, -# conflicts=conflicts, -# data_availability=data_availability, -# **kwargs, -# ) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_metadata.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__visual.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__visual.py deleted file mode 100644 index 072066869..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/test__visual.py +++ /dev/null @@ -1,144 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_visual.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: _visual.py - Visual adjustments (ticks, spines, position) -# -# """Mixin for visual adjustments including ticks, spines, and positioning.""" -# -# import os -# from typing import List, Optional, Union -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# -# -# class VisualAdjustmentMixin: -# """Mixin for visual adjustments to axis appearance.""" -# -# def _get_ax_module(self): -# """Lazy import ax module to avoid circular imports.""" -# from .....plt import ax as ax_module -# return ax_module -# -# def set_ticks( -# self, -# xvals: Optional[List[Union[int, float]]] = None, -# xticks: Optional[List[str]] = None, -# yvals: Optional[List[Union[int, float]]] = None, -# yticks: Optional[List[str]] = None, -# ) -> None: -# """Set custom tick positions and labels. -# -# Parameters -# ---------- -# xvals : list of numbers, optional -# Positions for x-axis ticks -# xticks : list of str, optional -# Labels for x-axis ticks -# yvals : list of numbers, optional -# Positions for y-axis ticks -# yticks : list of str, optional -# Labels for y-axis ticks -# """ -# self._axis_mpl = self._get_ax_module().set_ticks( -# self._axis_mpl, -# xvals=xvals, -# xticks=xticks, -# yvals=yvals, -# yticks=yticks, -# ) -# -# def set_n_ticks(self, n_xticks: int = 4, n_yticks: int = 4) -> None: -# """Set the number of ticks on each axis. -# -# Parameters -# ---------- -# n_xticks : int, optional -# Number of ticks on x-axis, by default 4 -# n_yticks : int, optional -# Number of ticks on y-axis, by default 4 -# """ -# self._axis_mpl = self._get_ax_module().set_n_ticks( -# self._axis_mpl, n_xticks=n_xticks, n_yticks=n_yticks -# ) -# -# def hide_spines( -# self, -# top: bool = True, -# bottom: bool = False, -# left: bool = False, -# right: bool = True, -# ticks: bool = False, -# labels: bool = False, -# ) -> None: -# """Hide specific spines and optionally ticks/labels. -# -# Parameters -# ---------- -# top : bool, optional -# Hide top spine, by default True -# bottom : bool, optional -# Hide bottom spine, by default False -# left : bool, optional -# Hide left spine, by default False -# right : bool, optional -# Hide right spine, by default True -# ticks : bool, optional -# Hide all ticks, by default False -# labels : bool, optional -# Hide all tick labels, by default False -# """ -# self._axis_mpl = self._get_ax_module().hide_spines( -# self._axis_mpl, -# top=top, -# bottom=bottom, -# left=left, -# right=right, -# ticks=ticks, -# labels=labels, -# ) -# -# def extend(self, x_ratio: float = 1.0, y_ratio: float = 1.0) -> None: -# """Extend axis limits by a ratio. -# -# Parameters -# ---------- -# x_ratio : float, optional -# Ratio to extend x-axis by, by default 1.0 -# y_ratio : float, optional -# Ratio to extend y-axis by, by default 1.0 -# """ -# self._axis_mpl = self._get_ax_module().extend( -# self._axis_mpl, x_ratio=x_ratio, y_ratio=y_ratio -# ) -# -# def shift(self, dx: float = 0, dy: float = 0) -> None: -# """Shift axis position. -# -# Parameters -# ---------- -# dx : float, optional -# Horizontal shift, by default 0 -# dy : float, optional -# Vertical shift, by default 0 -# """ -# self._axis_mpl = self._get_ax_module().shift(self._axis_mpl, dx=dx, dy=dy) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_AdjustmentMixin/_visual.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__base.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__base.py deleted file mode 100644 index 42c29b8e2..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__base.py +++ /dev/null @@ -1,50 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_base.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: _base.py - Core helper methods for MatplotlibPlotMixin -# -# """Base mixin with core helper methods for plotting.""" -# -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# -# -# class PlotBaseMixin: -# """Base mixin with core helper methods for plotting.""" -# -# def _get_ax_module(self): -# """Lazy import ax module to avoid circular imports.""" -# from .....plt import ax as ax_module -# return ax_module -# -# def _apply_scitex_postprocess( -# self, method_name, result=None, kwargs=None, args=None -# ): -# """Apply scitex post-processing styling after plotting. -# -# This ensures all scitex wrapper methods get the same styling -# as matplotlib methods going through __getattr__ (tick locator, spines, etc.). -# """ -# from scitex.plt.styles import apply_plot_postprocess -# apply_plot_postprocess(method_name, result, self._axis_mpl, kwargs or {}, args) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_base.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__scientific.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__scientific.py deleted file mode 100644 index e36fe7d38..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__scientific.py +++ /dev/null @@ -1,609 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_scientific.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: _scientific.py - Scientific/specialized plot methods -# -# """Scientific and domain-specific plotting methods.""" -# -# import os -# from typing import Any, Dict, List, Optional, Tuple -# -# import matplotlib -# import numpy as np -# import pandas as pd -# from scipy.stats import gaussian_kde -# -# from scitex.pd import to_xyz -# from scitex.types import ArrayLike -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# -# -# class ScientificPlotMixin: -# """Mixin for scientific and domain-specific plotting methods. -# -# Provides specialized visualizations for: -# - Image display with colorbars -# - Kernel density estimation -# - Confusion matrices -# - Raster plots (spike trains) -# - ECDF plots -# - Joint distributions (scatter + marginal histograms) -# - Heatmaps with annotations -# """ -# -# def stx_image( -# self, -# data: ArrayLike, -# *, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> "Axes": -# """Display a 2D array as an image with SciTeX styling. -# -# Parameters -# ---------- -# data : array-like -# 2D array to display as an image. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the image function. -# Common options: cmap, vmin, vmax, aspect, colorbar. -# -# Returns -# ------- -# Axes -# The axes with the image displayed. -# -# See Also -# -------- -# stx_imshow : Lower-level image display. -# stx_heatmap : Annotated heatmap. -# sns_heatmap : DataFrame-based heatmap. -# -# Examples -# -------- -# >>> ax.stx_image(matrix, cmap='viridis', colorbar=True) -# """ -# method_name = "stx_image" -# -# with self._no_tracking(): -# self._axis_mpl = self._get_ax_module().stx_image( -# self._axis_mpl, data, **kwargs -# ) -# -# tracked_dict = {"image_df": pd.DataFrame(data)} -# if kwargs.get("xyz", False): -# tracked_dict["image_df"] = to_xyz(tracked_dict["image_df"]) -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl -# -# def stx_kde( -# self, -# data: ArrayLike, -# *, -# cumulative: bool = False, -# fill: bool = False, -# xlim: Optional[Tuple[float, float]] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> "Axes": -# """Plot a kernel density estimate of the data. -# -# Parameters -# ---------- -# data : array-like -# 1D array of values for density estimation. -# cumulative : bool, default False -# If True, plot cumulative distribution instead of density. -# fill : bool, default False -# If True, fill the area under the curve. -# xlim : tuple of float, optional -# Range for the x-axis. If None, uses data range. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the plot function. -# Common options: color, linewidth, linestyle, label. -# -# Returns -# ------- -# Axes -# The axes with the KDE plot. -# -# See Also -# -------- -# sns_kdeplot : DataFrame-based KDE plot. -# stx_ecdf : Empirical cumulative distribution function. -# hist : Histogram alternative. -# -# Examples -# -------- -# >>> ax.stx_kde(samples, fill=True, alpha=0.3) -# >>> ax.stx_kde(data, cumulative=True, label='CDF') -# """ -# method_name = "stx_kde" -# -# n_samples = (~np.isnan(data)).sum() -# if kwargs.get("label"): -# kwargs["label"] = f"{kwargs['label']} ($n$={n_samples})" -# -# if xlim is None: -# xlim = (np.nanmin(data), np.nanmax(data)) -# -# xx = np.linspace(xlim[0], xlim[1], int(1e3)) -# density = gaussian_kde(data)(xx) -# density /= density.sum() -# -# if cumulative: -# density = np.cumsum(density) -# -# with self._no_tracking(): -# from scitex.plt.utils import mm_to_pt -# -# if "linewidth" not in kwargs and "lw" not in kwargs: -# kwargs["linewidth"] = mm_to_pt(0.2) -# if "color" not in kwargs and "c" not in kwargs: -# kwargs["color"] = "black" -# if "linestyle" not in kwargs and "ls" not in kwargs: -# kwargs["linestyle"] = "--" -# -# if fill: -# self._axis_mpl.fill_between(xx, density, **kwargs) -# else: -# self._axis_mpl.plot(xx, density, **kwargs) -# -# tracked_dict = {"x": xx, "kde": density, "n": n_samples} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl -# -# def stx_conf_mat( -# self, -# data: ArrayLike, -# *, -# x_labels: Optional[List[str]] = None, -# y_labels: Optional[List[str]] = None, -# title: str = "Confusion Matrix", -# cmap: str = "Blues", -# cbar: bool = True, -# cbar_kw: Optional[Dict[str, Any]] = None, -# label_rotation_xy: Tuple[float, float] = (15, 15), -# x_extend_ratio: float = 1.0, -# y_extend_ratio: float = 1.0, -# calc_bacc: bool = False, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple["Axes", Optional[float]]: -# """Plot a confusion matrix with optional balanced accuracy calculation. -# -# Parameters -# ---------- -# data : array-like -# 2D confusion matrix array. -# x_labels : list of str, optional -# Labels for x-axis (predicted classes). -# y_labels : list of str, optional -# Labels for y-axis (true classes). -# title : str, default 'Confusion Matrix' -# Title for the plot. -# cmap : str, default 'Blues' -# Colormap for the heatmap. -# cbar : bool, default True -# Whether to show the colorbar. -# cbar_kw : dict, optional -# Additional keyword arguments for the colorbar. -# label_rotation_xy : tuple of float, default (15, 15) -# Rotation angles for (x, y) axis labels. -# x_extend_ratio : float, default 1.0 -# Ratio to extend x-axis limits. -# y_extend_ratio : float, default 1.0 -# Ratio to extend y-axis limits. -# calc_bacc : bool, default False -# Whether to calculate and return balanced accuracy. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the heatmap function. -# -# Returns -# ------- -# tuple -# (Axes, balanced_accuracy) - balanced_accuracy is None if calc_bacc=False. -# -# Examples -# -------- -# >>> ax.stx_conf_mat(cm, x_labels=['A', 'B'], y_labels=['A', 'B']) -# >>> ax, bacc = ax.stx_conf_mat(cm, calc_bacc=True) -# """ -# method_name = "stx_conf_mat" -# -# if cbar_kw is None: -# cbar_kw = {} -# -# with self._no_tracking(): -# self._axis_mpl, bacc_val = self._get_ax_module().stx_conf_mat( -# self._axis_mpl, -# data, -# x_labels=x_labels, -# y_labels=y_labels, -# title=title, -# cmap=cmap, -# cbar=cbar, -# cbar_kw=cbar_kw, -# label_rotation_xy=label_rotation_xy, -# x_extend_ratio=x_extend_ratio, -# y_extend_ratio=y_extend_ratio, -# calc_bacc=calc_bacc, -# **kwargs, -# ) -# -# tracked_dict = { -# "args": [data], -# "balanced_accuracy": bacc_val, -# "x_labels": x_labels, -# "y_labels": y_labels, -# } -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl, bacc_val -# -# def stx_raster( -# self, -# spike_times: List[ArrayLike], -# *, -# time: Optional[ArrayLike] = None, -# labels: Optional[List[str]] = None, -# colors: Optional[List[str]] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple["Axes", pd.DataFrame]: -# """Plot a raster plot (spike train visualization). -# -# Parameters -# ---------- -# spike_times : list of array-like -# List of arrays, each containing spike times for one unit/neuron. -# time : array-like, optional -# Time axis reference. If None, uses spike time range. -# labels : list of str, optional -# Labels for each unit/row. -# colors : list of str, optional -# Colors for each unit/row. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the raster function. -# -# Returns -# ------- -# tuple -# (Axes, DataFrame) - The axes and digitized raster data. -# -# Examples -# -------- -# >>> ax.stx_raster([spikes_unit1, spikes_unit2], labels=['Unit 1', 'Unit 2']) -# """ -# method_name = "stx_raster" -# -# with self._no_tracking(): -# self._axis_mpl, raster_digit_df = self._get_ax_module().stx_raster( -# self._axis_mpl, spike_times, time=time -# ) -# -# tracked_dict = {"raster_digit_df": raster_digit_df} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl, raster_digit_df -# -# def stx_ecdf( -# self, -# data: ArrayLike, -# *, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple["Axes", pd.DataFrame]: -# """Plot an empirical cumulative distribution function (ECDF). -# -# Parameters -# ---------- -# data : array-like -# 1D array of values. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the ECDF function. -# Common options: color, linewidth, label. -# -# Returns -# ------- -# tuple -# (Axes, DataFrame) - The axes and ECDF data (x, y columns). -# -# See Also -# -------- -# stx_kde : Kernel density estimate (continuous). -# hist : Histogram (discrete bins). -# -# Examples -# -------- -# >>> ax.stx_ecdf(samples, label='Distribution A') -# """ -# method_name = "stx_ecdf" -# -# with self._no_tracking(): -# self._axis_mpl, ecdf_df = self._get_ax_module().stx_ecdf( -# self._axis_mpl, data, **kwargs -# ) -# -# tracked_dict = {"ecdf_df": ecdf_df} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl, ecdf_df -# -# def stx_joyplot( -# self, -# data: ArrayLike, -# *, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> "Axes": -# """Plot a joyplot (ridgeline plot) for distribution comparison. -# -# Parameters -# ---------- -# data : array-like -# 2D array where each row is a distribution to plot. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the joyplot function. -# -# Returns -# ------- -# Axes -# The axes with the joyplot. -# -# Examples -# -------- -# >>> ax.stx_joyplot(distributions_2d, overlap=0.5) -# """ -# method_name = "stx_joyplot" -# -# with self._no_tracking(): -# self._axis_mpl = self._get_ax_module().stx_joyplot( -# self._axis_mpl, data, **kwargs -# ) -# -# tracked_dict = {"joyplot_data": data} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl -# -# def stx_scatter_hist( -# self, -# x: ArrayLike, -# y: ArrayLike, -# *, -# hist_bins: int = 20, -# scatter_alpha: float = 0.6, -# scatter_size: float = 20, -# scatter_color: str = "blue", -# hist_color_x: str = "blue", -# hist_color_y: str = "red", -# hist_alpha: float = 0.5, -# scatter_ratio: float = 0.8, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple["Axes", "Axes", "Axes", Dict]: -# """Plot a scatter plot with marginal histograms. -# -# Parameters -# ---------- -# x : array-like -# X coordinates of the scatter points. -# y : array-like -# Y coordinates of the scatter points. -# hist_bins : int, default 20 -# Number of bins for the marginal histograms. -# scatter_alpha : float, default 0.6 -# Transparency of scatter points. -# scatter_size : float, default 20 -# Size of scatter points. -# scatter_color : str, default 'blue' -# Color of scatter points. -# hist_color_x : str, default 'blue' -# Color of x-marginal histogram. -# hist_color_y : str, default 'red' -# Color of y-marginal histogram. -# hist_alpha : float, default 0.5 -# Transparency of histograms. -# scatter_ratio : float, default 0.8 -# Ratio of scatter plot area to total area. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the scatter function. -# -# Returns -# ------- -# tuple -# (main_ax, hist_x_ax, hist_y_ax, hist_data) - Axes and histogram data. -# -# See Also -# -------- -# stx_scatter : Simple scatter plot. -# sns_jointplot : Seaborn joint plot. -# -# Examples -# -------- -# >>> ax, ax_hx, ax_hy, data = ax.stx_scatter_hist(x, y, hist_bins=30) -# """ -# method_name = "stx_scatter_hist" -# -# with self._no_tracking(): -# self._axis_mpl, ax_histx, ax_histy, hist_data = ( -# self._get_ax_module().stx_scatter_hist( -# self._axis_mpl, -# x, -# y, -# hist_bins=hist_bins, -# scatter_alpha=scatter_alpha, -# scatter_size=scatter_size, -# scatter_color=scatter_color, -# hist_color_x=hist_color_x, -# hist_color_y=hist_color_y, -# hist_alpha=hist_alpha, -# scatter_ratio=scatter_ratio, -# **kwargs, -# ) -# ) -# -# tracked_dict = { -# "x": x, -# "y": y, -# "hist_x": hist_data["hist_x"], -# "hist_y": hist_data["hist_y"], -# "bin_edges_x": hist_data["bin_edges_x"], -# "bin_edges_y": hist_data["bin_edges_y"], -# } -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl, ax_histx, ax_histy, hist_data -# -# def stx_heatmap( -# self, -# data: ArrayLike, -# *, -# x_labels: Optional[List[str]] = None, -# y_labels: Optional[List[str]] = None, -# cmap: str = "viridis", -# cbar_label: str = "ColorBar Label", -# value_format: str = "{x:.1f}", -# show_annot: bool = True, -# annot_color_lighter: str = "white", -# annot_color_darker: str = "black", -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple["Axes", matplotlib.image.AxesImage, matplotlib.colorbar.Colorbar]: -# """Plot an annotated heatmap. -# -# Parameters -# ---------- -# data : array-like -# 2D array of values to display. -# x_labels : list of str, optional -# Labels for x-axis (columns). -# y_labels : list of str, optional -# Labels for y-axis (rows). -# cmap : str, default 'viridis' -# Colormap name. -# cbar_label : str, default 'ColorBar Label' -# Label for the colorbar. -# value_format : str, default '{x:.1f}' -# Format string for cell annotations. -# show_annot : bool, default True -# Whether to show value annotations in cells. -# annot_color_lighter : str, default 'white' -# Annotation color for dark backgrounds. -# annot_color_darker : str, default 'black' -# Annotation color for light backgrounds. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the heatmap function. -# -# Returns -# ------- -# tuple -# (Axes, AxesImage, Colorbar) - The axes, image, and colorbar objects. -# -# See Also -# -------- -# sns_heatmap : DataFrame-based heatmap. -# stx_conf_mat : Confusion matrix heatmap. -# stx_image : Simple image display. -# -# Examples -# -------- -# >>> ax, im, cbar = ax.stx_heatmap(matrix, x_labels=['A', 'B'], cmap='coolwarm') -# """ -# method_name = "stx_heatmap" -# -# with self._no_tracking(): -# ax, im, cbar = self._get_ax_module().stx_heatmap( -# self._axis_mpl, -# data, -# x_labels=x_labels, -# y_labels=y_labels, -# cmap=cmap, -# cbar_label=cbar_label, -# value_format=value_format, -# show_annot=show_annot, -# annot_color_lighter=annot_color_lighter, -# annot_color_darker=annot_color_darker, -# **kwargs, -# ) -# -# tracked_dict = { -# "data": data, -# "x_labels": x_labels, -# "y_labels": y_labels, -# } -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return ax, im, cbar -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_scientific.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__statistical.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__statistical.py deleted file mode 100644 index 7cb7bad02..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__statistical.py +++ /dev/null @@ -1,670 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_statistical.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: _statistical.py - Statistical plot methods -# -# """Statistical plotting methods including line plots, box plots, and violin plots.""" -# -# import os -# from typing import List, Optional, Sequence, Tuple, Union -# -# import numpy as np -# import pandas as pd -# -# from scitex.types import ArrayLike -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# -# -# class StatisticalPlotMixin: -# """Mixin for statistical plotting methods. -# -# Provides methods for: -# - Distribution plots (boxplot, violin) -# - Line plots with uncertainty (mean±std, mean±CI, median±IQR) -# - Histograms with bin alignment -# - Geometric shapes (rectangles, filled regions) -# """ -# -# def stx_rectangle( -# self, -# x: float, -# y: float, -# width: float, -# height: float, -# *, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> "Axes": -# """Draw a rectangle on the axes. -# -# Parameters -# ---------- -# x : float -# X coordinate of the lower-left corner. -# y : float -# Y coordinate of the lower-left corner. -# width : float -# Width of the rectangle. -# height : float -# Height of the rectangle. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the rectangle function. -# -# Returns -# ------- -# Axes -# The axes with the rectangle added. -# -# Examples -# -------- -# >>> ax.stx_rectangle(0, 0, 1, 2, color='blue', alpha=0.5) -# """ -# method_name = "stx_rectangle" -# -# with self._no_tracking(): -# self._axis_mpl = self._get_ax_module().stx_rectangle( -# self._axis_mpl, x, y, width, height, **kwargs -# ) -# -# tracked_dict = {"x": x, "y": y, "width": width, "height": height} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl -# -# def stx_fillv( -# self, -# starts: ArrayLike, -# ends: ArrayLike, -# *, -# color: str = "red", -# alpha: float = 0.2, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> "Axes": -# """Fill vertical spans between start and end positions. -# -# Parameters -# ---------- -# starts : array-like -# Start x-coordinates of each span. -# ends : array-like -# End x-coordinates of each span. -# color : str, default 'red' -# Fill color. -# alpha : float, default 0.2 -# Transparency level. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the fill function. -# -# Returns -# ------- -# Axes -# The axes with the filled spans added. -# -# Examples -# -------- -# >>> ax.stx_fillv([0, 2, 4], [1, 3, 5], color='green') -# """ -# method_name = "stx_fillv" -# -# with self._no_tracking(): -# self._axis_mpl = self._get_ax_module().stx_fillv( -# self._axis_mpl, starts, ends, color=color, alpha=alpha -# ) -# -# tracked_dict = {"starts": starts, "ends": ends} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl -# -# def stx_box( -# self, -# data: Union[ArrayLike, Sequence[ArrayLike]], -# *, -# colors: Optional[List[str]] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> dict: -# """Create a boxplot with SciTeX styling and tracking. -# -# Parameters -# ---------- -# data : array-like or sequence of array-like -# Data for the boxplot. Can be a single array or list of arrays -# where each array represents a group. -# colors : list of str, optional -# Colors for each box. If None, uses default palette. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `matplotlib.axes.Axes.boxplot`. -# -# Returns -# ------- -# dict -# Dictionary mapping component names ('boxes', 'whiskers', etc.) -# to lists of Line2D or Patch artists. -# -# See Also -# -------- -# stx_boxplot : Alias for this method. -# sns_boxplot : DataFrame-based boxplot. -# stx_violin : Violin plot alternative. -# -# Examples -# -------- -# >>> ax.stx_box([data1, data2, data3], labels=['A', 'B', 'C']) -# >>> ax.stx_box(data, notch=True, patch_artist=True) -# """ -# method_name = "stx_box" -# -# _data = data.copy() -# -# if kwargs.get("label"): -# n_per_group = [len(g) for g in data] -# n_min, n_max = min(n_per_group), max(n_per_group) -# n_str = str(n_min) if n_min == n_max else f"{n_min}-{n_max}" -# kwargs["label"] = kwargs["label"] + f" ($n$={n_str})" -# -# if "patch_artist" not in kwargs: -# kwargs["patch_artist"] = True -# -# with self._no_tracking(): -# result = self._axis_mpl.boxplot(data, **kwargs) -# -# n_per_group = [len(g) for g in data] -# tracked_dict = {"data": _data, "n": n_per_group} -# self._track(track, id, method_name, tracked_dict, None) -# -# from scitex.plt.ax import style_boxplot -# -# style_boxplot(result, colors=colors) -# -# self._apply_scitex_postprocess(method_name, result) -# -# return result -# -# def hist( -# self, -# x: ArrayLike, -# *, -# bins: Union[int, str, ArrayLike] = 10, -# range: Optional[Tuple[float, float]] = None, -# align_bins: bool = True, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple[np.ndarray, np.ndarray, "BarContainer"]: -# """Plot a histogram with optional bin alignment across multiple histograms. -# -# Parameters -# ---------- -# x : array-like -# Input data for the histogram. -# bins : int, str, or array-like, default 10 -# Number of bins, binning strategy ('auto', 'fd', etc.), or bin edges. -# range : tuple of float, optional -# Lower and upper range of the bins. If None, uses data range. -# align_bins : bool, default True -# When True, aligns bins across multiple histograms on the same axes. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `matplotlib.axes.Axes.hist`. -# -# Returns -# ------- -# tuple -# (counts, bin_edges, patches) from matplotlib hist. -# -# See Also -# -------- -# sns_histplot : DataFrame-based histogram with KDE support. -# -# Examples -# -------- -# >>> ax.hist(data, bins=20, density=True) -# >>> ax.hist(data, bins='auto', alpha=0.7, label='Group A') -# """ -# method_name = "hist" -# -# axis_id = str(hash(self._axis_mpl)) -# hist_id = id if id is not None else str(self.id) -# -# if align_bins: -# from .....plt.utils import histogram_bin_manager -# -# bins, range = histogram_bin_manager.register_histogram( -# axis_id, hist_id, x, bins, range -# ) -# -# with self._no_tracking(): -# hist_data = self._axis_mpl.hist(x, bins=bins, range=range, **kwargs) -# -# tracked_dict = { -# "args": (x,), -# "hist_result": (hist_data[0], hist_data[1]), -# "bins": bins, -# "range": range, -# } -# self._track(track, id, method_name, tracked_dict, kwargs) -# self._apply_scitex_postprocess(method_name, hist_data) -# -# return hist_data -# -# def stx_violin( -# self, -# data: Union[pd.DataFrame, List, ArrayLike], -# *, -# x: Optional[str] = None, -# y: Optional[str] = None, -# hue: Optional[str] = None, -# labels: Optional[List[str]] = None, -# colors: Optional[List[str]] = None, -# half: bool = False, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> "Axes": -# """Create a violin plot with SciTeX styling and tracking. -# -# Parameters -# ---------- -# data : DataFrame, list, or array-like -# Data for the violin plot. Can be: -# - List of arrays (one per violin) -# - DataFrame with columns specified by x, y, hue -# x : str, optional -# Column name for x-axis grouping (DataFrame input). -# y : str, optional -# Column name for y-axis values (DataFrame input). -# hue : str, optional -# Column name for color grouping (DataFrame input). -# labels : list of str, optional -# Labels for each violin (list input). -# colors : list of str, optional -# Colors for each violin. -# half : bool, default False -# If True, draw half-violins (useful for paired comparisons). -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the violin function. -# -# Returns -# ------- -# Axes -# The axes with the violin plot. -# -# See Also -# -------- -# stx_violinplot : Alias for this method. -# sns_violinplot : DataFrame-based violin plot. -# stx_box : Boxplot alternative. -# -# Examples -# -------- -# >>> ax.stx_violin([data1, data2], labels=['A', 'B']) -# >>> ax.stx_violin(df, x='group', y='value', hue='category') -# """ -# method_name = "stx_violin" -# -# with self._no_tracking(): -# if isinstance(data, list) and all( -# isinstance(item, (list, np.ndarray)) for item in data -# ): -# self._axis_mpl = self._get_ax_module().stx_violin( -# self._axis_mpl, -# values_list=data, -# labels=labels, -# colors=colors, -# half=half, -# **kwargs, -# ) -# else: -# self._axis_mpl = self._get_ax_module().stx_violin( -# self._axis_mpl, -# data=data, -# x=x, -# y=y, -# hue=hue, -# half=half, -# **kwargs, -# ) -# -# tracked_dict = { -# "data": data, -# "x": x, -# "y": y, -# "hue": hue, -# "half": half, -# "labels": labels, -# "colors": colors, -# } -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl -# -# def stx_line( -# self, -# y: ArrayLike, -# *, -# x: Optional[ArrayLike] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple["Axes", pd.DataFrame]: -# """Plot a simple line with SciTeX styling. -# -# Parameters -# ---------- -# y : array-like -# Y values for the line. -# x : array-like, optional -# X values for the line. If None, uses integer indices. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the line plot function. -# -# Returns -# ------- -# tuple -# (Axes, DataFrame) - The axes and a DataFrame with the plotted data. -# -# See Also -# -------- -# stx_mean_std : Line with standard deviation shading. -# stx_shaded_line : Line with custom shaded region. -# sns_lineplot : DataFrame-based line plot. -# -# Examples -# -------- -# >>> ax.stx_line(y_values) -# >>> ax.stx_line(y, x=x, label='Series A', color='blue') -# """ -# method_name = "stx_line" -# -# with self._no_tracking(): -# self._axis_mpl, plot_df = self._get_ax_module().stx_line( -# self._axis_mpl, y, xx=x, **kwargs -# ) -# -# tracked_dict = {"plot_df": plot_df} -# self._track(track, id, method_name, tracked_dict, kwargs) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl, plot_df -# -# def stx_mean_std( -# self, -# data: ArrayLike, -# *, -# x: Optional[ArrayLike] = None, -# sd: float = 1.0, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple["Axes", pd.DataFrame]: -# """Plot mean line with standard deviation shading. -# -# Parameters -# ---------- -# data : array-like -# 2D array where each row is an observation and columns are time points. -# x : array-like, optional -# X values. If None, uses integer indices. -# sd : float, default 1.0 -# Number of standard deviations for the shaded region. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the plot function. -# -# Returns -# ------- -# tuple -# (Axes, DataFrame) - The axes and a DataFrame with mean, upper, lower. -# -# See Also -# -------- -# stx_mean_ci : Mean with confidence interval. -# stx_median_iqr : Median with interquartile range. -# stx_shaded_line : Custom shaded line. -# -# Examples -# -------- -# >>> ax.stx_mean_std(data_2d, sd=2, label='Mean±2SD') -# """ -# method_name = "stx_mean_std" -# -# with self._no_tracking(): -# self._axis_mpl, plot_df = self._get_ax_module().stx_mean_std( -# self._axis_mpl, data, xx=x, sd=sd, **kwargs -# ) -# -# tracked_dict = {"plot_df": plot_df} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl, plot_df -# -# def stx_mean_ci( -# self, -# data: ArrayLike, -# *, -# x: Optional[ArrayLike] = None, -# ci: float = 95.0, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple["Axes", pd.DataFrame]: -# """Plot mean line with confidence interval shading. -# -# Parameters -# ---------- -# data : array-like -# 2D array where each row is an observation and columns are time points. -# x : array-like, optional -# X values. If None, uses integer indices. -# ci : float, default 95.0 -# Confidence interval percentage (e.g., 95 for 95% CI). -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the plot function. -# -# Returns -# ------- -# tuple -# (Axes, DataFrame) - The axes and a DataFrame with mean, upper, lower. -# -# See Also -# -------- -# stx_mean_std : Mean with standard deviation. -# stx_median_iqr : Median with interquartile range. -# -# Examples -# -------- -# >>> ax.stx_mean_ci(data_2d, ci=99, label='Mean±99%CI') -# """ -# method_name = "stx_mean_ci" -# -# with self._no_tracking(): -# self._axis_mpl, plot_df = self._get_ax_module().stx_mean_ci( -# self._axis_mpl, data, xx=x, perc=ci, **kwargs -# ) -# -# tracked_dict = {"plot_df": plot_df} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl, plot_df -# -# def stx_median_iqr( -# self, -# data: ArrayLike, -# *, -# x: Optional[ArrayLike] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple["Axes", pd.DataFrame]: -# """Plot median line with interquartile range shading. -# -# Parameters -# ---------- -# data : array-like -# 2D array where each row is an observation and columns are time points. -# x : array-like, optional -# X values. If None, uses integer indices. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the plot function. -# -# Returns -# ------- -# tuple -# (Axes, DataFrame) - The axes and a DataFrame with median, Q1, Q3. -# -# See Also -# -------- -# stx_mean_std : Mean with standard deviation. -# stx_mean_ci : Mean with confidence interval. -# -# Examples -# -------- -# >>> ax.stx_median_iqr(data_2d, label='Median±IQR') -# """ -# method_name = "stx_median_iqr" -# -# with self._no_tracking(): -# self._axis_mpl, plot_df = self._get_ax_module().stx_median_iqr( -# self._axis_mpl, data, xx=x, **kwargs -# ) -# -# tracked_dict = {"plot_df": plot_df} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl, plot_df -# -# def stx_shaded_line( -# self, -# x: ArrayLike, -# y_lower: ArrayLike, -# y_middle: ArrayLike, -# y_upper: ArrayLike, -# *, -# color: Optional[Union[str, List[str]]] = None, -# label: Optional[Union[str, List[str]]] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> Tuple["Axes", pd.DataFrame]: -# """Plot a line with shaded area between lower and upper bounds. -# -# Parameters -# ---------- -# x : array-like -# X coordinates. -# y_lower : array-like -# Lower bound of the shaded region. -# y_middle : array-like -# Center line values. -# y_upper : array-like -# Upper bound of the shaded region. -# color : str or list of str, optional -# Color(s) for the line and shading. -# label : str or list of str, optional -# Label(s) for the legend. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the plot function. -# -# Returns -# ------- -# tuple -# (Axes, DataFrame) - The axes and a DataFrame with the plotted data. -# -# See Also -# -------- -# stx_mean_std : Mean with standard deviation. -# stx_fill_between : Simple fill between curves. -# -# Examples -# -------- -# >>> ax.stx_shaded_line(x, lower, mean, upper, color='blue', label='Result') -# """ -# method_name = "stx_shaded_line" -# -# with self._no_tracking(): -# self._axis_mpl, plot_df = self._get_ax_module().stx_shaded_line( -# self._axis_mpl, -# x, -# y_lower, -# y_middle, -# y_upper, -# color=color, -# label=label, -# **kwargs, -# ) -# -# tracked_dict = {"plot_df": plot_df} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name) -# -# return self._axis_mpl, plot_df -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_statistical.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__stx_aliases.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__stx_aliases.py deleted file mode 100644 index 8e4a9d74a..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/test__stx_aliases.py +++ /dev/null @@ -1,543 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_stx_aliases.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: _stx_aliases.py - stx_ aliases for standard matplotlib methods -# -# """stx_ prefixed aliases for standard matplotlib methods with tracking support.""" -# -# import os -# from typing import List, Optional, Sequence, Union -# -# import numpy as np -# import pandas as pd -# from matplotlib.container import BarContainer -# from matplotlib.collections import PathCollection -# from matplotlib.contour import QuadContourSet -# from matplotlib.image import AxesImage -# -# from scitex.types import ArrayLike -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# -# -# class StxAliasesMixin: -# """Mixin providing stx_ aliases for standard matplotlib methods. -# -# These methods wrap standard matplotlib plotting functions with: -# - SciTeX styling applied automatically -# - Data tracking for reproducibility -# - Sample size annotations in labels -# """ -# -# def stx_bar( -# self, -# x: ArrayLike, -# height: ArrayLike, -# *, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> BarContainer: -# """Create a bar plot with SciTeX styling and tracking. -# -# Parameters -# ---------- -# x : array-like -# X coordinates of the bars. -# height : array-like -# Heights of the bars. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `matplotlib.axes.Axes.bar`. -# -# Returns -# ------- -# BarContainer -# Container with all the bars. -# -# See Also -# -------- -# stx_barh : Horizontal bar plot. -# mpl_bar : Raw matplotlib bar without styling. -# -# Examples -# -------- -# >>> ax.stx_bar([1, 2, 3], [4, 5, 6]) -# >>> ax.stx_bar(x, height, label="Group A", color="blue") -# """ -# method_name = "stx_bar" -# -# if kwargs.get("label"): -# n_samples = len(x) -# kwargs["label"] = f"{kwargs['label']} ($n$={n_samples})" -# -# with self._no_tracking(): -# result = self._axis_mpl.bar(x, height, **kwargs) -# -# tracked_dict = {"bar_df": pd.DataFrame({"x": x, "height": height})} -# self._track(track, id, method_name, tracked_dict, None) -# -# from scitex.plt.ax import style_barplot -# -# style_barplot(result) -# -# self._apply_scitex_postprocess(method_name, result) -# -# return result -# -# def stx_barh( -# self, -# y: ArrayLike, -# width: ArrayLike, -# *, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> BarContainer: -# """Create a horizontal bar plot with SciTeX styling and tracking. -# -# Parameters -# ---------- -# y : array-like -# Y coordinates of the bars. -# width : array-like -# Widths of the bars. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `matplotlib.axes.Axes.barh`. -# -# Returns -# ------- -# BarContainer -# Container with all the bars. -# -# See Also -# -------- -# stx_bar : Vertical bar plot. -# mpl_barh : Raw matplotlib barh without styling. -# -# Examples -# -------- -# >>> ax.stx_barh([1, 2, 3], [4, 5, 6]) -# """ -# method_name = "stx_barh" -# -# if kwargs.get("label"): -# n_samples = len(y) -# kwargs["label"] = f"{kwargs['label']} ($n$={n_samples})" -# -# with self._no_tracking(): -# result = self._axis_mpl.barh(y, width, **kwargs) -# -# tracked_dict = {"barh_df": pd.DataFrame({"y": y, "width": width})} -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name, result) -# -# return result -# -# def stx_scatter( -# self, -# x: ArrayLike, -# y: ArrayLike, -# *, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> PathCollection: -# """Create a scatter plot with SciTeX styling and tracking. -# -# Parameters -# ---------- -# x : array-like -# X coordinates of the data points. -# y : array-like -# Y coordinates of the data points. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `matplotlib.axes.Axes.scatter`. -# -# Returns -# ------- -# PathCollection -# Collection of scatter points. -# -# See Also -# -------- -# sns_scatterplot : DataFrame-based scatter plot. -# mpl_scatter : Raw matplotlib scatter without styling. -# -# Examples -# -------- -# >>> ax.stx_scatter(x, y, label="Data", s=50) -# """ -# method_name = "stx_scatter" -# -# if kwargs.get("label"): -# n_samples = len(x) -# kwargs["label"] = f"{kwargs['label']} ($n$={n_samples})" -# -# with self._no_tracking(): -# result = self._axis_mpl.scatter(x, y, **kwargs) -# -# tracked_dict = {"scatter_df": pd.DataFrame({"x": x, "y": y})} -# self._track(track, id, method_name, tracked_dict, None) -# -# from scitex.plt.ax import style_scatter -# -# style_scatter(result) -# -# self._apply_scitex_postprocess(method_name, result) -# -# return result -# -# def stx_errorbar( -# self, -# x: ArrayLike, -# y: ArrayLike, -# *, -# yerr: Optional[ArrayLike] = None, -# xerr: Optional[ArrayLike] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create an error bar plot with SciTeX styling and tracking. -# -# Parameters -# ---------- -# x : array-like -# X coordinates of the data points. -# y : array-like -# Y coordinates of the data points. -# yerr : array-like, optional -# Error values for y-axis (symmetric or asymmetric). -# xerr : array-like, optional -# Error values for x-axis (symmetric or asymmetric). -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `matplotlib.axes.Axes.errorbar`. -# -# Returns -# ------- -# ErrorbarContainer -# Container with the plotted errorbar lines. -# -# See Also -# -------- -# stx_mean_std : Mean line with standard deviation shading. -# stx_mean_ci : Mean line with confidence interval shading. -# -# Examples -# -------- -# >>> ax.stx_errorbar(x, y, yerr=std, fmt='o-') -# """ -# method_name = "stx_errorbar" -# -# if kwargs.get("label"): -# n_samples = len(x) -# kwargs["label"] = f"{kwargs['label']} ($n$={n_samples})" -# -# with self._no_tracking(): -# result = self._axis_mpl.errorbar(x, y, yerr=yerr, xerr=xerr, **kwargs) -# -# df_dict = {"x": x, "y": y} -# if yerr is not None: -# df_dict["yerr"] = yerr -# if xerr is not None: -# df_dict["xerr"] = xerr -# tracked_dict = {"errorbar_df": pd.DataFrame(df_dict)} -# self._track(track, id, method_name, tracked_dict, None) -# -# from scitex.plt.ax import style_errorbar -# -# style_errorbar(result) -# -# self._apply_scitex_postprocess(method_name, result) -# -# return result -# -# def stx_fill_between( -# self, -# x: ArrayLike, -# y1: ArrayLike, -# y2: Union[float, ArrayLike] = 0, -# *, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Fill the area between two curves with SciTeX styling and tracking. -# -# Parameters -# ---------- -# x : array-like -# X coordinates for the fill region. -# y1 : array-like -# First y-boundary curve. -# y2 : float or array-like, default 0 -# Second y-boundary curve. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `matplotlib.axes.Axes.fill_between`. -# -# Returns -# ------- -# PolyCollection -# Collection representing the filled area. -# -# See Also -# -------- -# stx_shaded_line : Line plot with shaded confidence region. -# -# Examples -# -------- -# >>> ax.stx_fill_between(x, y_lower, y_upper, alpha=0.3) -# """ -# method_name = "stx_fill_between" -# -# with self._no_tracking(): -# result = self._axis_mpl.fill_between(x, y1, y2, **kwargs) -# -# tracked_dict = { -# "fill_between_df": pd.DataFrame( -# { -# "x": x, -# "y1": y1, -# "y2": y2 if hasattr(y2, "__len__") else [y2] * len(x), -# } -# ) -# } -# self._track(track, id, method_name, tracked_dict, None) -# self._apply_scitex_postprocess(method_name, result) -# -# return result -# -# def stx_contour( -# self, -# X: ArrayLike, -# Y: ArrayLike, -# Z: ArrayLike, -# *, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> QuadContourSet: -# """Create a contour plot with SciTeX styling and tracking. -# -# Parameters -# ---------- -# X : array-like -# X coordinates of the grid (2D array or 1D for meshgrid). -# Y : array-like -# Y coordinates of the grid (2D array or 1D for meshgrid). -# Z : array-like -# Values at each grid point (2D array). -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `matplotlib.axes.Axes.contour`. -# -# Returns -# ------- -# QuadContourSet -# The contour set object. -# -# See Also -# -------- -# stx_imshow : Display data as an image. -# mpl_contour : Raw matplotlib contour without styling. -# -# Examples -# -------- -# >>> ax.stx_contour(X, Y, Z, levels=10) -# """ -# method_name = "stx_contour" -# -# with self._no_tracking(): -# result = self._axis_mpl.contour(X, Y, Z, **kwargs) -# -# tracked_dict = { -# "contour_df": pd.DataFrame( -# {"X": np.ravel(X), "Y": np.ravel(Y), "Z": np.ravel(Z)} -# ) -# } -# self._track(track, id, method_name, tracked_dict, None) -# -# self._apply_scitex_postprocess(method_name, result) -# -# return result -# -# def stx_imshow( -# self, -# data: ArrayLike, -# *, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> AxesImage: -# """Display data as an image with SciTeX styling and tracking. -# -# Parameters -# ---------- -# data : array-like -# Image data (2D or 3D array for RGB/RGBA). -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `matplotlib.axes.Axes.imshow`. -# -# Returns -# ------- -# AxesImage -# The image object. -# -# See Also -# -------- -# stx_image : Scientific image display with colorbar. -# mpl_imshow : Raw matplotlib imshow without styling. -# -# Examples -# -------- -# >>> ax.stx_imshow(image_array, cmap='viridis') -# """ -# method_name = "stx_imshow" -# -# with self._no_tracking(): -# result = self._axis_mpl.imshow(data, **kwargs) -# -# if hasattr(data, "shape") and len(data.shape) == 2: -# n_rows, n_cols = data.shape -# df = pd.DataFrame(data, columns=[f"col_{i}" for i in range(n_cols)]) -# else: -# df = pd.DataFrame(data) -# tracked_dict = {"imshow_df": df} -# self._track(track, id, method_name, tracked_dict, None) -# -# self._apply_scitex_postprocess(method_name, result) -# -# return result -# -# def stx_boxplot( -# self, -# data: Union[ArrayLike, Sequence[ArrayLike]], -# *, -# colors: Optional[List[str]] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> dict: -# """Create a boxplot with SciTeX styling and tracking. -# -# This is an alias for :meth:`stx_box`. -# -# Parameters -# ---------- -# data : array-like or sequence of array-like -# Data for the boxplot. Can be a single array or list of arrays. -# colors : list of str, optional -# Colors for each box. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `matplotlib.axes.Axes.boxplot`. -# -# Returns -# ------- -# dict -# Dictionary mapping component names to artists. -# -# See Also -# -------- -# stx_box : Primary boxplot method. -# sns_boxplot : DataFrame-based boxplot. -# stx_violin : Violin plot alternative. -# -# Examples -# -------- -# >>> ax.stx_boxplot([data1, data2, data3], labels=['A', 'B', 'C']) -# """ -# return self.stx_box(data, colors=colors, track=track, id=id, **kwargs) -# -# def stx_violinplot( -# self, -# data: Union[ArrayLike, Sequence[ArrayLike]], -# *, -# colors: Optional[List[str]] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ) -> "Axes": -# """Create a violin plot with SciTeX styling and tracking. -# -# This is an alias for :meth:`stx_violin`. -# -# Parameters -# ---------- -# data : array-like or sequence of array-like -# Data for the violin plot. Can be a single array or list of arrays. -# colors : list of str, optional -# Colors for each violin. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the violin plot function. -# -# Returns -# ------- -# Axes -# The axes with the violin plot. -# -# See Also -# -------- -# stx_violin : Primary violin plot method. -# sns_violinplot : DataFrame-based violin plot. -# stx_box : Boxplot alternative. -# -# Examples -# -------- -# >>> ax.stx_violinplot([data1, data2], labels=['A', 'B']) -# """ -# return self.stx_violin(data, colors=colors, track=track, id=id, **kwargs) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_MatplotlibPlotMixin/_stx_aliases.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/test__base.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/test__base.py deleted file mode 100644 index 9cb232b16..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/test__base.py +++ /dev/null @@ -1,168 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_base.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: _base.py - Base seaborn functionality -# -# """Base seaborn mixin with helper methods for tracking and data preparation.""" -# -# import os -# from functools import wraps -# -# import scitex -# import numpy as np -# import pandas as pd -# import seaborn as sns -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# -# -# def sns_copy_doc(func): -# """Decorator to copy docstring from seaborn function.""" -# @wraps(func) -# def wrapper(self, *args, **kwargs): -# return func(self, *args, **kwargs) -# -# sns_method_name = func.__name__.split("sns_")[-1] -# wrapper.__doc__ = getattr(sns, sns_method_name).__doc__ -# return wrapper -# -# -# class SeabornBaseMixin: -# """Base mixin for seaborn integration with tracking support.""" -# -# def _sns_base( -# self, method_name, *args, track=True, track_obj=None, id=None, **kwargs -# ): -# """Execute seaborn plot method with tracking support.""" -# sns_method_name = method_name.split("sns_")[-1] -# -# with self._no_tracking(): -# sns_plot_fn = getattr(sns, sns_method_name) -# -# if kwargs.get("hue_colors"): -# kwargs = scitex.gen.alternate_kwarg( -# kwargs, primary_key="palette", alternate_key="hue_colors" -# ) -# -# import warnings -# from scitex import logging -# -# mpl_logger = logging.getLogger("matplotlib") -# original_level = mpl_logger.level -# mpl_logger.setLevel(logging.WARNING) -# -# try: -# with warnings.catch_warnings(): -# warnings.filterwarnings( -# "ignore", -# message=".*categorical units.*parsable as floats or dates.*", -# category=UserWarning, -# ) -# warnings.filterwarnings( -# "ignore", -# message=".*Using categorical units.*", -# module="matplotlib.*", -# ) -# warnings.simplefilter("ignore", UserWarning) -# -# self._axis_mpl = sns_plot_fn(ax=self._axis_mpl, *args, **kwargs) -# finally: -# mpl_logger.setLevel(original_level) -# -# # Post-processing for histplot with kde=True -# if sns_method_name == "histplot" and kwargs.get("kde", False): -# from scitex.plt.utils import mm_to_pt -# kde_lw = mm_to_pt(0.2) -# for line in self._axis_mpl.get_lines(): -# line.set_linewidth(kde_lw) -# line.set_color("black") -# line.set_linestyle("--") -# -# # Post-processing for histplot alpha -# if sns_method_name == "histplot" and "alpha" not in kwargs: -# for patch in self._axis_mpl.patches: -# patch.set_alpha(1.0) -# -# track_obj = track_obj if track_obj is not None else args -# tracked_dict = { -# "data": track_obj, -# "args": args, -# } -# self._track(track, id, method_name, tracked_dict, kwargs) -# -# def _sns_base_xyhue(self, method_name, *args, track=True, id=None, **kwargs): -# """Execute seaborn plot with x/y/hue data preparation.""" -# df = kwargs.get("data") -# x, y, hue = kwargs.get("x"), kwargs.get("y"), kwargs.get("hue") -# -# track_obj = self._sns_prepare_xyhue(df, x, y, hue) if df is not None else None -# self._sns_base( -# method_name, -# *args, -# track=track, -# track_obj=track_obj, -# id=id, -# **kwargs, -# ) -# -# def _sns_prepare_xyhue(self, data=None, x=None, y=None, hue=None, **kwargs): -# """Prepare data for tracking based on x/y/hue configuration.""" -# data = data.reset_index() -# -# if hue is not None: -# if x is None and y is None: -# return data -# elif x is None: -# agg_dict = {} -# for hh in data[hue].unique(): -# agg_dict[hh] = data.loc[data[hue] == hh, y] -# df = scitex.pd.force_df(agg_dict) -# return df -# elif y is None: -# df = pd.concat( -# [data.loc[data[hue] == hh, x] for hh in data[hue].unique()], -# axis=1, -# ) -# return df -# else: -# pivoted_data = data.pivot_table( -# values=y, -# index=data.index, -# columns=[x, hue], -# aggfunc="first", -# ) -# pivoted_data.columns = [ -# f"{col[0]}-{col[1]}" for col in pivoted_data.columns -# ] -# return pivoted_data -# else: -# if x is None and y is None: -# return data -# elif x is None: -# return data[[y]] -# elif y is None: -# return data[[x]] -# else: -# return data.pivot_table( -# values=y, index=data.index, columns=x, aggfunc="first" -# ) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_base.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/test__wrappers.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/test__wrappers.py deleted file mode 100644 index 1a874f187..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/test__wrappers.py +++ /dev/null @@ -1,616 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: _wrappers.py - Seaborn plot wrappers -# -# """Seaborn plot wrappers with SciTeX integration.""" -# -# import os -# from typing import Optional, Union -# -# import numpy as np -# import pandas as pd -# import seaborn as sns -# -# from scitex.types import ArrayLike -# -# from ._base import sns_copy_doc -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# -# -# class SeabornWrappersMixin: -# """Mixin providing sns_ prefixed seaborn wrappers. -# -# All methods use the seaborn DataFrame-centric interface: -# - data: DataFrame containing the data -# - x, y: Column names for axes -# - hue: Column name for color grouping -# -# These methods integrate with SciTeX tracking and styling. -# """ -# -# def _get_ax_module(self): -# """Lazy import ax module to avoid circular imports.""" -# from .....plt import ax as ax_module -# -# return ax_module -# -# @sns_copy_doc -# def sns_barplot( -# self, -# data: Optional[pd.DataFrame] = None, -# *, -# x: Optional[str] = None, -# y: Optional[str] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a bar plot showing point estimates and error bars. -# -# Parameters -# ---------- -# data : DataFrame, optional -# Input data structure. -# x : str, optional -# Column name for x-axis categories. -# y : str, optional -# Column name for y-axis values. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.barplot`. -# -# See Also -# -------- -# stx_bar : Array-based bar plot. -# """ -# self._sns_base_xyhue( -# "sns_barplot", data=data, x=x, y=y, track=track, id=id, **kwargs -# ) -# -# @sns_copy_doc -# def sns_boxplot( -# self, -# data: Optional[pd.DataFrame] = None, -# *, -# x: Optional[str] = None, -# y: Optional[str] = None, -# strip: bool = False, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a box plot showing distributions with quartiles. -# -# Parameters -# ---------- -# data : DataFrame, optional -# Input data structure. -# x : str, optional -# Column name for x-axis grouping. -# y : str, optional -# Column name for y-axis values. -# strip : bool, default False -# If True, overlay a stripplot showing individual points. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.boxplot`. -# -# See Also -# -------- -# stx_box : Array-based boxplot. -# sns_violinplot : Violin plot alternative. -# """ -# self._sns_base_xyhue( -# "sns_boxplot", data=data, x=x, y=y, track=track, id=id, **kwargs -# ) -# -# # Post-processing: Style boxplot with black medians -# from scitex.plt.utils import mm_to_pt -# -# lw_pt = mm_to_pt(0.2) -# -# for line in self._axis_mpl.get_lines(): -# line.set_linewidth(lw_pt) -# xdata = line.get_xdata() -# ydata = line.get_ydata() -# if len(xdata) == 2 and len(ydata) == 2: -# if ydata[0] == ydata[1]: -# x_span = abs(xdata[1] - xdata[0]) -# if x_span < 0.4: -# line.set_color("black") -# -# if strip: -# strip_kwargs = kwargs.copy() -# strip_kwargs.pop("notch", None) -# strip_kwargs.pop("whis", None) -# self.sns_stripplot( -# data=data, -# x=x, -# y=y, -# track=False, -# id=f"{id}_strip", -# **strip_kwargs, -# ) -# -# @sns_copy_doc -# def sns_heatmap( -# self, -# data: Union[pd.DataFrame, ArrayLike], -# *, -# xyz: bool = False, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a heatmap from rectangular data. -# -# Parameters -# ---------- -# data : DataFrame or array-like -# 2D dataset for the heatmap. -# xyz : bool, default False -# If True, convert data to XYZ format before plotting. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.heatmap`. -# -# See Also -# -------- -# stx_heatmap : Array-based annotated heatmap. -# stx_image : Simple image display. -# """ -# import scitex -# -# method_name = "sns_heatmap" -# df = data -# if xyz: -# df = scitex.pd.to_xyz(df) -# self._sns_base(method_name, df, track=track, track_obj=df, id=id, **kwargs) -# -# @sns_copy_doc -# def sns_histplot( -# self, -# data: Optional[pd.DataFrame] = None, -# *, -# x: Optional[str] = None, -# y: Optional[str] = None, -# bins: int = 10, -# align_bins: bool = True, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a histogram with optional kernel density estimate. -# -# Parameters -# ---------- -# data : DataFrame, optional -# Input data structure. -# x : str, optional -# Column name for x-axis values. -# y : str, optional -# Column name for y-axis values. -# bins : int, default 10 -# Number of histogram bins. -# align_bins : bool, default True -# Align bins across multiple histograms on same axes. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.histplot`. -# Common options: kde, stat, element, hue. -# -# See Also -# -------- -# hist : Array-based histogram. -# stx_kde : Kernel density estimate. -# """ -# method_name = "sns_histplot" -# -# plot_data = None -# if data is not None and x is not None: -# plot_data = ( -# data[x].values -# if hasattr(data, "columns") and x in data.columns -# else None -# ) -# -# axis_id = str(hash(self._axis_mpl)) -# hist_id = id if id is not None else str(self.id) -# range_value = kwargs.get("binrange", None) -# -# if align_bins and plot_data is not None: -# from .....plt.utils import histogram_bin_manager -# -# bins_val, range_val = histogram_bin_manager.register_histogram( -# axis_id, hist_id, plot_data, bins, range_value -# ) -# kwargs["bins"] = bins_val -# if range_value is not None: -# kwargs["binrange"] = range_val -# -# with self._no_tracking(): -# sns_plot = sns.histplot(data=data, x=x, y=y, ax=self._axis_mpl, **kwargs) -# -# hist_result = None -# if hasattr(sns_plot, "patches") and sns_plot.patches: -# patches = sns_plot.patches -# if patches: -# counts = np.array([p.get_height() for p in patches]) -# bin_edges = [] -# for p in patches: -# bin_edges.append(p.get_x()) -# if patches: -# bin_edges.append(patches[-1].get_x() + patches[-1].get_width()) -# hist_result = (counts, np.array(bin_edges)) -# -# track_obj = self._sns_prepare_xyhue(data, x, y, kwargs.get("hue")) -# tracked_dict = { -# "data": track_obj, -# "args": (data, x, y), -# "hist_result": hist_result, -# } -# self._track(track, id, method_name, tracked_dict, kwargs) -# -# return sns_plot -# -# @sns_copy_doc -# def sns_kdeplot( -# self, -# data: Optional[pd.DataFrame] = None, -# *, -# x: Optional[str] = None, -# y: Optional[str] = None, -# xlim: Optional[tuple] = None, -# ylim: Optional[tuple] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a kernel density estimate plot. -# -# Parameters -# ---------- -# data : DataFrame, optional -# Input data structure. -# x : str, optional -# Column name for x-axis values. -# y : str, optional -# Column name for y-axis values. -# xlim : tuple, optional -# Limits for x-axis KDE range. -# ylim : tuple, optional -# Limits for y-axis KDE range. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to the KDE function. -# -# See Also -# -------- -# stx_kde : Array-based KDE plot. -# sns_histplot : Histogram with optional KDE. -# """ -# hue_col = kwargs.pop("hue", None) -# -# if hue_col: -# hues = data[hue_col] -# if x is not None: -# lim = xlim -# for hue in np.unique(hues): -# _data = data.loc[hues == hue, x] -# self.stx_kde(_data, xlim=lim, label=hue, id=hue, **kwargs) -# if y is not None: -# lim = ylim -# for hue in np.unique(hues): -# _data = data.loc[hues == hue, y] -# self.stx_kde(_data, xlim=lim, label=hue, id=hue, **kwargs) -# else: -# if x is not None: -# _data, lim = data[x], xlim -# if y is not None: -# _data, lim = data[y], ylim -# self.stx_kde(_data, xlim=lim, **kwargs) -# -# @sns_copy_doc -# def sns_pairplot( -# self, -# *args, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a grid of pairwise relationships in a dataset. -# -# Parameters -# ---------- -# *args -# Positional arguments passed to `seaborn.pairplot`. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.pairplot`. -# """ -# self._sns_base("sns_pairplot", *args, track=track, id=id, **kwargs) -# -# @sns_copy_doc -# def sns_scatterplot( -# self, -# data: Optional[pd.DataFrame] = None, -# *, -# x: Optional[str] = None, -# y: Optional[str] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a scatter plot with semantic mappings. -# -# Parameters -# ---------- -# data : DataFrame, optional -# Input data structure. -# x : str, optional -# Column name for x-axis values. -# y : str, optional -# Column name for y-axis values. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.scatterplot`. -# Common options: hue, size, style. -# -# See Also -# -------- -# stx_scatter : Array-based scatter plot. -# """ -# self._sns_base_xyhue( -# "sns_scatterplot", -# data=data, -# x=x, -# y=y, -# track=track, -# id=id, -# **kwargs, -# ) -# -# @sns_copy_doc -# def sns_lineplot( -# self, -# data: Optional[pd.DataFrame] = None, -# *, -# x: Optional[str] = None, -# y: Optional[str] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a line plot with semantic mappings. -# -# Parameters -# ---------- -# data : DataFrame, optional -# Input data structure. -# x : str, optional -# Column name for x-axis values. -# y : str, optional -# Column name for y-axis values. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.lineplot`. -# Common options: hue, size, style, estimator. -# -# See Also -# -------- -# stx_line : Array-based line plot. -# stx_mean_std : Line with uncertainty shading. -# """ -# self._sns_base_xyhue( -# "sns_lineplot", -# data=data, -# x=x, -# y=y, -# track=track, -# id=id, -# **kwargs, -# ) -# -# @sns_copy_doc -# def sns_swarmplot( -# self, -# data: Optional[pd.DataFrame] = None, -# *, -# x: Optional[str] = None, -# y: Optional[str] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a categorical scatter plot with non-overlapping points. -# -# Parameters -# ---------- -# data : DataFrame, optional -# Input data structure. -# x : str, optional -# Column name for x-axis grouping. -# y : str, optional -# Column name for y-axis values. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.swarmplot`. -# -# See Also -# -------- -# sns_stripplot : Jittered categorical scatter. -# sns_boxplot : Box plot for distributions. -# """ -# self._sns_base_xyhue( -# "sns_swarmplot", data=data, x=x, y=y, track=track, id=id, **kwargs -# ) -# -# @sns_copy_doc -# def sns_stripplot( -# self, -# data: Optional[pd.DataFrame] = None, -# *, -# x: Optional[str] = None, -# y: Optional[str] = None, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a categorical scatter plot with jittered points. -# -# Parameters -# ---------- -# data : DataFrame, optional -# Input data structure. -# x : str, optional -# Column name for x-axis grouping. -# y : str, optional -# Column name for y-axis values. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.stripplot`. -# -# See Also -# -------- -# sns_swarmplot : Non-overlapping categorical scatter. -# sns_boxplot : Often combined with stripplot. -# """ -# self._sns_base_xyhue( -# "sns_stripplot", data=data, x=x, y=y, track=track, id=id, **kwargs -# ) -# -# @sns_copy_doc -# def sns_violinplot( -# self, -# data: Optional[pd.DataFrame] = None, -# *, -# x: Optional[str] = None, -# y: Optional[str] = None, -# half: bool = False, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a violin plot combining box plot with kernel density. -# -# Parameters -# ---------- -# data : DataFrame, optional -# Input data structure. -# x : str, optional -# Column name for x-axis grouping. -# y : str, optional -# Column name for y-axis values. -# half : bool, default False -# If True, draw half-violins (useful for paired comparisons). -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.violinplot`. -# -# See Also -# -------- -# stx_violin : Array-based violin plot. -# sns_boxplot : Box plot alternative. -# """ -# if half: -# with self._no_tracking(): -# self._axis_mpl = self._get_ax_module().plot_half_violin( -# self._axis_mpl, data=data, x=x, y=y, **kwargs -# ) -# else: -# self._sns_base_xyhue( -# "sns_violinplot", -# data=data, -# x=x, -# y=y, -# track=track, -# id=id, -# **kwargs, -# ) -# -# track_obj = self._sns_prepare_xyhue(data, x, y, kwargs.get("hue")) -# self._track(track, id, "sns_violinplot", track_obj, kwargs) -# -# return self._axis_mpl -# -# @sns_copy_doc -# def sns_jointplot( -# self, -# *args, -# track: bool = True, -# id: Optional[str] = None, -# **kwargs, -# ): -# """Create a figure with joint and marginal distributions. -# -# Parameters -# ---------- -# *args -# Positional arguments passed to `seaborn.jointplot`. -# track : bool, default True -# Enable data tracking for reproducibility. -# id : str, optional -# Unique identifier for this plot element. -# **kwargs -# Additional arguments passed to `seaborn.jointplot`. -# -# See Also -# -------- -# stx_scatter_hist : Array-based scatter with marginal histograms. -# """ -# self._sns_base("sns_jointplot", *args, track=track, id=id, **kwargs) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/test__RawMatplotlibMixin.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/test__RawMatplotlibMixin.py deleted file mode 100644 index f54480ef8..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/test__RawMatplotlibMixin.py +++ /dev/null @@ -1,337 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_RawMatplotlibMixin.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_RawMatplotlibMixin.py -# -# """ -# Matplotlib aliases (mpl_xxx) for explicit matplotlib-style API. -# -# Provides consistent naming convention: -# - stx_xxx: scitex-specific methods (ArrayLike input, tracked) -# - sns_xxx: seaborn wrappers (DataFrame input, tracked) -# - mpl_xxx: matplotlib methods (matplotlib-style input, tracked) -# -# All three API layers track data for reproducibility. -# -# Usage: -# ax.stx_line(y) # ArrayLike input -# ax.sns_boxplot(data=df, x="group", y="value") # DataFrame input -# ax.mpl_plot(x, y) # matplotlib-style input -# ax.plot(x, y) # Same as mpl_plot -# """ -# -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# -# -# class RawMatplotlibMixin: -# """Mixin providing mpl_xxx aliases for matplotlib-style API. -# -# These methods are identical to calling ax.plot(), ax.scatter(), etc. -# They go through SciTeX's __getattr__ wrapper and are fully tracked. -# -# The mpl_* prefix provides: -# - Explicit naming convention (mpl_* vs stx_* vs sns_*) -# - Programmatic access via MPL_METHODS registry -# - Same tracking and styling as regular matplotlib calls -# """ -# -# # ========================================================================= -# # Helper to call through __getattr__ wrapper (enables tracking) -# # ========================================================================= -# def _mpl_call(self, method_name, *args, **kwargs): -# """Call matplotlib method through __getattr__ wrapper for tracking.""" -# # Use object.__getattribute__ to get the __getattr__ from AxisWrapper -# # Then call it with the method name to get the tracked wrapper -# wrapper_class = type(self) -# # Walk up MRO to find __getattr__ in AxisWrapper -# for cls in wrapper_class.__mro__: -# if "__getattr__" in cls.__dict__: -# return cls.__getattr__(self, method_name)(*args, **kwargs) -# # Fallback to direct call if no __getattr__ found -# return getattr(self._axes_mpl, method_name)(*args, **kwargs) -# -# # ========================================================================= -# # Line plots -# # ========================================================================= -# def mpl_plot(self, *args, **kwargs): -# """Matplotlib plot() - tracked, identical to ax.plot().""" -# return self._mpl_call("plot", *args, **kwargs) -# -# def mpl_step(self, *args, **kwargs): -# """Matplotlib step() - tracked, identical to ax.step().""" -# return self._mpl_call("step", *args, **kwargs) -# -# def mpl_stem(self, *args, **kwargs): -# """Matplotlib stem() - tracked, identical to ax.stem().""" -# return self._mpl_call("stem", *args, **kwargs) -# -# # ========================================================================= -# # Scatter plots -# # ========================================================================= -# def mpl_scatter(self, *args, **kwargs): -# """Matplotlib scatter() - tracked, identical to ax.scatter().""" -# return self._mpl_call("scatter", *args, **kwargs) -# -# # ========================================================================= -# # Bar plots -# # ========================================================================= -# def mpl_bar(self, *args, **kwargs): -# """Matplotlib bar() - tracked, identical to ax.bar().""" -# return self._mpl_call("bar", *args, **kwargs) -# -# def mpl_barh(self, *args, **kwargs): -# """Matplotlib barh() - tracked, identical to ax.barh().""" -# return self._mpl_call("barh", *args, **kwargs) -# -# def mpl_bar3d(self, *args, **kwargs): -# """Matplotlib bar3d() (3D axes) - tracked.""" -# return self._mpl_call("bar3d", *args, **kwargs) -# -# # ========================================================================= -# # Histograms -# # ========================================================================= -# def mpl_hist(self, *args, **kwargs): -# """Matplotlib hist() - tracked, identical to ax.hist().""" -# return self._mpl_call("hist", *args, **kwargs) -# -# def mpl_hist2d(self, *args, **kwargs): -# """Matplotlib hist2d() - tracked, identical to ax.hist2d().""" -# return self._mpl_call("hist2d", *args, **kwargs) -# -# def mpl_hexbin(self, *args, **kwargs): -# """Matplotlib hexbin() - tracked, identical to ax.hexbin().""" -# return self._mpl_call("hexbin", *args, **kwargs) -# -# # ========================================================================= -# # Statistical plots -# # ========================================================================= -# def mpl_boxplot(self, *args, **kwargs): -# """Matplotlib boxplot() - tracked, identical to ax.boxplot().""" -# return self._mpl_call("boxplot", *args, **kwargs) -# -# def mpl_violinplot(self, *args, **kwargs): -# """Matplotlib violinplot() - tracked, identical to ax.violinplot().""" -# return self._mpl_call("violinplot", *args, **kwargs) -# -# def mpl_errorbar(self, *args, **kwargs): -# """Matplotlib errorbar() - tracked, identical to ax.errorbar().""" -# return self._mpl_call("errorbar", *args, **kwargs) -# -# def mpl_eventplot(self, *args, **kwargs): -# """Matplotlib eventplot() - tracked, identical to ax.eventplot().""" -# return self._mpl_call("eventplot", *args, **kwargs) -# -# # ========================================================================= -# # Fill and area plots -# # ========================================================================= -# def mpl_fill(self, *args, **kwargs): -# """Matplotlib fill() - tracked, identical to ax.fill().""" -# return self._mpl_call("fill", *args, **kwargs) -# -# def mpl_fill_between(self, *args, **kwargs): -# """Matplotlib fill_between() - tracked, identical to ax.fill_between().""" -# return self._mpl_call("fill_between", *args, **kwargs) -# -# def mpl_fill_betweenx(self, *args, **kwargs): -# """Matplotlib fill_betweenx() - tracked, identical to ax.fill_betweenx().""" -# return self._mpl_call("fill_betweenx", *args, **kwargs) -# -# def mpl_stackplot(self, *args, **kwargs): -# """Matplotlib stackplot() - tracked, identical to ax.stackplot().""" -# return self._mpl_call("stackplot", *args, **kwargs) -# -# # ========================================================================= -# # Contour and heatmap plots -# # ========================================================================= -# def mpl_contour(self, *args, **kwargs): -# """Matplotlib contour() - tracked, identical to ax.contour().""" -# return self._mpl_call("contour", *args, **kwargs) -# -# def mpl_contourf(self, *args, **kwargs): -# """Matplotlib contourf() - tracked, identical to ax.contourf().""" -# return self._mpl_call("contourf", *args, **kwargs) -# -# def mpl_imshow(self, *args, **kwargs): -# """Matplotlib imshow() - tracked, identical to ax.imshow().""" -# return self._mpl_call("imshow", *args, **kwargs) -# -# def mpl_pcolormesh(self, *args, **kwargs): -# """Matplotlib pcolormesh() - tracked, identical to ax.pcolormesh().""" -# return self._mpl_call("pcolormesh", *args, **kwargs) -# -# def mpl_pcolor(self, *args, **kwargs): -# """Matplotlib pcolor() - tracked, identical to ax.pcolor().""" -# return self._mpl_call("pcolor", *args, **kwargs) -# -# def mpl_matshow(self, *args, **kwargs): -# """Matplotlib matshow() - tracked, identical to ax.matshow().""" -# return self._mpl_call("matshow", *args, **kwargs) -# -# # ========================================================================= -# # Vector field plots -# # ========================================================================= -# def mpl_quiver(self, *args, **kwargs): -# """Matplotlib quiver() - tracked, identical to ax.quiver().""" -# return self._mpl_call("quiver", *args, **kwargs) -# -# def mpl_streamplot(self, *args, **kwargs): -# """Matplotlib streamplot() - tracked, identical to ax.streamplot().""" -# return self._mpl_call("streamplot", *args, **kwargs) -# -# def mpl_barbs(self, *args, **kwargs): -# """Matplotlib barbs() - tracked, identical to ax.barbs().""" -# return self._mpl_call("barbs", *args, **kwargs) -# -# # ========================================================================= -# # Pie and polar plots -# # ========================================================================= -# def mpl_pie(self, *args, **kwargs): -# """Matplotlib pie() - tracked, identical to ax.pie().""" -# return self._mpl_call("pie", *args, **kwargs) -# -# # ========================================================================= -# # Text and annotations -# # ========================================================================= -# def mpl_text(self, *args, **kwargs): -# """Matplotlib text() - tracked, identical to ax.text().""" -# return self._mpl_call("text", *args, **kwargs) -# -# def mpl_annotate(self, *args, **kwargs): -# """Matplotlib annotate() - tracked, identical to ax.annotate().""" -# return self._mpl_call("annotate", *args, **kwargs) -# -# # ========================================================================= -# # Lines and spans -# # ========================================================================= -# def mpl_axhline(self, *args, **kwargs): -# """Matplotlib axhline() - tracked, identical to ax.axhline().""" -# return self._mpl_call("axhline", *args, **kwargs) -# -# def mpl_axvline(self, *args, **kwargs): -# """Matplotlib axvline() - tracked, identical to ax.axvline().""" -# return self._mpl_call("axvline", *args, **kwargs) -# -# def mpl_axhspan(self, *args, **kwargs): -# """Matplotlib axhspan() - tracked, identical to ax.axhspan().""" -# return self._mpl_call("axhspan", *args, **kwargs) -# -# def mpl_axvspan(self, *args, **kwargs): -# """Matplotlib axvspan() - tracked, identical to ax.axvspan().""" -# return self._mpl_call("axvspan", *args, **kwargs) -# -# # ========================================================================= -# # Patches and shapes -# # ========================================================================= -# def mpl_add_patch(self, patch, **kwargs): -# """Matplotlib add_patch() - tracked, identical to ax.add_patch().""" -# return self._mpl_call("add_patch", patch, **kwargs) -# -# def mpl_add_artist(self, artist, **kwargs): -# """Matplotlib add_artist() - tracked, identical to ax.add_artist().""" -# return self._mpl_call("add_artist", artist, **kwargs) -# -# def mpl_add_collection(self, collection, **kwargs): -# """Matplotlib add_collection() - tracked, identical to ax.add_collection().""" -# return self._mpl_call("add_collection", collection, **kwargs) -# -# # ========================================================================= -# # 3D plotting (if available) -# # ========================================================================= -# def mpl_plot_surface(self, *args, **kwargs): -# """Matplotlib plot_surface() (3D axes) - tracked.""" -# return self._mpl_call("plot_surface", *args, **kwargs) -# -# def mpl_plot_wireframe(self, *args, **kwargs): -# """Matplotlib plot_wireframe() (3D axes) - tracked.""" -# return self._mpl_call("plot_wireframe", *args, **kwargs) -# -# def mpl_contour3D(self, *args, **kwargs): -# """Matplotlib contour3D() (3D axes) - tracked.""" -# return self._mpl_call("contour3D", *args, **kwargs) -# -# def mpl_scatter3D(self, *args, **kwargs): -# """Matplotlib scatter3D() (3D axes) - tracked.""" -# return self._mpl_call("scatter3D", *args, **kwargs) -# -# # ========================================================================= -# # Utility method to get raw axes -# # ========================================================================= -# @property -# def mpl_axes(self): -# """Direct access to underlying matplotlib axes object.""" -# return self._axes_mpl -# -# def mpl_raw(self, method_name, *args, **kwargs): -# """Call any matplotlib method by name without scitex processing. -# -# Parameters -# ---------- -# method_name : str -# Name of matplotlib axes method to call -# *args, **kwargs -# Arguments to pass to the method -# -# Returns -# ------- -# result -# Result from matplotlib method -# -# Example -# ------- -# >>> ax.mpl_raw("tricontour", x, y, z, levels=10) -# """ -# method = getattr(self._axes_mpl, method_name) -# return method(*args, **kwargs) -# -# -# # Registry of mpl_xxx methods for programmatic access -# MPL_METHODS = [ -# # Line plots -# "mpl_plot", "mpl_step", "mpl_stem", -# # Scatter -# "mpl_scatter", -# # Bar -# "mpl_bar", "mpl_barh", "mpl_bar3d", -# # Histograms -# "mpl_hist", "mpl_hist2d", "mpl_hexbin", -# # Statistical -# "mpl_boxplot", "mpl_violinplot", "mpl_errorbar", "mpl_eventplot", -# # Fill/area -# "mpl_fill", "mpl_fill_between", "mpl_fill_betweenx", "mpl_stackplot", -# # Contour/heatmap -# "mpl_contour", "mpl_contourf", "mpl_imshow", "mpl_pcolormesh", "mpl_pcolor", "mpl_matshow", -# # Vector fields -# "mpl_quiver", "mpl_streamplot", "mpl_barbs", -# # Pie -# "mpl_pie", -# # Text/annotations -# "mpl_text", "mpl_annotate", -# # Lines/spans -# "mpl_axhline", "mpl_axvline", "mpl_axhspan", "mpl_axvspan", -# # Patches -# "mpl_add_patch", "mpl_add_artist", "mpl_add_collection", -# # 3D -# "mpl_plot_surface", "mpl_plot_wireframe", "mpl_contour3D", "mpl_scatter3D", -# ] -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_RawMatplotlibMixin.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/test__TrackingMixin.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/test__TrackingMixin.py deleted file mode 100644 index a480e4540..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/test__TrackingMixin.py +++ /dev/null @@ -1,419 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2025-05-03 12:35:08 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/_subplots/_AxisWrapperMixins/test__TrackingMixin.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/_subplots/_AxisWrapperMixins/test__TrackingMixin.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from unittest.mock import MagicMock - -import pytest - - -@pytest.fixture -def tracking_mixin_instance(): - """Fixture that creates a simple TrackingMixin instance for testing.""" - from scitex.plt._subplots._AxisWrapperMixins import TrackingMixin - - class TestTrackingMixin(TrackingMixin): - def __init__(self): - self.axis = MagicMock() - self.track = True - self.id = 0 - self._ax_history = {} - - return TestTrackingMixin() - - -def test_track_method_with_tracking_enabled(tracking_mixin_instance): - """Tests that _track method correctly stores history when tracking is enabled.""" - # Setup - instance = tracking_mixin_instance - instance.track = True - method_name = "test_method" - args = ([1, 2, 3], [4, 5, 6]) - kwargs = {"color": "red", "marker": "o"} - plot_id = "test_plot" - # Expected key includes ax position prefix (defaults to ax_00_) - expected_key = "ax_00_test_plot" - - # Execute - instance._track(True, plot_id, method_name, args, kwargs) - - # Verify - key is prefixed with ax position - assert expected_key in instance._ax_history - assert instance._ax_history[expected_key] == ( - expected_key, - method_name, - args, - kwargs, - ) - - -def test_track_method_with_tracking_disabled(tracking_mixin_instance): - """Tests that _track method does not store history when tracking is disabled.""" - # Setup - instance = tracking_mixin_instance - instance.track = False - method_name = "test_method" - args = ([1, 2, 3], [4, 5, 6]) - kwargs = {"color": "red"} - plot_id = "test_plot" - - # Execute - instance._track(False, plot_id, method_name, args, kwargs) - - # Verify - assert plot_id not in instance._ax_history - - -def test_track_method_with_id_from_kwargs(tracking_mixin_instance): - """Tests that _track method extracts id from kwargs if present.""" - # Setup - instance = tracking_mixin_instance - method_name = "test_method" - args = ([1, 2, 3],) - kwargs = {"color": "blue", "id": "kwargs_id"} - # Expected key includes ax position prefix (defaults to ax_00_) - expected_key = "ax_00_kwargs_id" - - # Execute - instance._track(True, None, method_name, args, kwargs) - - # Verify - key is prefixed with ax position - assert expected_key in instance._ax_history - assert "id" not in kwargs # id should have been removed from kwargs - - -def test_no_tracking_context_manager(tracking_mixin_instance): - """Tests that _no_tracking context manager temporarily disables tracking.""" - # Setup - instance = tracking_mixin_instance - instance.track = True - - # Execute - with instance._no_tracking(): - tracking_during = instance.track - tracking_after = instance.track - - # Verify - assert tracking_during is False - assert tracking_after is True - - -def test_history_property(tracking_mixin_instance): - """Tests that history property returns the correct dictionary.""" - # Setup - instance = tracking_mixin_instance - instance._ax_history = { - "plot1": ("plot1", "method1", ([1, 2], [3, 4]), {}), - "plot2": ("plot2", "method2", ([5, 6], [7, 8]), {"color": "red"}), - } - - # Execute - history = instance.history - - # Verify - assert history == instance._ax_history - assert "plot1" in history - assert "plot2" in history - - -def test_reset_history(tracking_mixin_instance): - """Tests that reset_history clears the history.""" - # Setup - instance = tracking_mixin_instance - instance._ax_history = { - "plot1": ("plot1", "method1", ([1, 2], [3, 4]), {}), - } - - # Execute - instance.reset_history() - - # Verify - assert instance._ax_history == {} - - -# def test_export_as_csv_method(tracking_mixin_instance): -# """Tests that export_as_csv method correctly converts history to DataFrame.""" -# # Setup -# instance = tracking_mixin_instance -# history_data = { -# "plot1": ("plot1", "plot", ([1, 2, 3], [4, 5, 6]), {}), -# } -# instance._ax_history = history_data - -# # Mock the _export_as_csv function -# expected_df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) - -# with patch( -# "scitex.plt._subplots._AxisWrapperMixins._TrackingMixin.TrackingMixin.export_as_csv", -# return_value=expected_df, -# ) as mock_export_as_csv: -# # Execute -# result = instance.export_as_csv() - -# # Verify -# mock_export_as_csv.assert_called_once_with(history_data) -# pd.testing.assert_frame_equal(result, expected_df) - - -# def test_flat_property_with_single_axis(tracking_mixin_instance): -# """Tests that flat property returns a list with a single axis.""" -# # Setup -# instance = tracking_mixin_instance -# instance.axis = MagicMock() - -# # Execute -# flat_result = instance.flat - -# # Verify -# assert isinstance(flat_result, list) -# assert len(flat_result) == 1 -# assert flat_result[0] == instance.axis - - -# def test_flat_property_with_multiple_axes(tracking_mixin_instance): -# """Tests that flat property returns the axis list when axis is already a list.""" -# # Setup -# instance = tracking_mixin_instance -# axis_list = [MagicMock(), MagicMock()] -# instance.axis = axis_list - -# # Execute -# flat_result = instance.flat - -# # Verify -# assert flat_result is axis_list - - -# def test_export_as_csv_with_none_result(tracking_mixin_instance): -# """Tests that export_as_csv returns empty DataFrame when _export_as_csv returns None.""" -# # Setup -# instance = tracking_mixin_instance - -# with patch( -# "scitex.plt._subplots._AxisWrapperMixins._export_as_csv.export_as_csv", -# return_value=None, -# ) as mock_export_as_csv: -# # Execute -# result = instance.export_as_csv() - -# # Verify -# assert isinstance(result, pd.DataFrame) -# assert result.empty - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-04-30 18:40:59 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# """ -# Functionality: -# * Handles tracking and history management for matplotlib plot operations -# Input: -# * Plot method calls, their arguments, and tracking configuration -# Output: -# * Tracked plotting history and DataFrame export for analysis -# Prerequisites: -# * pandas, matplotlib -# """ -# -# from contextlib import contextmanager -# -# import pandas as pd -# -# from .._export_as_csv import export_as_csv as _export_as_csv -# -# -# class TrackingMixin: -# """Mixin class for tracking matplotlib plotting operations. -# -# Example -# ------- -# >>> fig, ax = plt.subplots() -# >>> ax.track = True -# >>> ax.id = 0 -# >>> ax._ax_history = OrderedDict() -# >>> ax.plot([1, 2, 3], [4, 5, 6], id="plot1") -# >>> print(ax.history) -# {'plot1': ('plot1', 'plot', {'plot_df': DataFrame, ...}, {})} -# """ -# -# def _track(self, track, id, method_name, tracked_dict, kwargs=None): -# """Track plotting operation with auto-generated IDs. -# -# Args: -# track: Whether to track this operation -# id: Identifier for the plot (can be None) -# method_name: Name of the plotting method -# tracked_dict: Dictionary of tracked data -# kwargs: Original keyword arguments -# """ -# # Extract id from kwargs and remove it before passing to matplotlib -# if kwargs is not None and hasattr(kwargs, "get") and "id" in kwargs: -# id = kwargs.pop("id") -# -# # Default kwargs to empty dict if None -# if kwargs is None: -# kwargs = {} -# -# if track is None: -# track = self.track -# -# if track: -# # Get axes position from _scitex_metadata if available -# ax_row, ax_col = 0, 0 -# if hasattr(self, "_axis_mpl") and hasattr(self._axis_mpl, "_scitex_metadata"): -# meta = self._axis_mpl._scitex_metadata -# if "position_in_grid" in meta: -# ax_row, ax_col = meta["position_in_grid"] -# -# # If no ID was provided, generate one using method_name + counter -# if id is None: -# # Initialize method counters if not exist -# if not hasattr(self, "_method_counters"): -# self._method_counters = {} -# -# # Get current counter value for this method and increment it -# counter = self._method_counters.get(method_name, 0) -# self._method_counters[method_name] = counter + 1 -# -# # Format ID with axes position: ax_RC_method_counter -# # e.g., ax_00_plot_0, ax_01_bar_1, ax_10_scatter_2 -# id = f"ax_{ax_row}{ax_col}_{method_name}_{counter}" -# else: -# # User-provided ID - prepend axes position -# # e.g., ax_00_sine, ax_01_my-data -# id = f"ax_{ax_row}{ax_col}_{id}" -# -# # For backward compatibility -# self.id += 1 -# -# # Store the tracking record -# self._ax_history[id] = (id, method_name, tracked_dict, kwargs) -# -# @contextmanager -# def _no_tracking(self): -# """Context manager to temporarily disable tracking.""" -# original_track = self.track -# self.track = False -# try: -# yield -# finally: -# self.track = original_track -# -# @property -# def history(self): -# return {k: self._ax_history[k] for k in self._ax_history} -# -# @property -# def flat(self): -# if isinstance(self._axis_mpl, list): -# return self._axis_mpl -# else: -# return [self._axis_mpl] -# -# def reset_history(self): -# self._ax_history = {} -# -# def export_as_csv(self): -# """ -# Export tracked plotting data to a DataFrame. -# """ -# df = _export_as_csv(self.history) -# -# return df if df is not None else pd.DataFrame() -# -# def export_as_csv_for_sigmaplot(self, include_visual_params=True): -# """ -# Export tracked plotting data to a DataFrame in SigmaPlot format. -# -# Parameters -# ---------- -# include_visual_params : bool, optional -# Whether to include visual parameters (xlabel, ylabel, scales, etc.) -# at the top of the CSV. Default is True. -# -# Returns -# ------- -# pandas.DataFrame -# DataFrame containing the plotted data formatted for SigmaPlot. -# -# Examples -# -------- -# >>> fig, ax = scitex.plt.subplots() -# >>> ax.plot([1, 2, 3], [4, 5, 6]) -# >>> ax.scatter([1, 2, 3], [7, 8, 9]) -# >>> df = ax.export_as_csv_for_sigmaplot() -# >>> df.to_csv('for_sigmaplot.csv', index=False) -# """ -# df = _export_as_csv(self.history) -# -# return df if df is not None else pd.DataFrame() -# -# # def _track( -# # self, -# # track: Optional[bool], -# # plot_id: Optional[str], -# # method_name: str, -# # tracked_dict: Any, -# # kwargs: Dict[str, Any] -# # ) -> None: -# # """Tracks plotting operation if tracking is enabled.""" -# # if track is None: -# # track = self.track -# # if track: -# # plot_id = plot_id if plot_id is not None else self.id -# # self.id += 1 -# # self._ax_history[plot_id] = (plot_id, method_name, tracked_dict, kwargs) -# -# # @contextmanager -# # def _no_tracking(self) -> None: -# # """Temporarily disables tracking within a context.""" -# # original_track = self.track -# # self.track = False -# # try: -# # yield -# # finally: -# # self.track = original_track -# -# # @property -# # def history(self) -> Dict[str, Tuple]: -# # """Returns the plotting history.""" -# # return dict(self._ax_history) -# -# # def reset_history(self) -> None: -# # """Clears the plotting history.""" -# # self._ax_history = OrderedDict() -# -# # def export_as_csv(self) -> pd.DataFrame: -# # """Converts plotting history to a SigmaPlot-compatible DataFrame.""" -# # df = _export_as_csv(self.history) -# # return df if df is not None else pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_TrackingMixin.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_AxisWrapperMixins/test__UnitAwareMixin.py b/tests/scitex/plt/_subplots/_AxisWrapperMixins/test__UnitAwareMixin.py deleted file mode 100644 index 73ca28eeb..000000000 --- a/tests/scitex/plt/_subplots/_AxisWrapperMixins/test__UnitAwareMixin.py +++ /dev/null @@ -1,456 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_UnitAwareMixin.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-08-01 10:35:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/SciTeX-Code/src/scitex/plt/_subplots/_AxisWrapperMixins/_UnitAwareMixin.py -# # ---------------------------------------- -# -# """ -# Unit-Aware Plotting Mixin -# ========================= -# -# This mixin adds unit handling capabilities to the AxisWrapper class, -# ensuring scientific validity in plots. -# -# Features: -# - Automatic unit tracking for axes -# - Unit validation for data compatibility -# - Automatic unit conversion -# - Unit-aware axis labels -# """ -# -# from typing import Optional, Dict, Tuple, Union, Any -# import re -# import numpy as np -# from scitex.units import Unit, Q, Units -# from scitex.logging import SciTeXError -# import scitex.logging as logging -# from scitex.logging import UnitWarning, warn as _warn -# -# # Valid dimensionless/special unit markers -# _VALID_DIMENSIONLESS = {"[-]", "[a.u.]", "[arb. units]", "[dimensionless]", "[1]", "[A.U.]"} -# -# -# def _convert_to_negative_exponent(unit: str) -> str: -# """Convert unit with / to negative exponent format. -# -# Examples: -# m/s -> m·s⁻¹ -# kg/m^2 -> kg·m⁻² -# W/m^2/K -> W·m⁻²·K⁻¹ -# """ -# superscript = str.maketrans("0123456789-", "⁰¹²³⁴⁵⁶⁷⁸⁹⁻") -# -# parts = unit.split("/") -# if len(parts) < 2: -# return unit -# -# result = parts[0] -# for part in parts[1:]: -# exp_match = re.match(r"([a-zA-Z]+)\^?(\d+)?", part) -# if exp_match: -# base = exp_match.group(1) -# exp = exp_match.group(2) or "1" -# neg_exp = f"-{exp}".translate(superscript) -# result += f"·{base}{neg_exp}" -# else: -# result += f"·{part}⁻¹" -# -# return result -# -# -# def validate_axis_label(label: str, axis_name: str = "axis") -> str: -# """Validate and warn about axis label units (educational for scientific standards). -# -# Checks for: -# - Missing units -# - Non-standard format (prefer [] over ()) -# - Suggests ^-1 format over / -# -# Parameters -# ---------- -# label : str -# Axis label to validate -# axis_name : str -# Name for warning messages (e.g., "X axis", "Y axis") -# -# Returns -# ------- -# str -# Original label (warnings are educational, not auto-correcting) -# """ -# if not label: -# return label -# -# # Check for units in brackets [] or parentheses () -# has_square_brackets = bool(re.search(r"\[.*?\]", label)) -# has_parentheses = bool(re.search(r"\(.*?\)", label)) -# -# unit_match_square = re.search(r"\[(.*?)\]", label) -# unit_match_paren = re.search(r"\((.*?)\)", label) -# -# if not has_square_brackets and not has_parentheses: -# _warn( -# f"{axis_name} label '{label}' has no units. " -# f"Consider: '{label} [unit]' or '{label} [-]' for dimensionless", -# UnitWarning, -# stacklevel=3, -# ) -# return label -# -# if has_parentheses and not has_square_brackets: -# unit = unit_match_paren.group(1) if unit_match_paren else "" -# suggested = re.sub(r"\((.*?)\)", f"[{unit}]", label) -# _warn( -# f"{axis_name} label '{label}' uses parentheses. " -# f"SI convention prefers: '{suggested}'", -# UnitWarning, -# stacklevel=3, -# ) -# -# unit_content = None -# if unit_match_square: -# unit_content = unit_match_square.group(1) -# elif unit_match_paren: -# unit_content = unit_match_paren.group(1) -# -# if unit_content and "/" in unit_content: -# suggested_unit = _convert_to_negative_exponent(unit_content) -# if suggested_unit != unit_content: -# suggested_label = label.replace(f"[{unit_content}]", f"[{suggested_unit}]") -# _warn( -# f"{axis_name} uses '/' in units. Consider: '{suggested_label}'", -# UnitWarning, -# stacklevel=3, -# ) -# -# return label -# -# -# class UnitMismatchError(SciTeXError): -# """Raised when units are incompatible for an operation.""" -# -# pass -# -# -# class UnitAwareMixin: -# """Mixin that adds unit awareness to plotting operations.""" -# -# def __init__(self): -# """Initialize unit tracking.""" -# self._x_unit: Optional[Unit] = None -# self._y_unit: Optional[Unit] = None -# self._z_unit: Optional[Unit] = None -# self._unit_validation_enabled: bool = True -# -# def set_unit_validation(self, enabled: bool) -> None: -# """Enable or disable unit validation.""" -# self._unit_validation_enabled = enabled -# -# def set_x_unit(self, unit: Union[str, Unit]) -> None: -# """Set the unit for the x-axis.""" -# if isinstance(unit, str): -# unit_obj = getattr(Units, unit, None) -# if unit_obj is None: -# raise ValueError(f"Unknown unit: {unit}") -# unit = unit_obj -# self._x_unit = unit -# self._update_xlabel_with_unit() -# -# def set_y_unit(self, unit: Union[str, Unit]) -> None: -# """Set the unit for the y-axis.""" -# if isinstance(unit, str): -# unit_obj = getattr(Units, unit, None) -# if unit_obj is None: -# raise ValueError(f"Unknown unit: {unit}") -# unit = unit_obj -# self._y_unit = unit -# self._update_ylabel_with_unit() -# -# def set_z_unit(self, unit: Union[str, Unit]) -> None: -# """Set the unit for the z-axis (for 3D plots).""" -# if isinstance(unit, str): -# unit_obj = getattr(Units, unit, None) -# if unit_obj is None: -# raise ValueError(f"Unknown unit: {unit}") -# unit = unit_obj -# self._z_unit = unit -# self._update_zlabel_with_unit() -# -# def get_x_unit(self) -> Optional[Unit]: -# """Get the current x-axis unit.""" -# return self._x_unit -# -# def get_y_unit(self) -> Optional[Unit]: -# """Get the current y-axis unit.""" -# return self._y_unit -# -# def get_z_unit(self) -> Optional[Unit]: -# """Get the current z-axis unit.""" -# return self._z_unit -# -# def _update_xlabel_with_unit(self) -> None: -# """Update x-axis label to include unit.""" -# if self._x_unit and hasattr(self, "_axes_mpl"): -# current_label = self._axes_mpl.get_xlabel() -# # Remove existing unit if present -# if "[" in current_label and "]" in current_label: -# current_label = current_label.split("[")[0].strip() -# if current_label: -# self._axes_mpl.set_xlabel(f"{current_label} [{self._x_unit.symbol}]") -# -# def _update_ylabel_with_unit(self) -> None: -# """Update y-axis label to include unit.""" -# if self._y_unit and hasattr(self, "_axes_mpl"): -# current_label = self._axes_mpl.get_ylabel() -# # Remove existing unit if present -# if "[" in current_label and "]" in current_label: -# current_label = current_label.split("[")[0].strip() -# if current_label: -# self._axes_mpl.set_ylabel(f"{current_label} [{self._y_unit.symbol}]") -# -# def _update_zlabel_with_unit(self) -> None: -# """Update z-axis label to include unit (for 3D plots).""" -# if ( -# self._z_unit -# and hasattr(self, "_axes_mpl") -# and hasattr(self._axes_mpl, "set_zlabel") -# ): -# current_label = self._axes_mpl.get_zlabel() -# # Remove existing unit if present -# if "[" in current_label and "]" in current_label: -# current_label = current_label.split("[")[0].strip() -# if current_label: -# self._axes_mpl.set_zlabel(f"{current_label} [{self._z_unit.symbol}]") -# -# def plot_with_units(self, x, y, x_unit=None, y_unit=None, **kwargs): -# """Plot with automatic unit handling. -# -# Parameters -# ---------- -# x : array-like or Quantity -# X-axis data -# y : array-like or Quantity -# Y-axis data -# x_unit : str or Unit, optional -# Unit for x-axis (overrides detected unit) -# y_unit : str or Unit, optional -# Unit for y-axis (overrides detected unit) -# **kwargs : dict -# Additional plotting parameters -# -# Returns -# ------- -# lines : list of Line2D -# The plotted lines -# """ -# # Extract values and units from Quantity objects -# x_val, x_detected_unit = self._extract_value_and_unit(x) -# y_val, y_detected_unit = self._extract_value_and_unit(y) -# -# # Use provided units or detected units -# if x_unit: -# self.set_x_unit(x_unit) -# elif x_detected_unit and not self._x_unit: -# self.set_x_unit(x_detected_unit) -# -# if y_unit: -# self.set_y_unit(y_unit) -# elif y_detected_unit and not self._y_unit: -# self.set_y_unit(y_detected_unit) -# -# # Validate units if enabled -# if self._unit_validation_enabled: -# self._validate_unit_compatibility(x_detected_unit, self._x_unit, "x") -# self._validate_unit_compatibility(y_detected_unit, self._y_unit, "y") -# -# # Plot using the standard method -# return self.plot(x_val, y_val, **kwargs) -# -# def _extract_value_and_unit(self, data) -> Tuple[np.ndarray, Optional[Unit]]: -# """Extract numerical value and unit from data.""" -# if hasattr(data, "value") and hasattr(data, "unit"): -# # It's a Quantity object -# return data.value, data.unit -# else: -# # Regular array -# return np.asarray(data), None -# -# def _validate_unit_compatibility( -# self, data_unit: Optional[Unit], axis_unit: Optional[Unit], axis_name: str -# ) -> None: -# """Validate that data unit is compatible with axis unit.""" -# if not self._unit_validation_enabled: -# return -# -# if data_unit and axis_unit: -# # Check if units have same dimensions -# if data_unit.dimensions != axis_unit.dimensions: -# raise UnitMismatchError( -# f"Unit mismatch on {axis_name}-axis: " -# f"data has unit {data_unit.symbol} {data_unit.dimensions}, " -# f"but axis expects {axis_unit.symbol} {axis_unit.dimensions}" -# ) -# -# def convert_x_units( -# self, new_unit: Union[str, Unit], update_data: bool = True -# ) -> float: -# """Convert x-axis to new units. -# -# Parameters -# ---------- -# new_unit : str or Unit -# Target unit -# update_data : bool -# Whether to update plotted data -# -# Returns -# ------- -# float -# Conversion factor applied -# """ -# if isinstance(new_unit, str): -# new_unit = getattr(Units, new_unit) -# -# if not self._x_unit: -# raise ValueError("No x-axis unit set") -# -# # Calculate conversion factor -# factor = self._x_unit.scale / new_unit.scale -# -# if update_data and hasattr(self, "_axes_mpl"): -# # Update all line data -# for line in self._axes_mpl.lines: -# xdata = line.get_xdata() -# line.set_xdata(xdata * factor) -# -# # Update x-axis limits -# xlim = self._axes_mpl.get_xlim() -# self._axes_mpl.set_xlim([x * factor for x in xlim]) -# -# # Update unit -# self.set_x_unit(new_unit) -# -# return factor -# -# def convert_y_units( -# self, new_unit: Union[str, Unit], update_data: bool = True -# ) -> float: -# """Convert y-axis to new units. -# -# Parameters -# ---------- -# new_unit : str or Unit -# Target unit -# update_data : bool -# Whether to update plotted data -# -# Returns -# ------- -# float -# Conversion factor applied -# """ -# if isinstance(new_unit, str): -# new_unit = getattr(Units, new_unit) -# -# if not self._y_unit: -# raise ValueError("No y-axis unit set") -# -# # Calculate conversion factor -# factor = self._y_unit.scale / new_unit.scale -# -# if update_data and hasattr(self, "_axes_mpl"): -# # Update all line data -# for line in self._axes_mpl.lines: -# ydata = line.get_ydata() -# line.set_ydata(ydata * factor) -# -# # Update y-axis limits -# ylim = self._axes_mpl.get_ylim() -# self._axes_mpl.set_ylim([y * factor for y in ylim]) -# -# # Update unit -# self.set_y_unit(new_unit) -# -# return factor -# -# def set_xlabel(self, label: str, unit: Optional[Union[str, Unit]] = None) -> None: -# """Set x-axis label with optional unit. -# -# Parameters -# ---------- -# label : str -# Axis label text -# unit : str or Unit, optional -# Unit to display -# """ -# if unit: -# self.set_x_unit(unit) -# -# if self._x_unit: -# label = f"{label} [{self._x_unit.symbol}]" -# -# # Validate units (educational warnings for scientific standards) -# validate_axis_label(label, "X axis") -# -# self._axes_mpl.set_xlabel(label) -# -# def set_ylabel(self, label: str, unit: Optional[Union[str, Unit]] = None) -> None: -# """Set y-axis label with optional unit. -# -# Parameters -# ---------- -# label : str -# Axis label text -# unit : str or Unit, optional -# Unit to display -# """ -# if unit: -# self.set_y_unit(unit) -# -# if self._y_unit: -# label = f"{label} [{self._y_unit.symbol}]" -# -# # Validate units (educational warnings for scientific standards) -# validate_axis_label(label, "Y axis") -# -# self._axes_mpl.set_ylabel(label) -# -# def set_zlabel(self, label: str, unit: Optional[Union[str, Unit]] = None) -> None: -# """Set z-axis label with optional unit (for 3D plots). -# -# Parameters -# ---------- -# label : str -# Axis label text -# unit : str or Unit, optional -# Unit to display -# """ -# if not hasattr(self._axes_mpl, "set_zlabel"): -# raise ValueError("Z-axis labels only available for 3D plots") -# -# if unit: -# self.set_z_unit(unit) -# -# if self._z_unit: -# label = f"{label} [{self._z_unit.symbol}]" -# -# # Validate units (educational warnings for scientific standards) -# validate_axis_label(label, "Z axis") -# -# self._axes_mpl.set_zlabel(label) - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapperMixins/_UnitAwareMixin.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_annotate.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_annotate.py deleted file mode 100644 index b29030805..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_annotate.py +++ /dev/null @@ -1,88 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_annotate.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-10-04 02:30:00 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/SciTeX-Code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_annotate.py -# # ---------------------------------------- -# from __future__ import annotations -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_annotate(id, tracked_dict, kwargs): -# """Format data from an annotate call. -# -# matplotlib annotate signature: annotate(text, xy, xytext=None, **kwargs) -# - text: The text of the annotation -# - xy: The point (x, y) to annotate -# - xytext: The position (x, y) to place the text at (optional) -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Extract text and xy coordinates if available -# if len(args) >= 2: -# text_content = args[0] -# xy = args[1] -# -# # xy should be a tuple (x, y) -# if hasattr(xy, "__len__") and len(xy) >= 2: -# x, y = xy[0], xy[1] -# else: -# return pd.DataFrame() -# -# data = { -# get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): [x], -# get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id): [y], -# get_csv_column_name("content", ax_row, ax_col, trace_id=trace_id): [text_content], -# } -# -# # Check if xytext was provided (either as third arg or in kwargs) -# xytext = None -# if len(args) >= 3: -# xytext = args[2] -# elif "xytext" in kwargs: -# xytext = kwargs["xytext"] -# -# if xytext is not None and hasattr(xytext, "__len__") and len(xytext) >= 2: -# data[get_csv_column_name("text_x", ax_row, ax_col, trace_id=trace_id)] = [xytext[0]] -# data[get_csv_column_name("text_y", ax_row, ax_col, trace_id=trace_id)] = [xytext[1]] -# -# # Create DataFrame with proper column names (use dict with list values) -# df = pd.DataFrame(data) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_annotate.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_bar.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_bar.py deleted file mode 100644 index 222baed4a..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_bar.py +++ /dev/null @@ -1,154 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-19 15:45:51 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import pandas as pd -# import numpy as np -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_bar(id, tracked_dict, kwargs): -# """Format data from a bar call for CSV export. -# -# Includes x, y values and optional yerr for error bars. -# -# Args: -# id: The identifier for the plot -# tracked_dict: Dictionary of tracked data -# kwargs: Original keyword arguments (may contain yerr) -# -# Returns: -# pd.DataFrame: Formatted data ready for CSV export with x, y, and optional yerr -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get structured column names -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# col_yerr = get_csv_column_name("yerr", ax_row, ax_col, trace_id=trace_id) -# -# # Extract yerr from kwargs if present -# yerr = kwargs.get("yerr") if kwargs else None -# -# # Check if we have the newer format with bar_data -# if "bar_data" in tracked_dict and isinstance( -# tracked_dict["bar_data"], pd.DataFrame -# ): -# # Use the pre-formatted DataFrame but keep only x and height (y) -# df = tracked_dict["bar_data"].copy() -# -# # Keep only essential columns -# essential_cols = [col for col in df.columns if col in ["x", "height"]] -# if essential_cols: -# df = df[essential_cols] -# -# # Rename using structured naming -# rename_map = {} -# if "x" in df.columns: -# rename_map["x"] = col_x -# if "height" in df.columns: -# rename_map["height"] = col_y -# -# df = df.rename(columns=rename_map) -# -# # Add yerr if present -# if yerr is not None: -# try: -# yerr_array = np.asarray(yerr) -# if len(yerr_array) == len(df): -# df[col_yerr] = yerr_array -# except (TypeError, ValueError): -# pass -# -# return df -# -# # Legacy format - get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Extract x and y data if available -# if len(args) >= 2: -# x, y = args[0], args[1] -# -# # Convert to arrays if possible for consistent handling -# try: -# x_array = np.asarray(x) -# y_array = np.asarray(y) -# -# # Create DataFrame with structured column names -# data = { -# col_x: x_array, -# col_y: y_array, -# } -# -# # Add yerr if present -# if yerr is not None: -# try: -# yerr_array = np.asarray(yerr) -# if len(yerr_array) == len(x_array): -# data[col_yerr] = yerr_array -# except (TypeError, ValueError): -# pass -# -# return pd.DataFrame(data) -# -# except (TypeError, ValueError): -# # Fall back to direct values if conversion fails -# result = {col_x: x, col_y: y} -# if yerr is not None: -# result[col_yerr] = yerr -# return pd.DataFrame(result) -# -# # If we have tracked data in another format (like our MatplotlibPlotMixin bar method) -# result = {} -# -# # Check for x position (might be in different keys) -# for x_key in ["x", "xs", "positions"]: -# if x_key in tracked_dict: -# result[col_x] = tracked_dict[x_key] -# break -# -# # Check for y values (might be in different keys) -# for y_key in ["y", "ys", "height", "heights", "values"]: -# if y_key in tracked_dict: -# result[col_y] = tracked_dict[y_key] -# break -# -# # Add yerr if present in kwargs -# if yerr is not None and result: -# try: -# yerr_array = np.asarray(yerr) -# result[col_yerr] = yerr_array -# except (TypeError, ValueError): -# pass -# -# return pd.DataFrame(result) if result else pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_bar.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_barh.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_barh.py deleted file mode 100644 index 3f3c2f0f8..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_barh.py +++ /dev/null @@ -1,74 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_barh(id, tracked_dict, kwargs): -# """Format data from a barh call.""" -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Extract x and y data if available -# if len(args) >= 2: -# # Note: in barh, first arg is y positions, second is widths (x values) -# y_pos, x_width = args[0], args[1] -# -# # Get xerr from kwargs -# xerr = kwargs.get("xerr") -# -# # Convert single values to Series -# if isinstance(y_pos, (int, float)): -# y_pos = pd.Series(y_pos, name="y") -# if isinstance(x_width, (int, float)): -# x_width = pd.Series(x_width, name="x") -# else: -# # Not enough arguments -# return pd.DataFrame() -# -# # Use structured column naming: ax-row-{row}-col-{col}_trace-id-{id}_variable-{var} -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# -# df = pd.DataFrame({col_y: y_pos, col_x: x_width}) -# -# if xerr is not None: -# if isinstance(xerr, (int, float)): -# xerr = pd.Series(xerr, name="xerr") -# col_xerr = get_csv_column_name("xerr", ax_row, ax_col, trace_id=trace_id) -# df[col_xerr] = xerr -# return df - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_barh.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_boxplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_boxplot.py deleted file mode 100644 index be8daa88b..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_boxplot.py +++ /dev/null @@ -1,90 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_boxplot(id, tracked_dict, kwargs): -# """Format data from a boxplot call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to boxplot -# -# Returns: -# pd.DataFrame: Formatted data from boxplot -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# args = tracked_dict.get("args", []) -# call_kwargs = tracked_dict.get("kwargs", {}) -# -# # Get labels if provided (for consistent naming with stats) -# labels = call_kwargs.get("labels", None) -# -# if len(args) >= 1: -# x = args[0] -# -# # One box plot -# from scitex.types import is_listed_X as scitex_types_is_listed_X -# -# if isinstance(x, np.ndarray) or scitex_types_is_listed_X(x, [float, int]): -# df = pd.DataFrame(x) -# # Use label if single box and labels provided -# if labels and len(labels) == 1: -# col_name = get_csv_column_name(labels[0], ax_row, ax_col, trace_id=trace_id) -# else: -# col_name = get_csv_column_name("data-0", ax_row, ax_col, trace_id=trace_id) -# df.columns = [col_name] -# else: -# # Multiple boxes -# import scitex.pd -# -# df = scitex.pd.force_df({i_x: _x for i_x, _x in enumerate(x)}) -# -# # Use labels if provided, otherwise use numeric indices -# if labels and len(labels) == len(df.columns): -# df.columns = [ -# get_csv_column_name(label, ax_row, ax_col, trace_id=trace_id) -# for label in labels -# ] -# else: -# df.columns = [ -# get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# for col in range(len(df.columns)) -# ] -# -# df = df.apply(lambda col: col.dropna().reset_index(drop=True)) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_boxplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_contour.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_contour.py deleted file mode 100644 index 26e474bab..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_contour.py +++ /dev/null @@ -1,66 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_contour(id, tracked_dict, kwargs): -# """Format data from a contour call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to contour -# -# Returns: -# pd.DataFrame: Formatted data from contour plot (flattened X, Y, Z grids) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# args = tracked_dict.get("args", []) -# -# # Typical args: X, Y, Z where X and Y are 2D coordinate arrays and Z is the height array -# if len(args) >= 3: -# X, Y, Z = args[:3] -# X_flat = np.asarray(X).flatten() -# Y_flat = np.asarray(Y).flatten() -# Z_flat = np.asarray(Z).flatten() -# -# # Get column names from single source of truth -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# col_z = get_csv_column_name("z", ax_row, ax_col, trace_id=trace_id) -# -# df = pd.DataFrame({col_x: X_flat, col_y: Y_flat, col_z: Z_flat}) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contour.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_contourf.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_contourf.py deleted file mode 100644 index 06535bfcb..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_contourf.py +++ /dev/null @@ -1,78 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_contourf(id, tracked_dict, kwargs): -# """Format data from a filled contour plot call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to contourf -# -# Returns: -# pd.DataFrame: Formatted data from contourf (flattened X, Y, Z grids) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple): -# # contourf can be called as: -# # contourf(Z) - Z is 2D -# # contourf(X, Y, Z) - X, Y are 1D or 2D, Z is 2D -# if len(args) == 1: -# Z = np.asarray(args[0]) -# X, Y = np.meshgrid(np.arange(Z.shape[1]), np.arange(Z.shape[0])) -# elif len(args) >= 3: -# X = np.asarray(args[0]) -# Y = np.asarray(args[1]) -# Z = np.asarray(args[2]) -# # If X, Y are 1D, create meshgrid -# if X.ndim == 1 and Y.ndim == 1: -# X, Y = np.meshgrid(X, Y) -# else: -# return pd.DataFrame() -# -# # Get column names from single source of truth -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# col_z = get_csv_column_name("z", ax_row, ax_col, trace_id=trace_id) -# -# df = pd.DataFrame( -# {col_x: X.flatten(), col_y: Y.flatten(), col_z: Z.flatten()} -# ) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_contourf.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_errorbar.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_errorbar.py deleted file mode 100644 index ddc2fc0cd..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_errorbar.py +++ /dev/null @@ -1,104 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_errorbar(id, tracked_dict, kwargs): -# """Format data from an errorbar call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to errorbar -# -# Returns: -# pd.DataFrame: Formatted data from errorbar plot -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# args = tracked_dict.get("args", []) -# -# if len(args) >= 2: -# x, y = args[:2] -# xerr = kwargs.get("xerr") -# yerr = kwargs.get("yerr") -# -# # Get column names from single source of truth -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# -# data = {col_x: x, col_y: y} -# -# if xerr is not None: -# if isinstance(xerr, (list, tuple)) and len(xerr) == 2: -# col_xerr_neg = get_csv_column_name("xerr-neg", ax_row, ax_col, trace_id=trace_id) -# col_xerr_pos = get_csv_column_name("xerr-pos", ax_row, ax_col, trace_id=trace_id) -# data[col_xerr_neg] = xerr[0] -# data[col_xerr_pos] = xerr[1] -# else: -# col_xerr = get_csv_column_name("xerr", ax_row, ax_col, trace_id=trace_id) -# data[col_xerr] = xerr -# -# if yerr is not None: -# if isinstance(yerr, (list, tuple)) and len(yerr) == 2: -# col_yerr_neg = get_csv_column_name("yerr-neg", ax_row, ax_col, trace_id=trace_id) -# col_yerr_pos = get_csv_column_name("yerr-pos", ax_row, ax_col, trace_id=trace_id) -# data[col_yerr_neg] = yerr[0] -# data[col_yerr_pos] = yerr[1] -# else: -# col_yerr = get_csv_column_name("yerr", ax_row, ax_col, trace_id=trace_id) -# data[col_yerr] = yerr -# -# # Handle different length arrays by padding -# max_len = max( -# len(arr) if hasattr(arr, "__len__") else 1 -# for arr in data.values() -# if arr is not None -# ) -# -# for key, value in list(data.items()): -# if value is None: -# continue -# if not hasattr(value, "__len__"): -# data[key] = [value] * max_len -# elif len(value) < max_len: -# data[key] = np.pad( -# np.asarray(value), -# (0, max_len - len(value)), -# mode="constant", -# constant_values=np.nan, -# ) -# -# return pd.DataFrame(data) -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_errorbar.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_eventplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_eventplot.py deleted file mode 100644 index aa186f242..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_eventplot.py +++ /dev/null @@ -1,99 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# import pandas as pd -# import scitex -# -# from scitex import logging -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# logger = logging.getLogger(__name__) -# -# -# def _format_eventplot(id, tracked_dict, kwargs): -# """Format data from an eventplot call.""" -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Eventplot displays multiple sets of events as parallel lines -# if len(args) >= 1: -# positions = args[0] -# -# try: -# # Try using scitex.pd.force_df if available -# try: -# import scitex.pd -# -# # If positions is a single array -# if isinstance(positions, (list, np.ndarray)) and not isinstance( -# positions[0], (list, np.ndarray) -# ): -# col_name = get_csv_column_name("eventplot-events", ax_row, ax_col, trace_id=trace_id) -# return pd.DataFrame({col_name: positions}) -# -# # If positions is a list of arrays (multiple event sets) -# elif isinstance(positions, (list, np.ndarray)): -# data = {} -# for i, events in enumerate(positions): -# col_name = get_csv_column_name(f"eventplot-events{i:02d}", ax_row, ax_col, trace_id=f"{trace_id}-{i}") -# data[col_name] = events -# -# # Use force_df to handle different length arrays -# return scitex.pd.force_df(data) -# -# except (ImportError, AttributeError): -# # Fall back to pandas with manual Series creation -# # If positions is a single array -# if isinstance(positions, (list, np.ndarray)) and not isinstance( -# positions[0], (list, np.ndarray) -# ): -# col_name = get_csv_column_name("eventplot-events", ax_row, ax_col, trace_id=trace_id) -# return pd.DataFrame({col_name: positions}) -# -# # If positions is a list of arrays (multiple event sets) -# elif isinstance(positions, (list, np.ndarray)): -# # Create a DataFrame where each column is a Series that can handle varying lengths -# df = pd.DataFrame() -# for i, events in enumerate(positions): -# col_name = get_csv_column_name(f"eventplot-events{i:02d}", ax_row, ax_col, trace_id=f"{trace_id}-{i}") -# df[col_name] = pd.Series(events) -# return df -# except Exception as e: -# # If all else fails, return an empty DataFrame -# logger.warning(f"Error formatting eventplot data: {str(e)}") -# return pd.DataFrame() -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_eventplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_fill.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_fill.py deleted file mode 100644 index c543e0ef7..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_fill.py +++ /dev/null @@ -1,62 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_fill(id, tracked_dict, kwargs): -# """Format data from a fill call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to fill -# -# Returns: -# pd.DataFrame: Formatted data from fill plot -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# args = tracked_dict.get("args", []) -# -# # Fill creates a polygon based on points -# if len(args) >= 2: -# x = args[0] -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# data = {col_x: x} -# -# for i, y in enumerate(args[1:]): -# col_y = get_csv_column_name(f"y{i:02d}", ax_row, ax_col, trace_id=trace_id) -# data[col_y] = y -# -# return pd.DataFrame(data) -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_fill_between.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_fill_between.py deleted file mode 100644 index 0833be952..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_fill_between.py +++ /dev/null @@ -1,62 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_fill_between(id, tracked_dict, kwargs): -# """Format data from a fill_between call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to fill_between -# -# Returns: -# pd.DataFrame: Formatted data from fill_between plot -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# args = tracked_dict.get("args", []) -# -# # Typical args: x, y1, y2 -# if len(args) >= 3: -# x, y1, y2 = args[:3] -# -# # Get column names from single source of truth -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y1 = get_csv_column_name("y1", ax_row, ax_col, trace_id=trace_id) -# col_y2 = get_csv_column_name("y2", ax_row, ax_col, trace_id=trace_id) -# -# df = pd.DataFrame({col_x: x, col_y1: y1, col_y2: y2}) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_fill_between.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hexbin.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hexbin.py deleted file mode 100644 index fe4519600..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hexbin.py +++ /dev/null @@ -1,67 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_hexbin(id, tracked_dict, kwargs): -# """Format data from a hexbin call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to hexbin -# -# Returns: -# pd.DataFrame: Formatted data from hexbin (input x, y data) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple) and len(args) >= 2: -# x = np.asarray(args[0]).flatten() -# y = np.asarray(args[1]).flatten() -# -# # Ensure same length -# min_len = min(len(x), len(y)) -# x = x[:min_len] -# y = y[:min_len] -# -# # Get column names from single source of truth -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# -# df = pd.DataFrame({col_x: x, col_y: y}) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hexbin.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hist.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hist.py deleted file mode 100644 index 1c6754b6b..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hist.py +++ /dev/null @@ -1,107 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_hist(id, tracked_dict, kwargs): -# """ -# Format data from a hist call as a bar plot representation. -# -# This formatter extracts both the raw data and the binned data from histogram plots, -# returning them in a format that can be visualized as a bar plot. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to hist -# -# Returns: -# pd.DataFrame: DataFrame containing both raw data and bin information -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Check if histogram result (bin counts and edges) is available in tracked_dict -# hist_result = tracked_dict.get("hist_result", None) -# -# columns = {} -# -# # Extract raw data if available -# if len(args) >= 1: -# x = args[0] -# col_raw = get_csv_column_name("raw-data", ax_row, ax_col, trace_id=trace_id) -# columns[col_raw] = x -# -# # If we have histogram result (counts and bin edges) -# if hist_result is not None: -# counts, bin_edges = hist_result -# -# # Calculate bin centers for bar plot representation -# bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) -# bin_widths = bin_edges[1:] - bin_edges[:-1] -# -# # Use structured column naming -# col_centers = get_csv_column_name("bin-centers", ax_row, ax_col, trace_id=trace_id) -# col_counts = get_csv_column_name("bin-counts", ax_row, ax_col, trace_id=trace_id) -# col_widths = get_csv_column_name("bin-widths", ax_row, ax_col, trace_id=trace_id) -# col_left = get_csv_column_name("bin-edges-left", ax_row, ax_col, trace_id=trace_id) -# col_right = get_csv_column_name("bin-edges-right", ax_row, ax_col, trace_id=trace_id) -# -# # Add bin information to DataFrame -# columns[col_centers] = bin_centers -# columns[col_counts] = counts -# columns[col_widths] = bin_widths -# columns[col_left] = bin_edges[:-1] -# columns[col_right] = bin_edges[1:] -# -# # Create DataFrame with aligned length -# max_length = max(len(value) for value in columns.values()) -# for key, value in list(columns.items()): -# if len(value) < max_length: -# # Pad with NaN if needed - convert to float first for NaN support -# arr = np.asarray(value, dtype=float) -# padded = np.full(max_length, np.nan) -# padded[:len(arr)] = arr -# columns[key] = padded -# -# # Return DataFrame or empty DataFrame if no data -# if columns: -# return pd.DataFrame(columns) -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hist2d.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hist2d.py deleted file mode 100644 index 66be405a9..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_hist2d.py +++ /dev/null @@ -1,67 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_hist2d(id, tracked_dict, kwargs): -# """Format data from a 2D histogram call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to hist2d -# -# Returns: -# pd.DataFrame: Formatted data from 2D histogram (input x, y data) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple) and len(args) >= 2: -# x = np.asarray(args[0]).flatten() -# y = np.asarray(args[1]).flatten() -# -# # Ensure same length -# min_len = min(len(x), len(y)) -# x = x[:min_len] -# y = y[:min_len] -# -# # Get column names from single source of truth -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# -# df = pd.DataFrame({col_x: x, col_y: y}) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_hist2d.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_imshow.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_imshow.py deleted file mode 100755 index c8c6c1e81..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_imshow.py +++ /dev/null @@ -1,224 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2025-12-01 13:25:00 (ywatanabe)" -# File: tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_imshow.py - -"""Tests for _format_imshow CSV formatter.""" - -import numpy as np -import pandas as pd -import pytest - -pytest.importorskip("zarr") - -from scitex.plt._subplots._export_as_csv_formatters._format_imshow import ( # noqa: E402 - _format_imshow, -) - - -class TestFormatImshow: - """Tests for _format_imshow function.""" - - def test_empty_tracked_dict_returns_empty_df(self): - """Empty tracked_dict should return empty DataFrame.""" - result = _format_imshow("test", {}, {}) - assert isinstance(result, pd.DataFrame) - assert result.empty - - def test_none_tracked_dict_returns_empty_df(self): - """None tracked_dict should return empty DataFrame.""" - result = _format_imshow("test", None, {}) - assert isinstance(result, pd.DataFrame) - assert result.empty - - def test_image_df_key_returns_directly(self): - """When image_df key exists, should return it directly.""" - image_df = pd.DataFrame({"row": [0, 1], "col": [0, 1], "value": [0.5, 0.8]}) - tracked_dict = {"image_df": image_df} - - result = _format_imshow("ax_00", tracked_dict, {}) - - pd.testing.assert_frame_equal(result, image_df) - - def test_args_2d_grayscale_image(self): - """2D grayscale image should be flattened with row/col indices.""" - img = np.array([[0.1, 0.2], [0.3, 0.4]]) - tracked_dict = {"args": (img,)} - - result = _format_imshow("ax_00", tracked_dict, {}) - - assert "r0c0_plot-0_row" in result.columns - assert "r0c0_plot-0_col" in result.columns - assert "r0c0_plot-0_value" in result.columns - assert len(result) == 4 # 2x2 = 4 pixels - - # Check values are correct - expected_values = [0.1, 0.2, 0.3, 0.4] - np.testing.assert_array_almost_equal( - result["r0c0_plot-0_value"].values, expected_values - ) - - def test_args_3d_rgb_image(self): - """3D RGB image should have R, G, B columns.""" - img = np.zeros((2, 2, 3)) - img[:, :, 0] = [[255, 0], [0, 255]] # R - img[:, :, 1] = [[0, 255], [255, 0]] # G - img[:, :, 2] = [[128, 128], [128, 128]] # B - tracked_dict = {"args": (img,)} - - result = _format_imshow("ax_00", tracked_dict, {}) - - assert "r0c0_plot-0_row" in result.columns - assert "r0c0_plot-0_col" in result.columns - assert "r0c0_plot-0_r" in result.columns - assert "r0c0_plot-0_g" in result.columns - assert "r0c0_plot-0_b" in result.columns - assert len(result) == 4 # 2x2 = 4 pixels - - def test_args_3d_rgba_image(self): - """3D RGBA image should have R, G, B, A columns.""" - img = np.zeros((2, 2, 4)) - img[:, :, 0] = 1.0 # R - img[:, :, 1] = 0.5 # G - img[:, :, 2] = 0.0 # B - img[:, :, 3] = 0.8 # A - tracked_dict = {"args": (img,)} - - result = _format_imshow("ax_00", tracked_dict, {}) - - assert "r0c0_plot-0_r" in result.columns - assert "r0c0_plot-0_g" in result.columns - assert "r0c0_plot-0_b" in result.columns - assert "r0c0_plot-0_a" in result.columns - - def test_row_col_indices_correct(self): - """Row and column indices should be meshgrid-like.""" - img = np.array([[1, 2, 3], [4, 5, 6]]) # 2 rows, 3 cols - tracked_dict = {"args": (img,)} - - result = _format_imshow("ax_00", tracked_dict, {}) - - # Check row indices - expected_rows = [0, 0, 0, 1, 1, 1] - expected_cols = [0, 1, 2, 0, 1, 2] - np.testing.assert_array_equal(result["r0c0_plot-0_row"].values, expected_rows) - np.testing.assert_array_equal(result["r0c0_plot-0_col"].values, expected_cols) - - def test_id_prefix_applied_correctly(self): - """ID prefix should be correctly applied to all columns.""" - img = np.array([[1, 2], [3, 4]]) - tracked_dict = {"args": (img,)} - - result = _format_imshow("custom_prefix", tracked_dict, {}) - - for col in result.columns: - assert col.startswith("r0c0_plot-custom-prefix_") - - def test_large_image_handling(self): - """Should handle larger images without issues.""" - img = np.random.rand(100, 100) - tracked_dict = {"args": (img,)} - - result = _format_imshow("ax_00", tracked_dict, {}) - - assert len(result) == 10000 # 100x100 - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_imshow(id, tracked_dict, kwargs): -# """Format data from an imshow call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to imshow -# -# Returns: -# pd.DataFrame: Formatted data from imshow (flattened image with row, col indices) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Check for pre-formatted image_df (from plot_imshow wrapper) -# if tracked_dict.get("image_df") is not None: -# return tracked_dict.get("image_df") -# -# # Handle raw args from __getattr__ proxied calls -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple) and len(args) > 0: -# img = np.asarray(args[0]) -# -# # Handle 2D grayscale image -# if img.ndim == 2: -# rows, cols = img.shape -# row_indices, col_indices = np.meshgrid( -# range(rows), range(cols), indexing="ij" -# ) -# -# # Get column names from single source of truth -# col_row = get_csv_column_name("row", ax_row, ax_col, trace_id=trace_id) -# col_col = get_csv_column_name("col", ax_row, ax_col, trace_id=trace_id) -# col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) -# -# df = pd.DataFrame( -# { -# col_row: row_indices.flatten(), -# col_col: col_indices.flatten(), -# col_value: img.flatten(), -# } -# ) -# return df -# -# # Handle RGB/RGBA images (3D array) -# elif img.ndim == 3: -# rows, cols, channels = img.shape -# row_indices, col_indices = np.meshgrid( -# range(rows), range(cols), indexing="ij" -# ) -# -# # Get column names from single source of truth -# col_row = get_csv_column_name("row", ax_row, ax_col, trace_id=trace_id) -# col_col = get_csv_column_name("col", ax_row, ax_col, trace_id=trace_id) -# -# data = { -# col_row: row_indices.flatten(), -# col_col: col_indices.flatten(), -# } -# -# # Add channel data (R, G, B, A) -# channel_names = ["r", "g", "b", "a"][:channels] -# for c, name in enumerate(channel_names): -# col_channel = get_csv_column_name(name, ax_row, ax_col, trace_id=trace_id) -# data[col_channel] = img[:, :, c].flatten() -# -# return pd.DataFrame(data) -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_imshow2d.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_imshow2d.py deleted file mode 100644 index 0a59e436f..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_imshow2d.py +++ /dev/null @@ -1,64 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_imshow2d(id, tracked_dict, kwargs): -# """Format data from an imshow2d call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to imshow2d -# -# Returns: -# pd.DataFrame: Formatted data from imshow2d -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Extract data if available -# if len(args) >= 1 and isinstance(args[0], pd.DataFrame): -# df = args[0].copy() -# # Rename columns using the single source of truth -# renamed_cols = {} -# for col in df.columns: -# renamed_cols[col] = get_csv_column_name( -# f"imshow2d_{col}", ax_row, ax_col, trace_id=trace_id -# ) -# df = df.rename(columns=renamed_cols) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_imshow2d.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_matshow.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_matshow.py deleted file mode 100644 index 749c74b96..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_matshow.py +++ /dev/null @@ -1,66 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_matshow(id, tracked_dict, kwargs): -# """Format data from a matshow call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to matshow -# -# Returns: -# pd.DataFrame: Formatted data from matshow (flattened matrix with row, col indices) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple) and len(args) > 0: -# Z = np.asarray(args[0]) -# -# # Create row/col indices -# rows, cols = np.indices(Z.shape) -# -# df = pd.DataFrame( -# { -# get_csv_column_name("row", ax_row, ax_col, trace_id=trace_id): rows.flatten(), -# get_csv_column_name("col", ax_row, ax_col, trace_id=trace_id): cols.flatten(), -# get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id): Z.flatten(), -# } -# ) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_matshow.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_pcolormesh.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_pcolormesh.py deleted file mode 100644 index 7f2fb50f3..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_pcolormesh.py +++ /dev/null @@ -1,82 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pcolormesh.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-21 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pcolormesh.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_pcolormesh(id, tracked_dict, kwargs): -# """Format data from a pcolormesh call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to pcolormesh -# -# Returns: -# pd.DataFrame: Formatted data from pcolormesh (x, y, value columns) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# args = tracked_dict.get("args", ()) -# -# if len(args) == 0: -# return pd.DataFrame() -# -# # pcolormesh can be called as: -# # pcolormesh(C) - just color values -# # pcolormesh(X, Y, C) - with coordinates -# if len(args) == 1: -# # Just C provided -# C = np.asarray(args[0]) -# rows, cols = C.shape -# Y, X = np.meshgrid(range(rows), range(cols), indexing="ij") -# elif len(args) >= 3: -# # X, Y, C provided -# X = np.asarray(args[0]) -# Y = np.asarray(args[1]) -# C = np.asarray(args[2]) -# else: -# return pd.DataFrame() -# -# # Get column names -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) -# -# # Flatten for CSV format -# df = pd.DataFrame({ -# col_x: X.flatten(), -# col_y: Y.flatten(), -# col_value: C.flatten(), -# }) -# -# return df -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pcolormesh.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_pie.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_pie.py deleted file mode 100644 index 613d23cf2..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_pie.py +++ /dev/null @@ -1,67 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_pie(id, tracked_dict, kwargs): -# """Format data from a pie chart call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to pie -# -# Returns: -# pd.DataFrame: Formatted data from pie chart -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple) and len(args) > 0: -# x = np.asarray(args[0]) -# -# # Get column names from single source of truth -# col_values = get_csv_column_name("values", ax_row, ax_col, trace_id=trace_id) -# data = {col_values: x} -# -# # Add labels if provided -# labels = kwargs.get("labels", None) -# if labels is not None: -# col_labels = get_csv_column_name("labels", ax_row, ax_col, trace_id=trace_id) -# data[col_labels] = labels -# -# df = pd.DataFrame(data) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_pie.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot.py deleted file mode 100755 index 434c3b41e..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot.py +++ /dev/null @@ -1,359 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2025-12-01 13:20:00 (ywatanabe)" -# File: tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot.py - -"""Tests for _format_plot CSV formatter.""" - -import numpy as np -import pandas as pd -import pytest - -pytest.importorskip("zarr") - -from scitex.plt._subplots._export_as_csv_formatters._format_plot import ( # noqa: E402 - _format_plot, -) - - -class TestFormatPlot: - """Tests for _format_plot function.""" - - def test_empty_tracked_dict_returns_empty_df(self): - """Empty tracked_dict should return empty DataFrame.""" - result = _format_plot("test", {}, {}) - assert isinstance(result, pd.DataFrame) - assert result.empty - - def test_none_tracked_dict_returns_empty_df(self): - """None tracked_dict should return empty DataFrame.""" - result = _format_plot("test", None, {}) - assert isinstance(result, pd.DataFrame) - assert result.empty - - def test_plot_df_key_returns_prefixed_df(self): - """When plot_df key exists, should return prefixed DataFrame.""" - plot_df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) - tracked_dict = {"plot_df": plot_df} - - result = _format_plot("ax_00", tracked_dict, {}) - - assert "r0c0_plot-0_x" in result.columns - assert "r0c0_plot-0_y" in result.columns - assert list(result["r0c0_plot-0_x"]) == [1, 2, 3] - - def test_args_single_1d_array(self): - """Single 1D array arg should generate x from indices.""" - y = np.array([1.0, 2.0, 3.0, 4.0]) - tracked_dict = {"args": (y,)} - - result = _format_plot("ax_00", tracked_dict, {}) - - assert "r0c0_plot-0_x" in result.columns - assert "r0c0_plot-0_y" in result.columns - assert list(result["r0c0_plot-0_x"]) == [0, 1, 2, 3] - assert list(result["r0c0_plot-0_y"]) == [1.0, 2.0, 3.0, 4.0] - - def test_args_single_2d_array(self): - """Single 2D array arg should extract x and y columns.""" - data = np.array([[0, 1], [1, 4], [2, 9]]) - tracked_dict = {"args": (data,)} - - result = _format_plot("ax_00", tracked_dict, {}) - - assert "r0c0_plot-0_x" in result.columns - assert "r0c0_plot-0_y" in result.columns - assert list(result["r0c0_plot-0_x"]) == [0, 1, 2] - assert list(result["r0c0_plot-0_y"]) == [1, 4, 9] - - def test_args_two_1d_arrays(self): - """Two 1D array args should use first as x, second as y.""" - x = np.array([0.0, 1.0, 2.0]) - y = np.array([0.0, 1.0, 4.0]) - tracked_dict = {"args": (x, y)} - - result = _format_plot("ax_00", tracked_dict, {}) - - assert "r0c0_plot-0_x" in result.columns - assert "r0c0_plot-0_y" in result.columns - np.testing.assert_array_equal(result["r0c0_plot-0_x"], x) - np.testing.assert_array_equal(result["r0c0_plot-0_y"], y) - - def test_args_x_and_2d_y(self): - """X and 2D Y arrays should create multiple y columns.""" - x = np.array([0.0, 1.0, 2.0]) - y = np.array([[1, 2], [3, 4], [5, 6]]) - tracked_dict = {"args": (x, y)} - - result = _format_plot("ax_00", tracked_dict, {}) - - assert "r0c0_plot-0-0_x00" in result.columns - assert "r0c0_plot-0-0_y00" in result.columns - assert "r0c0_plot-0-1_x01" in result.columns - assert "r0c0_plot-0-1_y01" in result.columns - - def test_args_x_and_dataframe_y(self): - """X and DataFrame Y should handle column iteration with indexed columns.""" - x = np.array([0.0, 1.0, 2.0]) - y = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - tracked_dict = {"args": (x, y)} - - result = _format_plot("ax_00", tracked_dict, {}) - - assert "r0c0_plot-0-0_x00" in result.columns - assert "r0c0_plot-0-0_y00" in result.columns - assert "r0c0_plot-0-1_x01" in result.columns - assert "r0c0_plot-0-1_y01" in result.columns - - def test_id_prefix_applied_correctly(self): - """ID prefix should be correctly applied to all columns. - - When a non-standard ID is given, _parse_tracking_id defaults to - ax_row=0, ax_col=0, trace_index=0, so get_csv_column_name produces - columns starting with 'ax_00_plot_0_'. - """ - x = np.array([1.0, 2.0]) - y = np.array([3.0, 4.0]) - tracked_dict = {"args": (x, y)} - - result = _format_plot("custom_prefix", tracked_dict, {}) - - for col in result.columns: - assert col.startswith("r0c0_plot-custom-prefix_") - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-08 18:45:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py -# -# """CSV formatter for matplotlib plot() calls.""" -# -# from collections import OrderedDict -# from typing import Any, Dict, Optional -# -# import numpy as np -# import pandas as pd -# import xarray as xr -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# -# def _parse_tracking_id(id: str, record_index: int = 0) -> tuple: -# """Parse tracking ID to extract axes position and trace ID. -# -# Parameters -# ---------- -# id : str -# Tracking ID like "ax_00_plot_0", "ax_00_stim-box", "plot_0", -# or user-provided like "sine" -# record_index : int -# Index of this record in the history (fallback for trace_id) -# -# Returns -# ------- -# tuple -# (ax_row, ax_col, trace_id) -# trace_id is a string - either the user-provided ID (e.g., "sine") -# or the record_index as string (e.g., "0") -# -# Note -# ---- -# When user provides a custom ID like "sine", that ID is preserved in the -# column names for clarity and traceability. -# -# Examples -# -------- -# >>> _parse_tracking_id("ax_00_plot_0") -# (0, 0, 'plot_0') -# >>> _parse_tracking_id("ax_00_stim-box") -# (0, 0, 'stim-box') -# >>> _parse_tracking_id("ax_12_text_0") -# (1, 2, 'text_0') -# >>> _parse_tracking_id("ax_10_violin") -# (1, 0, 'violin') -# """ -# ax_row, ax_col = 0, 0 -# trace_id = str(record_index) # Default to record_index as string -# -# if id.startswith("ax_"): -# parts = id.split("_") -# if len(parts) >= 2: -# ax_pos = parts[1] -# if len(ax_pos) >= 2: -# try: -# ax_row = int(ax_pos[0]) -# ax_col = int(ax_pos[1]) -# except ValueError: -# pass -# # Extract trace ID from parts[2:] (everything after "ax_XX_") -# # e.g., "ax_00_stim-box" -> parts = ["ax", "00", "stim-box"] -> trace_id = "stim-box" -# # e.g., "ax_00_plot_0" -> parts = ["ax", "00", "plot", "0"] -> trace_id = "plot_0" -# # e.g., "ax_12_text_0" -> parts = ["ax", "12", "text", "0"] -> trace_id = "text_0" -# if len(parts) >= 3: -# trace_id = "_".join(parts[2:]) -# elif id.startswith("plot_"): -# # Extract everything after "plot_" as the trace_id -# trace_id = id[5:] if len(id) > 5 else str(record_index) -# else: -# # User-provided ID like "sine", "cosine" - use it directly -# trace_id = id -# -# return ax_row, ax_col, trace_id -# -# -# def _format_plot( -# id: str, -# tracked_dict: Optional[Dict[str, Any]], -# kwargs: Dict[str, Any], -# ) -> pd.DataFrame: -# """Format data from a plot() call for CSV export. -# -# Handles various input formats including: -# - Pre-formatted plot_df from scitex wrappers -# - Raw args from __getattr__ proxied matplotlib calls -# - Single array: plot(y) generates x from indices -# - Two arrays: plot(x, y) -# - 2D arrays: creates multiple x/y column pairs -# -# Parameters -# ---------- -# id : str -# Identifier prefix for the output columns (e.g., "ax_00_plot_0"). -# tracked_dict : dict or None -# Dictionary containing tracked data. May include: -# - 'plot_df': Pre-formatted DataFrame from wrapper -# - 'args': Raw positional arguments (x, y) from plot() -# kwargs : dict -# Keyword arguments passed to plot (currently unused). -# -# Returns -# ------- -# pd.DataFrame -# Formatted data with columns using single source of truth naming. -# Format: ax-row_0_ax-col_0_trace-id_sine_variable_x -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # For stx_line, we expect a 'plot_df' key -# if "plot_df" in tracked_dict: -# plot_df = tracked_dict["plot_df"] -# if isinstance(plot_df, pd.DataFrame): -# # Rename columns using single source of truth -# renamed = {} -# for col in plot_df.columns: -# if col == "plot_x": -# renamed[col] = get_csv_column_name( -# "x", ax_row, ax_col, trace_id=trace_id -# ) -# elif col == "plot_y": -# renamed[col] = get_csv_column_name( -# "y", ax_row, ax_col, trace_id=trace_id -# ) -# else: -# # For other columns, use simplified naming -# renamed[col] = get_csv_column_name( -# col, ax_row, ax_col, trace_id=trace_id -# ) -# return plot_df.rename(columns=renamed) -# -# # Handle raw args from __getattr__ proxied calls -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple) and len(args) > 0: -# # Get column names from single source of truth -# x_col = get_csv_column_name( -# "x", ax_row, ax_col, trace_id=trace_id -# ) -# y_col = get_csv_column_name( -# "y", ax_row, ax_col, trace_id=trace_id -# ) -# -# # Handle single argument: plot(y) or plot(data_2d) -# if len(args) == 1: -# args_value = args[0] -# -# # Convert to numpy for consistent handling -# if hasattr(args_value, "values"): # pandas Series/DataFrame -# args_value = args_value.values -# args_value = np.asarray(args_value) -# -# # 2D array: extract x and y columns -# if hasattr(args_value, "ndim") and args_value.ndim == 2: -# x, y = args_value[:, 0], args_value[:, 1] -# df = pd.DataFrame({x_col: x, y_col: y}) -# return df -# -# # 1D array: generate x from indices (common case: plot(y)) -# elif hasattr(args_value, "ndim") and args_value.ndim == 1: -# x = np.arange(len(args_value)) -# y = args_value -# df = pd.DataFrame({x_col: x, y_col: y}) -# return df -# -# # Handle two arguments: plot(x, y) -# elif len(args) >= 2: -# x_arg, y_arg = args[0], args[1] -# -# # Convert to numpy -# x = np.asarray(x_arg.values if hasattr(x_arg, "values") else x_arg) -# y = np.asarray(y_arg.values if hasattr(y_arg, "values") else y_arg) -# -# # Handle 2D y array (multiple lines) -# if hasattr(y, "ndim") and y.ndim == 2: -# out = OrderedDict() -# for ii in range(y.shape[1]): -# x_col_i = get_csv_column_name( -# f"x{ii:02d}", ax_row, ax_col, trace_id=f"{trace_id}-{ii}" -# ) -# y_col_i = get_csv_column_name( -# f"y{ii:02d}", ax_row, ax_col, trace_id=f"{trace_id}-{ii}" -# ) -# out[x_col_i] = x -# out[y_col_i] = y[:, ii] -# df = pd.DataFrame(out) -# return df -# -# # Handle DataFrame y -# if isinstance(y_arg, pd.DataFrame): -# result = {x_col: x} -# for ii, col in enumerate(y_arg.columns): -# y_col_i = get_csv_column_name( -# f"y{ii:02d}", ax_row, ax_col, trace_id=f"{trace_id}-{ii}" -# ) -# result[y_col_i] = np.array(y_arg[col]) -# df = pd.DataFrame(result) -# return df -# -# # Handle 1D arrays (most common case: plot(x, y)) -# if hasattr(y, "ndim") and y.ndim == 1: -# # Flatten x if needed -# x_flat = np.ravel(x) -# y_flat = np.ravel(y) -# df = pd.DataFrame({x_col: x_flat, y_col: y_flat}) -# return df -# -# # Fallback for list-like y -# df = pd.DataFrame({x_col: np.ravel(x), y_col: np.ravel(y)}) -# return df -# -# # Default empty DataFrame if we can't process the input -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_box.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_box.py deleted file mode 100644 index e2be55d83..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_box.py +++ /dev/null @@ -1,114 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py -# -# """CSV formatter for stx_box() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_box(id, tracked_dict, kwargs): -# """Format data from a stx_box call. -# -# Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stx_box -# -# Returns: -# pd.DataFrame: Formatted box plot data with standard column names -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # First try to get data directly from tracked_dict -# data = tracked_dict.get("data") -# -# # If no data key, get from args -# if data is None: -# args = tracked_dict.get("args", []) -# if len(args) >= 1: -# data = args[0] -# else: -# return pd.DataFrame() -# -# # If data is a simple array or list of values -# if isinstance(data, (np.ndarray, list)) and len(data) > 0: -# try: -# # Check if it's a simple list of values or a list of lists -# if isinstance(data[0], (int, float, np.number)): -# col_name = get_csv_column_name( -# "data", ax_row, ax_col, trace_id=trace_id -# ) -# return pd.DataFrame({col_name: data}) -# -# # If data is a list of arrays (multiple box plots) -# elif isinstance(data, (list, tuple)) and all( -# isinstance(x, (list, np.ndarray)) for x in data -# ): -# result = pd.DataFrame() -# for i, values in enumerate(data): -# try: -# col_name = get_csv_column_name( -# f"data-{i}", ax_row, ax_col, trace_id=trace_id -# ) -# result[col_name] = pd.Series(values) -# except Exception: -# pass -# return result -# except (IndexError, TypeError): -# pass -# -# # If data is a dictionary -# elif isinstance(data, dict): -# result = pd.DataFrame() -# for label, values in data.items(): -# try: -# col_name = get_csv_column_name( -# f"data-{label}", ax_row, ax_col, trace_id=trace_id -# ) -# result[col_name] = pd.Series(values) -# except Exception: -# pass -# return result -# -# # If data is a DataFrame -# elif isinstance(data, pd.DataFrame): -# result = pd.DataFrame() -# for col in data.columns: -# col_name = get_csv_column_name( -# f"data-{col}", ax_row, ax_col, trace_id=trace_id -# ) -# result[col_name] = data[col] -# return result -# -# # Default case: return empty DataFrame if nothing could be processed -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_box.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_imshow.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_imshow.py deleted file mode 100644 index d9d76a442..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_imshow.py +++ /dev/null @@ -1,69 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_imshow.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-11-18 11:40:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_imshow.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import pandas as pd -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_imshow(id, tracked_dict, kwargs): -# """Format data from a plot_imshow call. -# -# Args: -# id: Plot identifier -# tracked_dict: Dictionary containing tracked data with key "imshow_df" -# kwargs: Additional keyword arguments -# -# Returns: -# pd.DataFrame: Formatted image data for CSV export -# """ -# # Check for imshow_df in tracked_dict -# if tracked_dict.get("imshow_df") is not None: -# df = tracked_dict["imshow_df"] -# -# # Add id prefix to column names if id is provided -# if id is not None: -# # Parse tracking ID to extract axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Use standardized column naming for each column -# df = df.copy() -# renamed_cols = {} -# for col in df.columns: -# # Create column name like "plot_imshow_row" or "plot_imshow_col" -# renamed_cols[col] = get_csv_column_name( -# f"plot_imshow_{col}", ax_row, ax_col, trace_id=trace_id -# ) -# df.rename(columns=renamed_cols, inplace=True) -# -# return df -# -# # Fallback: return empty DataFrame -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_imshow.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_kde.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_kde.py deleted file mode 100644 index c7e74373f..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_kde.py +++ /dev/null @@ -1,75 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import pandas as pd -# from scitex.pd import force_df -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_kde(id, tracked_dict, kwargs): -# """Format data from a stx_kde call. -# -# Processes kernel density estimation plot data. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing 'x', 'kde', and 'n' keys -# kwargs (dict): Keyword arguments passed to stx_kde -# -# Returns: -# pd.DataFrame: Formatted KDE data -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# x = tracked_dict.get("x") -# kde = tracked_dict.get("kde") -# n = tracked_dict.get("n") -# -# if x is None or kde is None: -# return pd.DataFrame() -# -# # Parse tracking ID to extract axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Use standardized column naming -# x_col = get_csv_column_name("kde_x", ax_row, ax_col, trace_id=trace_id) -# density_col = get_csv_column_name("kde_density", ax_row, ax_col, trace_id=trace_id) -# -# df = pd.DataFrame({x_col: x, density_col: kde}) -# -# # Add sample count if available -# if n is not None: -# # If n is a scalar, create a list with the same length as x -# if not hasattr(n, "__len__"): -# n = [n] * len(x) -# n_col = get_csv_column_name("kde_n", ax_row, ax_col, trace_id=trace_id) -# df[n_col] = n -# -# return df - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_kde.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_scatter.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_scatter.py deleted file mode 100644 index 7a8376d95..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_plot_scatter.py +++ /dev/null @@ -1,60 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-10-03 02:47:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import pandas as pd -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_scatter(id, tracked_dict, kwargs): -# """Format data from a plot_scatter call. -# -# The plot_scatter method stores data as: -# {"scatter_df": pd.DataFrame({"x": args[0], "y": args[1]})} -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Get the scatter_df from tracked_dict -# scatter_df = tracked_dict.get("scatter_df") -# -# if scatter_df is not None and isinstance(scatter_df, pd.DataFrame): -# # Parse tracking ID to extract axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Use standardized column naming -# x_col = get_csv_column_name("scatter_x", ax_row, ax_col, trace_id=trace_id) -# y_col = get_csv_column_name("scatter_y", ax_row, ax_col, trace_id=trace_id) -# -# # Rename columns to include the id -# return scatter_df.rename(columns={"x": x_col, "y": y_col}) -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_quiver.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_quiver.py deleted file mode 100644 index 8158fd7f3..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_quiver.py +++ /dev/null @@ -1,77 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 12:20:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_quiver(id, tracked_dict, kwargs): -# """Format data from a quiver (vector field) call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to quiver -# -# Returns: -# pd.DataFrame: Formatted data from quiver (X, Y positions and U, V vectors) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple): -# # quiver can be called as: -# # quiver(U, V) - positions auto-generated -# # quiver(X, Y, U, V) - explicit positions -# if len(args) == 2: -# U = np.asarray(args[0]) -# V = np.asarray(args[1]) -# X, Y = np.meshgrid(np.arange(U.shape[1]), np.arange(U.shape[0])) -# elif len(args) >= 4: -# X = np.asarray(args[0]) -# Y = np.asarray(args[1]) -# U = np.asarray(args[2]) -# V = np.asarray(args[3]) -# else: -# return pd.DataFrame() -# -# df = pd.DataFrame( -# { -# get_csv_column_name("quiver-x", ax_row, ax_col, trace_id=trace_id): X.flatten(), -# get_csv_column_name("quiver-y", ax_row, ax_col, trace_id=trace_id): Y.flatten(), -# get_csv_column_name("quiver-u", ax_row, ax_col, trace_id=trace_id): U.flatten(), -# get_csv_column_name("quiver-v", ax_row, ax_col, trace_id=trace_id): V.flatten(), -# } -# ) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_quiver.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_scatter.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_scatter.py deleted file mode 100644 index 583b1e5d3..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_scatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_scatter(id, tracked_dict, kwargs): -# """Format data from a scatter call (matplotlib ax.scatter or seaborn scatter). -# -# Note: For plot_scatter (wrapper method), use _format_plot_scatter instead. -# This formatter expects data in args format: tracked_dict['args'] = (x, y). -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Extract x and y data if available -# if len(args) >= 2: -# x, y = args[0], args[1] -# # Use structured column naming: ax-row-{row}-col-{col}_trace-id-{id}_variable-{var} -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# df = pd.DataFrame({col_x: x, col_y: y}) -# return df -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_scatter.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_barplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_barplot.py deleted file mode 100644 index 98fe04b81..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_barplot.py +++ /dev/null @@ -1,82 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py -# -# """CSV formatter for sns.barplot() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_barplot(id, tracked_dict, kwargs): -# """Format data from a sns_barplot call. -# -# Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to sns_barplot -# -# Returns: -# pd.DataFrame: Formatted data with standard column names -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # If 'data' key is in tracked_dict, use it -# if "data" in tracked_dict: -# df = tracked_dict["data"] -# if isinstance(df, pd.DataFrame): -# result = pd.DataFrame() -# for col in df.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = df[col] -# return result -# -# # Legacy handling for args -# if "args" in tracked_dict: -# df = tracked_dict["args"] -# if isinstance(df, pd.DataFrame): -# try: -# processed_df = pd.DataFrame( -# pd.Series(np.array(df).diagonal(), index=df.columns) -# ).T -# result = pd.DataFrame() -# for col in processed_df.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = processed_df[col] -# return result -# except (ValueError, TypeError, IndexError): -# result = pd.DataFrame() -# for col in df.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = df[col] -# return result -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_barplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_boxplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_boxplot.py deleted file mode 100644 index 32e82288a..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_boxplot.py +++ /dev/null @@ -1,128 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py -# -# """CSV formatter for sns.boxplot() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_boxplot(id, tracked_dict, kwargs): -# """Format data from a sns_boxplot call. -# -# Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to sns_boxplot -# -# Returns: -# pd.DataFrame: Formatted boxplot data with standard column names -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict: -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # If tracked_dict is a dictionary, try to extract the data from it -# if isinstance(tracked_dict, dict): -# # First try to get 'data' key which is used in seaborn functions -# if "data" in tracked_dict: -# data = tracked_dict["data"] -# if isinstance(data, pd.DataFrame): -# result = pd.DataFrame() -# for col in data.columns: -# col_name = get_csv_column_name( -# f"data-{col}", ax_row, ax_col, trace_id=trace_id -# ) -# result[col_name] = data[col] -# return result -# -# # If no 'data' key, try to get data from args -# args = tracked_dict.get("args", []) -# if len(args) > 0: -# data = args[0] -# if isinstance(data, pd.DataFrame): -# result = pd.DataFrame() -# for col in data.columns: -# col_name = get_csv_column_name( -# f"data-{col}", ax_row, ax_col, trace_id=trace_id -# ) -# result[col_name] = data[col] -# return result -# -# # Handle list or array data -# elif isinstance(data, (list, np.ndarray)): -# try: -# if all(isinstance(item, (list, np.ndarray)) for item in data): -# result = pd.DataFrame() -# for i, group_data in enumerate(data): -# col_name = get_csv_column_name( -# f"data-{i}", ax_row, ax_col, trace_id=trace_id -# ) -# result[col_name] = pd.Series(group_data) -# return result -# else: -# col_name = get_csv_column_name( -# "data", ax_row, ax_col, trace_id=trace_id -# ) -# return pd.DataFrame({col_name: data}) -# except Exception: -# pass -# -# # If tracked_dict is a DataFrame already, use it directly -# elif isinstance(tracked_dict, pd.DataFrame): -# result = pd.DataFrame() -# for col in tracked_dict.columns: -# col_name = get_csv_column_name( -# f"data-{col}", ax_row, ax_col, trace_id=trace_id -# ) -# result[col_name] = tracked_dict[col] -# return result -# -# # If tracked_dict is list-like, try to convert it to a DataFrame -# elif hasattr(tracked_dict, "__iter__") and not isinstance(tracked_dict, str): -# try: -# if all(isinstance(item, (list, np.ndarray)) for item in tracked_dict): -# result = pd.DataFrame() -# for i, group_data in enumerate(tracked_dict): -# col_name = get_csv_column_name( -# f"data-{i}", ax_row, ax_col, trace_id=trace_id -# ) -# result[col_name] = pd.Series(group_data) -# return result -# else: -# col_name = get_csv_column_name( -# "data", ax_row, ax_col, trace_id=trace_id -# ) -# return pd.DataFrame({col_name: tracked_dict}) -# except Exception: -# pass -# -# # Return empty DataFrame if we couldn't extract useful data -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_boxplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_heatmap.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_heatmap.py deleted file mode 100644 index 33336b201..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_heatmap.py +++ /dev/null @@ -1,96 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py -# -# """CSV formatter for sns.heatmap() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_heatmap(id, tracked_dict, kwargs): -# """Format data from a sns_heatmap call. -# -# Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to sns_heatmap -# -# Returns: -# pd.DataFrame: Formatted data with standard column names -# """ -# # Check if tracked_dict is empty -# if not tracked_dict: -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# def _format_dataframe(df): -# result = pd.DataFrame() -# for col in df.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = df[col] -# return result -# -# def _format_array(arr): -# rows, cols = arr.shape if len(arr.shape) >= 2 else (arr.shape[0], 1) -# result = pd.DataFrame() -# for i in range(cols): -# col_data = arr[:, i] if len(arr.shape) >= 2 else arr -# col_name = get_csv_column_name(f"data-col-{i}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = col_data -# return result -# -# # If tracked_dict is a dictionary -# if isinstance(tracked_dict, dict): -# if "data" in tracked_dict: -# data = tracked_dict["data"] -# -# if isinstance(data, pd.DataFrame): -# return _format_dataframe(data) -# elif isinstance(data, np.ndarray): -# return _format_array(data) -# -# # Legacy handling for args -# args = tracked_dict.get("args", []) -# if len(args) > 0: -# data = args[0] -# -# if isinstance(data, pd.DataFrame): -# return _format_dataframe(data) -# elif isinstance(data, np.ndarray): -# return _format_array(data) -# -# # If tracked_dict is a DataFrame directly -# elif isinstance(tracked_dict, pd.DataFrame): -# return _format_dataframe(tracked_dict) -# -# # If tracked_dict is a numpy array directly -# elif isinstance(tracked_dict, np.ndarray): -# return _format_array(tracked_dict) -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_heatmap.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_histplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_histplot.py deleted file mode 100644 index 84d18bdea..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_histplot.py +++ /dev/null @@ -1,111 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py -# -# """CSV formatter for sns.histplot() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_histplot(id, tracked_dict, kwargs): -# """Format data from a sns_histplot call as a bar plot representation. -# -# Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to sns_histplot -# -# Returns: -# pd.DataFrame: Formatted data with standard column names -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# columns = {} -# -# # Check if histogram result is available in tracked_dict -# hist_result = tracked_dict.get("hist_result", None) -# -# # If we have histogram result (counts and bin edges) -# if hist_result is not None: -# counts, bin_edges = hist_result -# -# # Calculate bin centers for bar plot representation -# bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) -# bin_widths = bin_edges[1:] - bin_edges[:-1] -# -# # Add bin information with standard naming -# columns[get_csv_column_name("bin-centers", ax_row, ax_col, trace_id=trace_id)] = bin_centers -# columns[get_csv_column_name("bin-counts", ax_row, ax_col, trace_id=trace_id)] = counts -# columns[get_csv_column_name("bin-widths", ax_row, ax_col, trace_id=trace_id)] = bin_widths -# columns[get_csv_column_name("bin-edges-left", ax_row, ax_col, trace_id=trace_id)] = bin_edges[:-1] -# columns[get_csv_column_name("bin-edges-right", ax_row, ax_col, trace_id=trace_id)] = bin_edges[1:] -# -# # Get raw data if available -# if "data" in tracked_dict: -# df = tracked_dict["data"] -# if isinstance(df, pd.DataFrame): -# x_col = kwargs.get("x") -# if x_col and x_col in df.columns: -# columns[get_csv_column_name("raw-data", ax_row, ax_col, trace_id=trace_id)] = df[x_col].values -# -# # Legacy handling for args -# elif "args" in tracked_dict: -# args = tracked_dict["args"] -# if len(args) >= 1: -# x = args[0] -# if hasattr(x, "values"): -# columns[get_csv_column_name("raw-data", ax_row, ax_col, trace_id=trace_id)] = x.values -# else: -# columns[get_csv_column_name("raw-data", ax_row, ax_col, trace_id=trace_id)] = x -# -# # If we have data to return -# if columns: -# # Ensure all arrays are the same length by padding with NaN -# max_length = max( -# len(value) for value in columns.values() if hasattr(value, "__len__") -# ) -# for key, value in list(columns.items()): -# if hasattr(value, "__len__") and len(value) < max_length: -# if isinstance(value, np.ndarray): -# columns[key] = np.pad( -# value, -# (0, max_length - len(value)), -# mode="constant", -# constant_values=np.nan, -# ) -# else: -# padded = list(value) + [np.nan] * (max_length - len(value)) -# columns[key] = np.array(padded) -# -# return pd.DataFrame(columns) -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_histplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_jointplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_jointplot.py deleted file mode 100644 index 823c9a908..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_jointplot.py +++ /dev/null @@ -1,92 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_jointplot(id, tracked_dict, kwargs): -# """Format data from a sns_jointplot call.""" -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to extract axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Joint distribution plot in seaborn -# if len(args) >= 1: -# data = args[0] -# -# # Get x and y variables from kwargs -# x_var = kwargs.get("x") -# y_var = kwargs.get("y") -# -# # Handle DataFrame input -# if isinstance(data, pd.DataFrame) and x_var and y_var: -# # Extract the relevant columns -# x_data = data[x_var] -# y_data = data[y_var] -# -# result = pd.DataFrame( -# { -# get_csv_column_name(f"joint_{x_var}", ax_row, ax_col, trace_id=trace_id): x_data, -# get_csv_column_name(f"joint_{y_var}", ax_row, ax_col, trace_id=trace_id): y_data, -# } -# ) -# return result -# -# # Handle direct x, y data arrays -# elif isinstance(data, pd.DataFrame): -# # If no x, y specified, return the whole dataframe -# result = data.copy() -# if id is not None: -# result.columns = [ -# get_csv_column_name(f"joint_{col}", ax_row, ax_col, trace_id=trace_id) -# for col in result.columns -# ] -# return result -# -# # Handle numpy arrays directly -# elif ( -# all(arg in args for arg in range(2)) -# and isinstance(args[0], (np.ndarray, list)) -# and isinstance(args[1], (np.ndarray, list)) -# ): -# x_data, y_data = args[0], args[1] -# return pd.DataFrame({ -# get_csv_column_name("joint_x", ax_row, ax_col, trace_id=trace_id): x_data, -# get_csv_column_name("joint_y", ax_row, ax_col, trace_id=trace_id): y_data, -# }) -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_jointplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_kdeplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_kdeplot.py deleted file mode 100644 index 901adb428..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_kdeplot.py +++ /dev/null @@ -1,107 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py -# -# """CSV formatter for sns.kdeplot() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_kdeplot(id, tracked_dict, kwargs): -# """Format data from a sns_kdeplot call. -# -# Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to sns_kdeplot -# -# Returns: -# pd.DataFrame: Formatted data with standard column names -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get args from tracked_dict -# args = tracked_dict.get("args", []) -# x_var = kwargs.get("x") if kwargs else None -# y_var = kwargs.get("y") if kwargs else None -# -# if len(args) >= 1: -# data = args[0] -# -# # Handle DataFrame input with x, y variables -# if isinstance(data, pd.DataFrame) and x_var: -# if y_var and y_var in data.columns: # Bivariate KDE -# return pd.DataFrame({ -# get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): data[x_var], -# get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id): data[y_var], -# }) -# elif x_var in data.columns: # Univariate KDE -# return pd.DataFrame({ -# get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): data[x_var] -# }) -# -# # Handle direct data array input -# elif isinstance(data, (np.ndarray, list)): -# y_data = ( -# args[1] -# if len(args) > 1 and isinstance(args[1], (np.ndarray, list)) -# else None -# ) -# -# if y_data is not None: # Bivariate KDE -# return pd.DataFrame({ -# get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): data, -# get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id): y_data, -# }) -# else: # Univariate KDE -# return pd.DataFrame({ -# get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): data -# }) -# -# # Handle DataFrame input without x, y specified -# elif isinstance(data, pd.DataFrame): -# result = pd.DataFrame() -# for col in data.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = data[col] -# return result -# -# # Also check for 'data' key directly -# if "data" in tracked_dict: -# data = tracked_dict["data"] -# if isinstance(data, pd.DataFrame): -# result = pd.DataFrame() -# for col in data.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = data[col] -# return result -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_kdeplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_lineplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_lineplot.py deleted file mode 100644 index eb08c6062..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_lineplot.py +++ /dev/null @@ -1,80 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_lineplot(id, tracked_dict, kwargs): -# """Format data from a sns_lineplot call.""" -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get data from tracked_dict - can be in "data" (from _sns_base_xyhue) or "args" -# data = tracked_dict.get("data") -# args = tracked_dict.get("args", []) -# -# # If data is None, try to get it from args -# if data is None and len(args) >= 1: -# data = args[0] -# -# x_var = kwargs.get("x") -# y_var = kwargs.get("y") -# -# # Handle DataFrame input with x, y variables -# if isinstance(data, pd.DataFrame): -# # If data has been pre-processed by _sns_prepare_xyhue, it may be pivoted -# # Just export all columns with proper naming -# if data.empty: -# return pd.DataFrame() -# -# result = {} -# for col in data.columns: -# col_name = str(col) if not isinstance(col, str) else col -# result[get_csv_column_name(col_name, ax_row, ax_col, trace_id=trace_id)] = data[col].values -# return pd.DataFrame(result) -# -# # Handle direct x, y data arrays from args -# elif ( -# len(args) > 1 -# and isinstance(args[0], (np.ndarray, list)) -# and isinstance(args[1], (np.ndarray, list)) -# ): -# x_data, y_data = args[0], args[1] -# return pd.DataFrame({ -# get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): x_data, -# get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id): y_data -# }) -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_lineplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_pairplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_pairplot.py deleted file mode 100644 index 905e155da..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_pairplot.py +++ /dev/null @@ -1,73 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_pairplot(id, tracked_dict, kwargs): -# """Format data from a sns_pairplot call.""" -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to extract axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Grid of plots showing pairwise relationships -# if len(args) >= 1: -# data = args[0] -# -# # Handle DataFrame input -# if isinstance(data, pd.DataFrame): -# # For pairplot, just return the full DataFrame since it uses all variables -# result = data.copy() -# if id is not None: -# result.columns = [ -# get_csv_column_name(f"pair_{col}", ax_row, ax_col, trace_id=trace_id) -# for col in result.columns -# ] -# -# # Add vars or hue columns if specified -# vars_list = kwargs.get("vars") -# if vars_list and all(var in data.columns for var in vars_list): -# # Keep only the specified columns -# result = pd.DataFrame( -# { -# get_csv_column_name(f"pair_{col}", ax_row, ax_col, trace_id=trace_id): data[col] -# for col in vars_list -# } -# ) -# -# return result -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_pairplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_scatterplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_scatterplot.py deleted file mode 100644 index 19ba32224..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_scatterplot.py +++ /dev/null @@ -1,88 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py -# -# """CSV formatter for sns.scatterplot() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_scatterplot(id, tracked_dict, kwargs=None): -# """Format data from a sns_scatterplot call. -# -# Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Tracked data dictionary -# kwargs (dict): Keyword arguments from the record tuple -# -# Returns: -# pd.DataFrame: Formatted data with standard column names -# """ -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Look for the DataFrame in the kwargs dictionary if provided -# if kwargs and isinstance(kwargs, dict) and "data" in kwargs: -# data = kwargs["data"] -# if isinstance(data, pd.DataFrame): -# result = pd.DataFrame() -# -# # If x and y variables are specified in kwargs, use them -# x_var = kwargs.get("x") -# y_var = kwargs.get("y") -# -# if x_var and y_var and x_var in data.columns and y_var in data.columns: -# result[get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id)] = data[x_var] -# result[get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id)] = data[y_var] -# -# # Also extract hue, size, style if specified -# for extra_var in ["hue", "size", "style"]: -# var_name = kwargs.get(extra_var) -# if var_name and var_name in data.columns: -# result[get_csv_column_name(extra_var, ax_row, ax_col, trace_id=trace_id)] = data[var_name] -# -# return result -# else: -# # If columns aren't specified, include all columns -# for col in data.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = data[col] -# return result -# -# # Alternative: try to find a DataFrame in tracked_dict -# if tracked_dict and isinstance(tracked_dict, dict): -# if "data" in tracked_dict and isinstance(tracked_dict["data"], pd.DataFrame): -# data = tracked_dict["data"] -# result = pd.DataFrame() -# -# for col in data.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = data[col] -# -# return result -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_scatterplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_stripplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_stripplot.py deleted file mode 100644 index fc3231d90..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_stripplot.py +++ /dev/null @@ -1,96 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py -# -# """CSV formatter for sns.stripplot() calls - uses standard column naming.""" -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_stripplot(id, tracked_dict, kwargs): -# """Format data from a sns_stripplot call. -# -# Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to sns_stripplot -# -# Returns: -# pd.DataFrame: Formatted data with standard column names -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # If 'data' key is in tracked_dict, use it -# if "data" in tracked_dict: -# data = tracked_dict["data"] -# -# if isinstance(data, pd.DataFrame): -# result = pd.DataFrame() -# -# # Extract variables from kwargs -# x_var = kwargs.get("x") if kwargs else None -# y_var = kwargs.get("y") if kwargs else None -# hue_var = kwargs.get("hue") if kwargs else None -# -# # Add x variable if specified -# if x_var and x_var in data.columns: -# result[get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id)] = data[x_var] -# -# # Add y variable if specified -# if y_var and y_var in data.columns: -# result[get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id)] = data[y_var] -# -# # Add grouping variable if present -# if hue_var and hue_var in data.columns: -# result[get_csv_column_name("hue", ax_row, ax_col, trace_id=trace_id)] = data[hue_var] -# -# # If we've added columns, return the result -# if not result.empty: -# return result -# -# # If no columns were explicitly specified, return all columns -# for col in data.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = data[col] -# return result -# -# # Legacy handling for args -# if "args" in tracked_dict and len(tracked_dict["args"]) >= 1: -# data = tracked_dict["args"][0] -# -# if isinstance(data, pd.DataFrame): -# result = pd.DataFrame() -# for col in data.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = data[col] -# return result -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_stripplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_swarmplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_swarmplot.py deleted file mode 100644 index a844a0b5e..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_swarmplot.py +++ /dev/null @@ -1,96 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py -# -# """CSV formatter for sns.swarmplot() calls - uses standard column naming.""" -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_swarmplot(id, tracked_dict, kwargs): -# """Format data from a sns_swarmplot call. -# -# Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to sns_swarmplot -# -# Returns: -# pd.DataFrame: Formatted data with standard column names -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # If 'data' key is in tracked_dict, use it -# if "data" in tracked_dict: -# data = tracked_dict["data"] -# -# if isinstance(data, pd.DataFrame): -# result = pd.DataFrame() -# -# # Extract variables from kwargs -# x_var = kwargs.get("x") if kwargs else None -# y_var = kwargs.get("y") if kwargs else None -# hue_var = kwargs.get("hue") if kwargs else None -# -# # Add x variable if specified -# if x_var and x_var in data.columns: -# result[get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id)] = data[x_var] -# -# # Add y variable if specified -# if y_var and y_var in data.columns: -# result[get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id)] = data[y_var] -# -# # Add grouping variable if present -# if hue_var and hue_var in data.columns: -# result[get_csv_column_name("hue", ax_row, ax_col, trace_id=trace_id)] = data[hue_var] -# -# # If we've added columns, return the result -# if not result.empty: -# return result -# -# # If no columns were explicitly specified, return all columns -# for col in data.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = data[col] -# return result -# -# # Legacy handling for args -# if "args" in tracked_dict and len(tracked_dict["args"]) >= 1: -# data = tracked_dict["args"][0] -# -# if isinstance(data, pd.DataFrame): -# result = pd.DataFrame() -# for col in data.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = data[col] -# return result -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_swarmplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_violinplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_violinplot.py deleted file mode 100644 index 1abf990b7..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_sns_violinplot.py +++ /dev/null @@ -1,150 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:30:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py -# -# """CSV formatter for sns.violinplot() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_sns_violinplot(id, tracked_dict, kwargs): -# """Format data from a sns_violinplot call. -# -# Uses standard column naming: ax-row-{r}-col-{c}_trace-id-{id}_variable-{var} -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to sns_violinplot -# -# Returns: -# pd.DataFrame: Formatted data with standard column names -# """ -# # Check if tracked_dict is empty -# if not tracked_dict: -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# def _format_dataframe(df): -# result = pd.DataFrame() -# for col in df.columns: -# col_name = get_csv_column_name(f"data-{col}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = df[col] -# return result -# -# def _format_list_of_arrays(data): -# result = pd.DataFrame() -# for i, group_data in enumerate(data): -# col_name = get_csv_column_name(f"data-{i}", ax_row, ax_col, trace_id=trace_id) -# result[col_name] = pd.Series(group_data) -# return result -# -# # If tracked_dict is a dictionary -# if isinstance(tracked_dict, dict): -# if "data" in tracked_dict: -# data = tracked_dict["data"] -# -# if isinstance(data, pd.DataFrame): -# try: -# return _format_dataframe(data) -# except Exception: -# try: -# x_var = kwargs.get("x") if kwargs else None -# y_var = kwargs.get("y") if kwargs else None -# -# if x_var and y_var and x_var in data.columns and y_var in data.columns: -# return pd.DataFrame({ -# get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): data[x_var], -# get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id): data[y_var], -# }) -# elif len(data.columns) > 0: -# first_col = data.columns[0] -# return pd.DataFrame({ -# get_csv_column_name("data", ax_row, ax_col, trace_id=trace_id): data[first_col] -# }) -# except Exception: -# return pd.DataFrame() -# -# elif isinstance(data, (list, np.ndarray)): -# try: -# if isinstance(data, list) and len(data) > 0 and all( -# isinstance(item, (list, np.ndarray)) for item in data -# ): -# return _format_list_of_arrays(data) -# else: -# return pd.DataFrame({ -# get_csv_column_name("data", ax_row, ax_col, trace_id=trace_id): data -# }) -# except Exception: -# return pd.DataFrame() -# -# # Legacy handling for args -# args = tracked_dict.get("args", []) -# if len(args) > 0: -# data = args[0] -# -# if isinstance(data, pd.DataFrame): -# return _format_dataframe(data) -# -# elif isinstance(data, (list, np.ndarray)): -# try: -# if all(isinstance(item, (list, np.ndarray)) for item in data): -# return _format_list_of_arrays(data) -# else: -# return pd.DataFrame({ -# get_csv_column_name("data", ax_row, ax_col, trace_id=trace_id): data -# }) -# except Exception: -# return pd.DataFrame() -# -# # If tracked_dict is a DataFrame directly -# elif isinstance(tracked_dict, pd.DataFrame): -# try: -# return _format_dataframe(tracked_dict) -# except Exception: -# try: -# if len(tracked_dict.columns) > 0: -# first_col = tracked_dict.columns[0] -# return pd.DataFrame({ -# get_csv_column_name("data", ax_row, ax_col, trace_id=trace_id): tracked_dict[first_col] -# }) -# except Exception: -# return pd.DataFrame() -# -# # If tracked_dict is a list or numpy array directly -# elif isinstance(tracked_dict, (list, np.ndarray)): -# try: -# if all(isinstance(item, (list, np.ndarray)) for item in tracked_dict): -# return _format_list_of_arrays(tracked_dict) -# else: -# return pd.DataFrame({ -# get_csv_column_name("data", ax_row, ax_col, trace_id=trace_id): tracked_dict -# }) -# except Exception: -# return pd.DataFrame() -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_sns_violinplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stackplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stackplot.py deleted file mode 100644 index 942080b57..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stackplot.py +++ /dev/null @@ -1,78 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stackplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-21 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stackplot.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_stackplot(id, tracked_dict, kwargs): -# """Format data from a stackplot call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stackplot -# -# Returns: -# pd.DataFrame: Formatted data from stackplot (x and multiple y columns) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# args = tracked_dict.get("args", ()) -# -# # stackplot(x, y1, y2, y3, ...) or stackplot(x, [y1, y2, y3], ...) -# if len(args) < 2: -# return pd.DataFrame() -# -# x = np.asarray(args[0]) -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# data = {col_x: x} -# -# # Get labels from kwargs if available -# labels = kwargs.get("labels", []) -# -# # Handle remaining args as y arrays -# y_arrays = args[1:] -# -# # If first y arg is a 2D array, treat rows as separate series -# if len(y_arrays) == 1 and hasattr(y_arrays[0], "ndim"): -# y_data = np.asarray(y_arrays[0]) -# if y_data.ndim == 2: -# y_arrays = [y_data[i] for i in range(y_data.shape[0])] -# -# for i, y in enumerate(y_arrays): -# y = np.asarray(y) -# # Use label if available, otherwise use index -# label = labels[i] if i < len(labels) else f"y{i:02d}" -# col_y = get_csv_column_name(label, ax_row, ax_col, trace_id=trace_id) -# data[col_y] = y -# -# return pd.DataFrame(data) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stackplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stem.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stem.py deleted file mode 100644 index c8ffe561b..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stem.py +++ /dev/null @@ -1,67 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 12:20:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_stem(id, tracked_dict, kwargs): -# """Format data from a stem plot call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stem -# -# Returns: -# pd.DataFrame: Formatted data from stem plot -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple) and len(args) > 0: -# if len(args) == 1: -# y = np.asarray(args[0]) -# x = np.arange(len(y)) -# elif len(args) >= 2: -# x = np.asarray(args[0]) -# y = np.asarray(args[1]) -# else: -# return pd.DataFrame() -# -# # Use structured column naming: ax-row-{row}-col-{col}_trace-id-{id}_variable-{var} -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# df = pd.DataFrame({col_x: x, col_y: y}) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stem.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_step.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_step.py deleted file mode 100644 index 513c667fe..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_step.py +++ /dev/null @@ -1,67 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 12:20:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_step(id, tracked_dict, kwargs): -# """Format data from a step plot call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to step -# -# Returns: -# pd.DataFrame: Formatted data from step plot -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple) and len(args) > 0: -# if len(args) == 1: -# y = np.asarray(args[0]) -# x = np.arange(len(y)) -# elif len(args) >= 2: -# x = np.asarray(args[0]) -# y = np.asarray(args[1]) -# else: -# return pd.DataFrame() -# -# # Use structured column naming: ax-row-{row}-col-{col}_trace-id-{id}_variable-{var} -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# df = pd.DataFrame({col_x: x, col_y: y}) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_step.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_streamplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_streamplot.py deleted file mode 100644 index a65456968..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_streamplot.py +++ /dev/null @@ -1,72 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_streamplot(id, tracked_dict, kwargs): -# """Format data from a streamplot call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to streamplot -# -# Returns: -# pd.DataFrame: Formatted data from streamplot (X, Y positions and U, V vectors) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# if "args" in tracked_dict: -# args = tracked_dict["args"] -# if isinstance(args, tuple) and len(args) >= 4: -# # streamplot(X, Y, U, V) - X, Y are 1D, U, V are 2D -# X = np.asarray(args[0]) -# Y = np.asarray(args[1]) -# U = np.asarray(args[2]) -# V = np.asarray(args[3]) -# -# # Create meshgrid if X, Y are 1D -# if X.ndim == 1 and Y.ndim == 1: -# X, Y = np.meshgrid(X, Y) -# -# df = pd.DataFrame( -# { -# get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): X.flatten(), -# get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id): Y.flatten(), -# get_csv_column_name("u", ax_row, ax_col, trace_id=trace_id): U.flatten(), -# get_csv_column_name("v", ax_row, ax_col, trace_id=trace_id): V.flatten(), -# } -# ) -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_streamplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_bar.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_bar.py deleted file mode 100644 index 05752cfc4..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_bar.py +++ /dev/null @@ -1,100 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_bar.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# """CSV formatter for stx_bar() calls.""" -# -# import pandas as pd -# import numpy as np -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_stx_bar(id, tracked_dict, kwargs): -# """Format data from stx_bar call for CSV export. -# -# Parameters -# ---------- -# id : str -# Tracking identifier -# tracked_dict : dict -# Dictionary containing tracked data with 'bar_df' key -# kwargs : dict -# Additional keyword arguments (may contain yerr) -# -# Returns -# ------- -# pd.DataFrame -# Formatted bar data with standardized column names -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get bar_df from tracked data -# bar_df = tracked_dict.get("bar_df") -# if bar_df is not None and isinstance(bar_df, pd.DataFrame): -# result = bar_df.copy() -# renamed = {} -# # Map 'x' and 'height' to standardized column names -# for col in result.columns: -# if col == "x": -# renamed[col] = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# elif col == "height": -# renamed[col] = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# else: -# renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) -# -# result = result.rename(columns=renamed) -# -# # Add yerr if present in kwargs -# yerr = kwargs.get("yerr") if kwargs else None -# if yerr is not None: -# try: -# yerr_array = np.asarray(yerr) -# if len(yerr_array) == len(result): -# col_yerr = get_csv_column_name("yerr", ax_row, ax_col, trace_id=trace_id) -# result[col_yerr] = yerr_array -# except (TypeError, ValueError): -# pass -# -# return result -# -# # Fallback to args if bar_df not found -# args = tracked_dict.get("args", []) -# if len(args) >= 2: -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# data = {col_x: args[0], col_y: args[1]} -# -# # Add yerr if present -# yerr = kwargs.get("yerr") if kwargs else None -# if yerr is not None: -# try: -# yerr_array = np.asarray(yerr) -# col_yerr = get_csv_column_name("yerr", ax_row, ax_col, trace_id=trace_id) -# data[col_yerr] = yerr_array -# except (TypeError, ValueError): -# pass -# -# return pd.DataFrame(data) -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_bar.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_barh.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_barh.py deleted file mode 100644 index 5b18230d0..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_barh.py +++ /dev/null @@ -1,101 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_barh.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# """CSV formatter for stx_barh() calls.""" -# -# import pandas as pd -# import numpy as np -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_stx_barh(id, tracked_dict, kwargs): -# """Format data from stx_barh call for CSV export. -# -# Parameters -# ---------- -# id : str -# Tracking identifier -# tracked_dict : dict -# Dictionary containing tracked data with 'barh_df' key -# kwargs : dict -# Additional keyword arguments (may contain xerr) -# -# Returns -# ------- -# pd.DataFrame -# Formatted barh data with standardized column names -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get barh_df from tracked data -# barh_df = tracked_dict.get("barh_df") -# if barh_df is not None and isinstance(barh_df, pd.DataFrame): -# result = barh_df.copy() -# renamed = {} -# # Map 'y' and 'width' to standardized column names -# for col in result.columns: -# if col == "y": -# renamed[col] = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# elif col == "width": -# renamed[col] = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# else: -# renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) -# -# result = result.rename(columns=renamed) -# -# # Add xerr if present in kwargs -# xerr = kwargs.get("xerr") if kwargs else None -# if xerr is not None: -# try: -# xerr_array = np.asarray(xerr) -# if len(xerr_array) == len(result): -# col_xerr = get_csv_column_name("xerr", ax_row, ax_col, trace_id=trace_id) -# result[col_xerr] = xerr_array -# except (TypeError, ValueError): -# pass -# -# return result -# -# # Fallback to args if barh_df not found -# args = tracked_dict.get("args", []) -# if len(args) >= 2: -# # Note: in barh, first arg is y positions, second is widths (x values) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# data = {col_y: args[0], col_x: args[1]} -# -# # Add xerr if present -# xerr = kwargs.get("xerr") if kwargs else None -# if xerr is not None: -# try: -# xerr_array = np.asarray(xerr) -# col_xerr = get_csv_column_name("xerr", ax_row, ax_col, trace_id=trace_id) -# data[col_xerr] = xerr_array -# except (TypeError, ValueError): -# pass -# -# return pd.DataFrame(data) -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_barh.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_conf_mat.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_conf_mat.py deleted file mode 100644 index 5cadd2ddb..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_conf_mat.py +++ /dev/null @@ -1,91 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_conf_mat.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_conf_mat.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# import pandas as pd -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_conf_mat(id, tracked_dict, kwargs): -# """Format data from a stx_conf_mat call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stx_conf_mat -# -# Returns: -# pd.DataFrame: Formatted confusion matrix data -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Extract confusion matrix if available in args -# if len(args) >= 1 and isinstance(args[0], (np.ndarray, list)): -# conf_mat = np.array(args[0]) -# -# # Convert to DataFrame -# if conf_mat.ndim == 2: -# # Create column and index names -# n_classes = conf_mat.shape[0] -# columns = [f"Predicted_{i}" for i in range(n_classes)] -# index = [f"True_{i}" for i in range(n_classes)] -# -# # Create DataFrame with proper labels -# df = pd.DataFrame(conf_mat, columns=columns, index=index) -# -# # Reset index to make it a regular column -# df = df.reset_index().rename(columns={"index": "True_Class"}) -# -# # Add prefix to all columns using single source of truth -# df.columns = [ -# get_csv_column_name(f"conf-mat-{col}", ax_row, ax_col, trace_id=trace_id) -# for col in df.columns -# ] -# -# return df -# -# # Extract balanced accuracy if available as fallback -# bacc = tracked_dict.get("balanced_accuracy") -# -# # Create DataFrame with the balanced accuracy -# if bacc is not None: -# col_name = get_csv_column_name( -# "conf-mat-balanced-accuracy", ax_row, ax_col, trace_id=trace_id -# ) -# df = pd.DataFrame({col_name: [bacc]}) -# return df -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_conf_mat.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_contour.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_contour.py deleted file mode 100644 index 1fde3af50..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_contour.py +++ /dev/null @@ -1,70 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_contour.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# """CSV formatter for stx_contour() calls.""" -# -# import pandas as pd -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_stx_contour(id, tracked_dict, kwargs): -# """Format data from stx_contour call for CSV export. -# -# Parameters -# ---------- -# id : str -# Identifier for the plot -# tracked_dict : dict -# Dictionary containing tracked data with 'contour_df' -# kwargs : dict -# Keyword arguments passed to stx_contour -# -# Returns -# ------- -# pd.DataFrame -# Formatted contour data with X, Y, Z columns -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get contour_df from tracked_dict -# contour_df = tracked_dict.get("contour_df") -# if contour_df is not None and isinstance(contour_df, pd.DataFrame): -# result = contour_df.copy() -# -# # Rename columns using single source of truth -# renamed = {} -# for col in result.columns: -# if col == "X": -# renamed[col] = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# elif col == "Y": -# renamed[col] = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# elif col == "Z": -# renamed[col] = get_csv_column_name("z", ax_row, ax_col, trace_id=trace_id) -# else: -# renamed[col] = get_csv_column_name(col.lower(), ax_row, ax_col, trace_id=trace_id) -# -# return result.rename(columns=renamed) -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_contour.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_ecdf.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_ecdf.py deleted file mode 100644 index 50b783494..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_ecdf.py +++ /dev/null @@ -1,71 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_ecdf.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_ecdf.py -# # ---------------------------------------- -# import os -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# -# def _format_plot_ecdf(id, tracked_dict, kwargs): -# """Format data from a stx_ecdf call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing 'ecdf_df' key with ECDF data -# kwargs (dict): Keyword arguments passed to stx_ecdf -# -# Returns: -# pd.DataFrame: Formatted ECDF data -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Get the ecdf_df from tracked_dict -# ecdf_df = tracked_dict.get("ecdf_df") -# -# if ecdf_df is None or not isinstance(ecdf_df, pd.DataFrame): -# return pd.DataFrame() -# -# # Create a copy to avoid modifying the original -# result = ecdf_df.copy() -# -# # Add prefix to column names if ID is provided -# if id is not None: -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Rename columns using single source of truth -# renamed = {} -# for col in result.columns: -# # Use the original column name as the variable (e.g., "ecdf_value", "ecdf_prob") -# renamed[col] = get_csv_column_name( -# f"ecdf_{col}", ax_row, ax_col, trace_id=trace_id -# ) -# result = result.rename(columns=renamed) -# -# return result - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_ecdf.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_errorbar.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_errorbar.py deleted file mode 100644 index 957718e14..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_errorbar.py +++ /dev/null @@ -1,136 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_errorbar.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# """CSV formatter for stx_errorbar() calls.""" -# -# import pandas as pd -# import numpy as np -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_stx_errorbar(id, tracked_dict, kwargs): -# """Format data from stx_errorbar call for CSV export. -# -# Parameters -# ---------- -# id : str -# Tracking identifier -# tracked_dict : dict -# Dictionary containing tracked data with 'errorbar_df' key -# kwargs : dict -# Additional keyword arguments (may contain yerr, xerr) -# -# Returns -# ------- -# pd.DataFrame -# Formatted errorbar data with standardized column names -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get errorbar_df from tracked data -# errorbar_df = tracked_dict.get("errorbar_df") -# if errorbar_df is not None and isinstance(errorbar_df, pd.DataFrame): -# result = errorbar_df.copy() -# renamed = {} -# -# # Map columns to standardized names -# for col in result.columns: -# if col == "x": -# renamed[col] = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# elif col == "y": -# renamed[col] = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# elif col == "yerr": -# # Check if yerr is asymmetric (tuple/list of 2) -# yerr_value = result[col].iloc[0] if len(result) > 0 else None -# if isinstance(yerr_value, (list, tuple)) and len(yerr_value) == 2: -# # Handle asymmetric yerr separately below -# continue -# else: -# renamed[col] = get_csv_column_name("yerr", ax_row, ax_col, trace_id=trace_id) -# elif col == "xerr": -# # Check if xerr is asymmetric (tuple/list of 2) -# xerr_value = result[col].iloc[0] if len(result) > 0 else None -# if isinstance(xerr_value, (list, tuple)) and len(xerr_value) == 2: -# # Handle asymmetric xerr separately below -# continue -# else: -# renamed[col] = get_csv_column_name("xerr", ax_row, ax_col, trace_id=trace_id) -# else: -# renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) -# -# result = result.rename(columns=renamed) -# -# # Handle asymmetric error bars if needed from kwargs -# yerr = kwargs.get("yerr") if kwargs else None -# xerr = kwargs.get("xerr") if kwargs else None -# -# if yerr is not None and isinstance(yerr, (list, tuple)) and len(yerr) == 2: -# col_yerr_neg = get_csv_column_name("yerr-neg", ax_row, ax_col, trace_id=trace_id) -# col_yerr_pos = get_csv_column_name("yerr-pos", ax_row, ax_col, trace_id=trace_id) -# result[col_yerr_neg] = yerr[0] -# result[col_yerr_pos] = yerr[1] -# -# if xerr is not None and isinstance(xerr, (list, tuple)) and len(xerr) == 2: -# col_xerr_neg = get_csv_column_name("xerr-neg", ax_row, ax_col, trace_id=trace_id) -# col_xerr_pos = get_csv_column_name("xerr-pos", ax_row, ax_col, trace_id=trace_id) -# result[col_xerr_neg] = xerr[0] -# result[col_xerr_pos] = xerr[1] -# -# return result -# -# # Fallback to args if errorbar_df not found -# args = tracked_dict.get("args", []) -# if len(args) >= 2: -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# data = {col_x: args[0], col_y: args[1]} -# -# # Add error bars if present -# yerr = kwargs.get("yerr") if kwargs else None -# xerr = kwargs.get("xerr") if kwargs else None -# -# if yerr is not None: -# if isinstance(yerr, (list, tuple)) and len(yerr) == 2: -# col_yerr_neg = get_csv_column_name("yerr-neg", ax_row, ax_col, trace_id=trace_id) -# col_yerr_pos = get_csv_column_name("yerr-pos", ax_row, ax_col, trace_id=trace_id) -# data[col_yerr_neg] = yerr[0] -# data[col_yerr_pos] = yerr[1] -# else: -# col_yerr = get_csv_column_name("yerr", ax_row, ax_col, trace_id=trace_id) -# data[col_yerr] = yerr -# -# if xerr is not None: -# if isinstance(xerr, (list, tuple)) and len(xerr) == 2: -# col_xerr_neg = get_csv_column_name("xerr-neg", ax_row, ax_col, trace_id=trace_id) -# col_xerr_pos = get_csv_column_name("xerr-pos", ax_row, ax_col, trace_id=trace_id) -# data[col_xerr_neg] = xerr[0] -# data[col_xerr_pos] = xerr[1] -# else: -# col_xerr = get_csv_column_name("xerr", ax_row, ax_col, trace_id=trace_id) -# data[col_xerr] = xerr -# -# return pd.DataFrame(data) -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_errorbar.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_fillv.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_fillv.py deleted file mode 100644 index fa6689728..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_fillv.py +++ /dev/null @@ -1,88 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_fillv.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 12:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_fillv.py -# -# """CSV formatter for stx_fillv() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_fillv(id, tracked_dict, kwargs): -# """Format data from a stx_fillv call. -# -# Formats data similar to line plot format for better compatibility. -# Uses standard column naming convention: -# (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stx_fillv -# -# Returns: -# pd.DataFrame: Formatted fillv data in a long-format dataframe -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Try to get starts/ends directly from tracked_dict first -# starts = tracked_dict.get("starts") -# ends = tracked_dict.get("ends") -# -# # If not found, get from args -# if starts is None or ends is None: -# args = tracked_dict.get("args", []) -# -# # Extract data if available from args -# if len(args) >= 2: -# starts, ends = args[0], args[1] -# -# # If we have valid starts and ends, create a DataFrame in a format similar to line plot -# if starts is not None and ends is not None: -# # Convert to numpy arrays if they're lists for better handling -# if isinstance(starts, list): -# starts = np.array(starts) -# if isinstance(ends, list): -# ends = np.array(ends) -# -# # Get standard column names -# x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# y_col = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# type_col = get_csv_column_name("type", ax_row, ax_col, trace_id=trace_id) -# -# # Create a DataFrame with x, y pairs for each fill span -# rows = [] -# for start, end in zip(starts, ends): -# rows.append({x_col: start, y_col: 0, type_col: "start"}) -# rows.append({x_col: end, y_col: 0, type_col: "end"}) -# -# if rows: -# return pd.DataFrame(rows) -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_fillv.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_heatmap.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_heatmap.py deleted file mode 100644 index 7c8f98c67..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_heatmap.py +++ /dev/null @@ -1,90 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_heatmap.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_heatmap.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_heatmap(id, tracked_dict, kwargs): -# """Format data from a stx_heatmap call. -# -# Exports heatmap data in xyz format (x, y, value) for better compatibility. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stx_heatmap -# -# Returns: -# pd.DataFrame: Formatted heatmap data in xyz format -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Extract data from tracked_dict -# data = tracked_dict.get("data") -# x_labels = tracked_dict.get("x_labels") -# y_labels = tracked_dict.get("y_labels") -# -# if data is not None and hasattr(data, "shape") and len(data.shape) == 2: -# rows, cols = data.shape -# row_indices, col_indices = np.meshgrid(range(rows), range(cols), indexing="ij") -# -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Format data in xyz format (x, y, value) using single source of truth -# df = pd.DataFrame( -# { -# get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): col_indices.flatten(), # x is column -# get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id): row_indices.flatten(), # y is row -# get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id): data.flatten(), # z is intensity/value -# } -# ) -# -# # Add label information if available -# if x_labels is not None and len(x_labels) == cols: -# # Map column indices to x labels (columns are x) -# x_label_map = {i: label for i, label in enumerate(x_labels)} -# x_col_name = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# x_label_col_name = get_csv_column_name("x_label", ax_row, ax_col, trace_id=trace_id) -# df[x_label_col_name] = df[x_col_name].map(x_label_map) -# -# if y_labels is not None and len(y_labels) == rows: -# # Map row indices to y labels (rows are y) -# y_label_map = {i: label for i, label in enumerate(y_labels)} -# y_col_name = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# y_label_col_name = get_csv_column_name("y_label", ax_row, ax_col, trace_id=trace_id) -# df[y_label_col_name] = df[y_col_name].map(y_label_map) -# -# return df -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_heatmap.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_image.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_image.py deleted file mode 100644 index 0f01369cf..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_image.py +++ /dev/null @@ -1,125 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_image.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_image.py -# # ---------------------------------------- -# import os -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# -# def _format_plot_image(id, tracked_dict, kwargs): -# """Format data from a stx_image call. -# -# Exports image data in long-format xyz format for better compatibility. -# Also saves channel data for RGB/RGBA images. -# -# Args: -# id (str or int): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stx_image -# -# Returns: -# pd.DataFrame: Formatted image data in xyz format -# """ -# # Check if tracked_dict is not a dictionary or is empty -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse the tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Check if image_df is available and use it if present -# if "image_df" in tracked_dict: -# image_df = tracked_dict.get("image_df") -# if isinstance(image_df, pd.DataFrame): -# # Add prefix if ID is provided -# if id is not None: -# image_df = image_df.copy() -# # Rename columns using single source of truth -# renamed = {} -# for col in image_df.columns: -# # Convert to string to handle integer column names -# col_str = str(col) -# renamed[col] = get_csv_column_name( -# col_str, ax_row, ax_col, trace_id=trace_id -# ) -# image_df = image_df.rename(columns=renamed) -# return image_df -# -# # If we have image data -# if "image" in tracked_dict: -# img = tracked_dict["image"] -# -# # Handle 2D grayscale images - create xyz format (x, y, value) -# if isinstance(img, np.ndarray) and img.ndim == 2: -# rows, cols = img.shape -# row_indices, col_indices = np.meshgrid( -# range(rows), range(cols), indexing="ij" -# ) -# -# # Create xyz format using single source of truth -# df = pd.DataFrame( -# { -# get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id): col_indices.flatten(), # x is column -# get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id): row_indices.flatten(), # y is row -# get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id): img.flatten(), # z is intensity -# } -# ) -# return df -# -# # Handle RGB/RGBA images - create xyz format with additional channel information -# elif isinstance(img, np.ndarray) and img.ndim == 3: -# rows, cols, channels = img.shape -# -# # Create a list to hold rows for a long-format DataFrame -# data_rows = [] -# channel_names = ["r", "g", "b", "a"] -# -# # Get column names using single source of truth -# x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# y_col = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# channel_col = get_csv_column_name("channel", ax_row, ax_col, trace_id=trace_id) -# value_col = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) -# -# # Create long-format data (x, y, channel, value) -# for r in range(rows): -# for c in range(cols): -# for ch in range(min(channels, len(channel_names))): -# data_rows.append( -# { -# x_col: c, # x is column -# y_col: r, # y is row -# channel_col: channel_names[ch], # channel name -# value_col: img[r, c, ch], # channel value -# } -# ) -# -# # Return long-format DataFrame -# return pd.DataFrame(data_rows) -# -# # Skip CSV export if no suitable data format found -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_image.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_imshow.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_imshow.py deleted file mode 100644 index b816877dc..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_imshow.py +++ /dev/null @@ -1,79 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_imshow.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# """CSV formatter for stx_imshow() calls.""" -# -# import numpy as np -# import pandas as pd -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_stx_imshow(id, tracked_dict, kwargs): -# """Format data from stx_imshow call for CSV export. -# -# Parameters -# ---------- -# id : str -# Identifier for the plot -# tracked_dict : dict -# Dictionary containing tracked data with 'imshow_df' -# kwargs : dict -# Keyword arguments passed to stx_imshow -# -# Returns -# ------- -# pd.DataFrame -# Formatted imshow data in row, col, value format (or row, col, R, G, B for RGB) -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get imshow_df from tracked_dict -# imshow_df = tracked_dict.get("imshow_df") -# if imshow_df is not None and isinstance(imshow_df, pd.DataFrame): -# # Convert from 2D DataFrame format (with col_0, col_1, ... columns) -# # to row, col, value format for easier analysis -# n_rows, n_cols = imshow_df.shape -# -# # Create row and column indices -# row_indices = np.repeat(np.arange(n_rows), n_cols) -# col_indices = np.tile(np.arange(n_cols), n_rows) -# -# # Get column names from single source of truth -# col_row = get_csv_column_name("row", ax_row, ax_col, trace_id=trace_id) -# col_col = get_csv_column_name("col", ax_row, ax_col, trace_id=trace_id) -# col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) -# -# # Flatten the DataFrame values -# values = imshow_df.values.flatten() -# -# result = pd.DataFrame({ -# col_row: row_indices, -# col_col: col_indices, -# col_value: values -# }) -# -# return result -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_imshow.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_joyplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_joyplot.py deleted file mode 100644 index 8dc7b1700..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_joyplot.py +++ /dev/null @@ -1,100 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_joyplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_joyplot.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# import pandas as pd -# from scitex.pd import force_df -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_joyplot(id, tracked_dict, kwargs): -# """Format data from a stx_joyplot call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing 'joyplot_data' key with joyplot data -# kwargs (dict): Keyword arguments passed to stx_joyplot -# -# Returns: -# pd.DataFrame: Formatted joyplot data -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get joyplot_data from tracked_dict -# data = tracked_dict.get("joyplot_data") -# -# if data is None: -# return pd.DataFrame() -# -# # Handle different data types -# if isinstance(data, pd.DataFrame): -# # Make a copy to avoid modifying original -# result = data.copy() -# # Add prefix to column names using single source of truth -# if id is not None: -# result.columns = [ -# get_csv_column_name(f"joyplot-{col}", ax_row, ax_col, trace_id=trace_id) -# for col in result.columns -# ] -# return result -# -# elif isinstance(data, dict): -# # Convert dictionary to DataFrame -# result = pd.DataFrame() -# for group, values in data.items(): -# col_name = get_csv_column_name( -# f"joyplot-{group}", ax_row, ax_col, trace_id=trace_id -# ) -# result[col_name] = pd.Series(values) -# return result -# -# elif isinstance(data, (list, tuple)) and all( -# isinstance(x, (np.ndarray, list)) for x in data -# ): -# # Convert list of arrays to DataFrame -# result = pd.DataFrame() -# for i, values in enumerate(data): -# col_name = get_csv_column_name( -# f"joyplot-group{i:02d}", ax_row, ax_col, trace_id=trace_id -# ) -# result[col_name] = pd.Series(values) -# return result -# -# # Try to force to DataFrame as a last resort -# try: -# col_name = get_csv_column_name( -# "joyplot-data", ax_row, ax_col, trace_id=trace_id -# ) -# return force_df({col_name: data}) -# except: -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_joyplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_line.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_line.py deleted file mode 100644 index c5ca27e99..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_line.py +++ /dev/null @@ -1,71 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_line.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 02:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_line.py -# -# """CSV formatter for stx_line() calls - uses standard column naming.""" -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_line(id, tracked_dict, kwargs): -# """Format data from a stx_line call. -# -# Processes stx_line data for CSV export using standard column naming -# (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing 'plot_df' key with plot data -# kwargs (dict): Keyword arguments passed to stx_line -# -# Returns: -# pd.DataFrame: Formatted line plot data with standard column names -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Get the plot_df from tracked_dict -# plot_df = tracked_dict.get("plot_df") -# -# if plot_df is None or not isinstance(plot_df, pd.DataFrame): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Create a copy to avoid modifying the original -# result = plot_df.copy() -# -# # Rename columns using standard naming convention -# renamed = {} -# for col in result.columns: -# if col == "x": -# renamed[col] = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# elif col == "y": -# renamed[col] = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# else: -# renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) -# -# return result.rename(columns=renamed) - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_line.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_mean_ci.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_mean_ci.py deleted file mode 100644 index a5be9be8a..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_mean_ci.py +++ /dev/null @@ -1,66 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_ci.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 02:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_ci.py -# -# """CSV formatter for stx_mean_ci() calls - uses standard column naming.""" -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_mean_ci(id, tracked_dict, kwargs): -# """Format data from a stx_mean_ci call. -# -# Processes mean with confidence interval band plot data for CSV export using -# standard column naming (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Contains 'plot_df' (pandas DataFrame with mean and CI data) -# kwargs (dict): Keyword arguments passed to stx_mean_ci -# -# Returns: -# pd.DataFrame: Formatted mean and CI data with standard column names -# """ -# # Mean-CI plot data is passed in the tracked_dict -# if not tracked_dict: -# return pd.DataFrame() -# -# # Get the plot_df from tracked_dict -# plot_df = tracked_dict.get("plot_df") -# -# if plot_df is None or not isinstance(plot_df, pd.DataFrame): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Create a copy to avoid modifying the original -# result = plot_df.copy() -# -# # Rename columns using standard naming convention -# renamed = {} -# for col in result.columns: -# renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) -# -# return result.rename(columns=renamed) - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_ci.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_mean_std.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_mean_std.py deleted file mode 100644 index 35307ed50..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_mean_std.py +++ /dev/null @@ -1,66 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_std.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 02:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_std.py -# -# """CSV formatter for stx_mean_std() calls - uses standard column naming.""" -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_mean_std(id, tracked_dict, kwargs): -# """Format data from a stx_mean_std call. -# -# Processes mean with standard deviation band plot data for CSV export using -# standard column naming (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing 'plot_df' key with mean and std data -# kwargs (dict): Keyword arguments passed to stx_mean_std -# -# Returns: -# pd.DataFrame: Formatted mean and std data with standard column names -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Get the plot_df from tracked_dict -# plot_df = tracked_dict.get("plot_df") -# -# if plot_df is None or not isinstance(plot_df, pd.DataFrame): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Create a copy to avoid modifying the original -# result = plot_df.copy() -# -# # Rename columns using standard naming convention -# renamed = {} -# for col in result.columns: -# renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) -# -# return result.rename(columns=renamed) - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_mean_std.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_median_iqr.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_median_iqr.py deleted file mode 100644 index ecb9792eb..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_median_iqr.py +++ /dev/null @@ -1,66 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_median_iqr.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-13 02:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_median_iqr.py -# -# """CSV formatter for stx_median_iqr() calls - uses standard column naming.""" -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_median_iqr(id, tracked_dict, kwargs): -# """Format data from a stx_median_iqr call. -# -# Processes median with interquartile range band plot data for CSV export using -# standard column naming (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Contains 'plot_df' (pandas DataFrame with median and IQR data) -# kwargs (dict): Keyword arguments passed to stx_median_iqr -# -# Returns: -# pd.DataFrame: Formatted median and IQR data with standard column names -# """ -# # Median-IQR plot data is passed in the tracked_dict -# if not tracked_dict: -# return pd.DataFrame() -# -# # Get the plot_df from tracked_dict -# plot_df = tracked_dict.get("plot_df") -# -# if plot_df is None or not isinstance(plot_df, pd.DataFrame): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Create a copy to avoid modifying the original -# result = plot_df.copy() -# -# # Rename columns using standard naming convention -# renamed = {} -# for col in result.columns: -# renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) -# -# return result.rename(columns=renamed) - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_median_iqr.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_raster.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_raster.py deleted file mode 100644 index 933638c76..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_raster.py +++ /dev/null @@ -1,68 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_raster.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_raster.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import pandas as pd -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_raster(id, tracked_dict, kwargs): -# """Format data from a stx_raster call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing 'raster_digit_df' key with raster plot data -# kwargs (dict): Keyword arguments passed to stx_raster -# -# Returns: -# pd.DataFrame: Formatted raster plot data -# """ -# # Check if args is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get the raster_digit_df from args -# raster_df = tracked_dict.get("raster_digit_df") -# -# if raster_df is None or not isinstance(raster_df, pd.DataFrame): -# return pd.DataFrame() -# -# # Create a copy to avoid modifying the original -# result = raster_df.copy() -# -# # Add prefix to column names using single source of truth -# if id is not None: -# # Rename columns with ID prefix -# result.columns = [ -# get_csv_column_name(f"raster-{col}", ax_row, ax_col, trace_id=trace_id) -# for col in result.columns -# ] -# -# return result - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_raster.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_rectangle.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_rectangle.py deleted file mode 100644 index 8b32fe041..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_rectangle.py +++ /dev/null @@ -1,145 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_rectangle.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 12:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_rectangle.py -# -# """CSV formatter for stx_rectangle() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_rectangle(id, tracked_dict, kwargs): -# """Format data from a stx_rectangle call. -# -# Uses standard column naming convention: -# (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stx_rectangle -# -# Returns: -# pd.DataFrame: Formatted rectangle data -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get standard column names -# x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# y_col = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# width_col = get_csv_column_name("width", ax_row, ax_col, trace_id=trace_id) -# height_col = get_csv_column_name("height", ax_row, ax_col, trace_id=trace_id) -# -# # Try to get rectangle parameters directly from tracked_dict -# x = tracked_dict.get("x") -# y = tracked_dict.get("y") -# width = tracked_dict.get("width") -# height = tracked_dict.get("height") -# -# # If direct parameters aren't available, try the args -# if any(param is None for param in [x, y, width, height]): -# args = tracked_dict.get("args", []) -# -# # Rectangles defined by [x, y, width, height] -# if len(args) >= 4: -# x, y, width, height = args[0], args[1], args[2], args[3] -# -# # If we have all required parameters, create the DataFrame -# if all(param is not None for param in [x, y, width, height]): -# try: -# # Handle single rectangle -# if all( -# isinstance(val, (int, float, np.number)) -# for val in [x, y, width, height] -# ): -# return pd.DataFrame( -# { -# x_col: [x], -# y_col: [y], -# width_col: [width], -# height_col: [height], -# } -# ) -# -# # Handle multiple rectangles (arrays) -# elif all( -# isinstance(val, (np.ndarray, list)) for val in [x, y, width, height] -# ): -# try: -# return pd.DataFrame( -# { -# x_col: x, -# y_col: y, -# width_col: width, -# height_col: height, -# } -# ) -# except ValueError: -# # Handle case where arrays might be different lengths -# result = pd.DataFrame() -# result[x_col] = pd.Series(x) -# result[y_col] = pd.Series(y) -# result[width_col] = pd.Series(width) -# result[height_col] = pd.Series(height) -# return result -# except Exception: -# # Fallback for rectangle in case of any errors -# try: -# return pd.DataFrame( -# { -# x_col: [float(x) if x is not None else 0], -# y_col: [float(y) if y is not None else 0], -# width_col: [float(width) if width is not None else 0], -# height_col: [float(height) if height is not None else 0], -# } -# ) -# except (TypeError, ValueError): -# pass -# -# # Check directly in the kwargs for the parameters -# rect_x = kwargs.get("x") -# rect_y = kwargs.get("y") -# rect_w = kwargs.get("width") -# rect_h = kwargs.get("height") -# -# if all(param is not None for param in [rect_x, rect_y, rect_w, rect_h]): -# try: -# return pd.DataFrame( -# { -# x_col: [float(rect_x)], -# y_col: [float(rect_y)], -# width_col: [float(rect_w)], -# height_col: [float(rect_h)], -# } -# ) -# except (TypeError, ValueError): -# pass -# -# # Default empty DataFrame if nothing could be processed -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_rectangle.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_scatter.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_scatter.py deleted file mode 100644 index a2f548a8f..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_scatter.py +++ /dev/null @@ -1,67 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# """CSV formatter for stx_scatter() calls.""" -# -# import pandas as pd -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_stx_scatter(id, tracked_dict, kwargs): -# """Format data from stx_scatter call for CSV export. -# -# Parameters -# ---------- -# id : str -# Tracking identifier -# tracked_dict : dict -# Dictionary containing tracked data with 'scatter_df' key -# kwargs : dict -# Additional keyword arguments (unused) -# -# Returns -# ------- -# pd.DataFrame -# Formatted scatter data with standardized column names -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get scatter_df from tracked data -# scatter_df = tracked_dict.get("scatter_df") -# if scatter_df is not None and isinstance(scatter_df, pd.DataFrame): -# result = scatter_df.copy() -# renamed = {} -# for col in result.columns: -# renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) -# return result.rename(columns=renamed) -# -# # Fallback to args if scatter_df not found -# args = tracked_dict.get("args", []) -# if len(args) >= 2: -# col_x = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# col_y = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# return pd.DataFrame({col_x: args[0], col_y: args[1]}) -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_scatter_hist.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_scatter_hist.py deleted file mode 100644 index e104c4d45..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_scatter_hist.py +++ /dev/null @@ -1,108 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter_hist.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 18:14:26 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_plot_scatter_hist.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_scatter_hist(id, tracked_dict, kwargs): -# """Format data from a stx_scatter_hist call. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stx_scatter_hist -# -# Returns: -# pd.DataFrame: Formatted scatter histogram data -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to extract axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Extract data from tracked_dict -# x = tracked_dict.get("x") -# y = tracked_dict.get("y") -# -# if x is not None and y is not None: -# # Create base DataFrame with x and y values -# df = pd.DataFrame({ -# get_csv_column_name("scatter_hist_x", ax_row, ax_col, trace_id=trace_id): x, -# get_csv_column_name("scatter_hist_y", ax_row, ax_col, trace_id=trace_id): y, -# }) -# -# # Add histogram data if available -# hist_x = tracked_dict.get("hist_x") -# hist_y = tracked_dict.get("hist_y") -# bin_edges_x = tracked_dict.get("bin_edges_x") -# bin_edges_y = tracked_dict.get("bin_edges_y") -# -# # If we have histogram data -# if hist_x is not None and bin_edges_x is not None: -# # Calculate bin centers for x-axis histogram -# bin_centers_x = 0.5 * (bin_edges_x[1:] + bin_edges_x[:-1]) -# -# # Create a DataFrame for x histogram data -# hist_x_df = pd.DataFrame( -# { -# get_csv_column_name("hist_x_bin_centers", ax_row, ax_col, trace_id=trace_id): bin_centers_x, -# get_csv_column_name("hist_x_counts", ax_row, ax_col, trace_id=trace_id): hist_x, -# } -# ) -# -# # Add it to the main DataFrame using a MultiIndex -# for i, (center, count) in enumerate(zip(bin_centers_x, hist_x)): -# df.loc[f"hist_x_{i}", get_csv_column_name("hist_x_bin", ax_row, ax_col, trace_id=trace_id)] = center -# df.loc[f"hist_x_{i}", get_csv_column_name("hist_x_count", ax_row, ax_col, trace_id=trace_id)] = count -# -# # If we have y histogram data -# if hist_y is not None and bin_edges_y is not None: -# # Calculate bin centers for y-axis histogram -# bin_centers_y = 0.5 * (bin_edges_y[1:] + bin_edges_y[:-1]) -# -# # Create a DataFrame for y histogram data -# hist_y_df = pd.DataFrame( -# { -# get_csv_column_name("hist_y_bin_centers", ax_row, ax_col, trace_id=trace_id): bin_centers_y, -# get_csv_column_name("hist_y_counts", ax_row, ax_col, trace_id=trace_id): hist_y, -# } -# ) -# -# # Add it to the main DataFrame using a MultiIndex -# for i, (center, count) in enumerate(zip(bin_centers_y, hist_y)): -# df.loc[f"hist_y_{i}", get_csv_column_name("hist_y_bin", ax_row, ax_col, trace_id=trace_id)] = center -# df.loc[f"hist_y_{i}", get_csv_column_name("hist_y_count", ax_row, ax_col, trace_id=trace_id)] = count -# -# return df -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_scatter_hist.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_shaded_line.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_shaded_line.py deleted file mode 100644 index bafb0f34e..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_shaded_line.py +++ /dev/null @@ -1,88 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_shaded_line.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 03:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_shaded_line.py -# -# """CSV formatter for stx_shaded_line() calls - uses standard column naming.""" -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_shaded_line(id, tracked_dict, kwargs): -# """Format data from a stx_shaded_line call. -# -# Processes stx_shaded_line data for CSV export using standard column naming -# (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stx_shaded_line -# -# Returns: -# pd.DataFrame: Formatted shaded line data with standard column names -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # If we have a plot_df from plotting methods, use that directly -# if "plot_df" in tracked_dict and isinstance(tracked_dict["plot_df"], pd.DataFrame): -# plot_df = tracked_dict["plot_df"] -# # Rename columns using standard naming convention -# renamed = {} -# for col in plot_df.columns: -# renamed[col] = get_csv_column_name(col, ax_row, ax_col, trace_id=trace_id) -# return plot_df.rename(columns=renamed) -# -# # Try getting the individual components -# x = tracked_dict.get("x") -# y_middle = tracked_dict.get("y_middle") -# y_lower = tracked_dict.get("y_lower") -# y_upper = tracked_dict.get("y_upper") -# -# # If we have all necessary components -# if x is not None and y_middle is not None and y_lower is not None: -# x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# y_col = get_csv_column_name("y-middle", ax_row, ax_col, trace_id=trace_id) -# lower_col = get_csv_column_name("y-lower", ax_row, ax_col, trace_id=trace_id) -# upper_col = get_csv_column_name("y-upper", ax_row, ax_col, trace_id=trace_id) -# -# data = { -# x_col: x, -# y_col: y_middle, -# lower_col: y_lower, -# } -# -# if y_upper is not None: -# data[upper_col] = y_upper -# else: -# # If only y_lower is provided, assume it's symmetric around y_middle -# data[upper_col] = y_middle + (y_middle - y_lower) -# -# return pd.DataFrame(data) -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_shaded_line.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_violin.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_violin.py deleted file mode 100644 index 6d1847324..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_stx_violin.py +++ /dev/null @@ -1,131 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_violin.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 12:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_violin.py -# -# """CSV formatter for stx_violin() calls - uses standard column naming.""" -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_plot_violin(id, tracked_dict, kwargs): -# """Format data from a stx_violin call. -# -# Formats data in a long-format for better compatibility. -# Uses standard column naming convention: -# (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to stx_violin -# -# Returns: -# pd.DataFrame: Formatted violin plot data in long format -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get standard column names -# group_col = get_csv_column_name("group", ax_row, ax_col, trace_id=trace_id) -# value_col = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) -# -# # Extract data from tracked_dict -# data = tracked_dict.get("data") -# -# if data is not None: -# # If data is a simple array or list -# if isinstance(data, (np.ndarray, list)) and not isinstance( -# data[0], (list, np.ndarray, dict) -# ): -# # Convert to long format with group and value columns -# rows = [{group_col: "0", value_col: val} for val in data] -# return pd.DataFrame(rows) -# -# # If data is a list of arrays (multiple violin plots) -# elif isinstance(data, (list, tuple)) and all( -# isinstance(x, (list, np.ndarray)) for x in data -# ): -# # Get labels if available -# labels = tracked_dict.get("labels") -# -# # Convert to long format -# rows = [] -# for i, values in enumerate(data): -# # Use label if available, otherwise use index -# group = labels[i] if labels and i < len(labels) else f"group{i:02d}" -# for val in values: -# rows.append({group_col: str(group), value_col: val}) -# -# if rows: -# return pd.DataFrame(rows) -# -# # If data is a dictionary -# elif isinstance(data, dict): -# # Convert to long format -# rows = [] -# for group, values in data.items(): -# for val in values: -# rows.append({group_col: str(group), value_col: val}) -# -# if rows: -# return pd.DataFrame(rows) -# -# # If data is a DataFrame -# elif isinstance(data, pd.DataFrame): -# # For DataFrame data with x and y columns -# x = tracked_dict.get("x") -# y = tracked_dict.get("y") -# -# if ( -# x is not None -# and y is not None -# and x in data.columns -# and y in data.columns -# ): -# # Convert to long format -# rows = [] -# for group_name, group_data in data.groupby(x): -# for val in group_data[y]: -# rows.append({group_col: str(group_name), value_col: val}) -# -# if rows: -# return pd.DataFrame(rows) -# else: -# # For other dataframes, melt to long format -# try: -# # Try to melt to long format -# result = pd.melt(data) -# # Rename columns using standard naming -# result.columns = [group_col, value_col] -# return result -# except Exception: -# # If melt fails, just return empty -# pass -# -# return pd.DataFrame() - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_stx_violin.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_text.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_text.py deleted file mode 100644 index 53915cc74..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_text.py +++ /dev/null @@ -1,77 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_text.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-10 12:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/_subplots/_export_as_csv_formatters/_format_text.py -# -# """CSV formatter for text() calls - uses standard column naming.""" -# -# from __future__ import annotations -# -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# -# from ._format_plot import _parse_tracking_id -# -# -# def _format_text(id, tracked_dict, kwargs): -# """Format data from a text call. -# -# Uses standard column naming convention: -# (ax-row-{r}-col-{c}_trace-id-{id}_variable-{var}). -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to text -# -# Returns: -# pd.DataFrame: Formatted text position data -# """ -# # Check if tracked_dict is empty or not a dictionary -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# # Get standard column names -# x_col = get_csv_column_name("x", ax_row, ax_col, trace_id=trace_id) -# y_col = get_csv_column_name("y", ax_row, ax_col, trace_id=trace_id) -# content_col = get_csv_column_name("content", ax_row, ax_col, trace_id=trace_id) -# -# # Get the args from tracked_dict -# args = tracked_dict.get("args", []) -# -# # Extract x, y, and text content if available -# if len(args) >= 2: -# x, y = args[0], args[1] -# text_content = args[2] if len(args) >= 3 else None -# -# data = {x_col: [x], y_col: [y]} -# -# if text_content is not None: -# data[content_col] = [text_content] -# -# return pd.DataFrame(data) -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_text.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_violin.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_violin.py deleted file mode 100644 index 2808c5855..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_violin.py +++ /dev/null @@ -1,81 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_violin(id, tracked_dict, kwargs): -# """Format data from a violin call. -# -# Formats data in a long-format for better compatibility. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to violin plot -# -# Returns: -# pd.DataFrame: Formatted violin data in long format -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# args = tracked_dict.get("args", []) -# -# if len(args) >= 1: -# data = args[0] -# -# # Handle case when data is a simple array or list -# if isinstance(data, (list, np.ndarray)) and not isinstance( -# data[0], (list, np.ndarray, dict) -# ): -# rows = [{"group": "0", "value": val} for val in data] -# df = pd.DataFrame(rows) -# col_group = get_csv_column_name("group", ax_row, ax_col, trace_id=trace_id) -# col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) -# df.columns = [col_group, col_value] -# return df -# -# # Handle case when data is a dictionary -# elif isinstance(data, dict): -# rows = [] -# for group, values in data.items(): -# for val in values: -# rows.append({"group": str(group), "value": val}) -# -# if rows: -# df = pd.DataFrame(rows) -# col_group = get_csv_column_name("group", ax_row, ax_col, trace_id=trace_id) -# col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) -# df.columns = [col_group, col_value] -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violin.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_violinplot.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_violinplot.py deleted file mode 100644 index e95170dcb..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test__format_violinplot.py +++ /dev/null @@ -1,98 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-09 12:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py -# -# import numpy as np -# import pandas as pd -# -# from scitex.plt.utils._csv_column_naming import get_csv_column_name -# from ._format_plot import _parse_tracking_id -# -# -# def _format_violinplot(id, tracked_dict, kwargs): -# """Format data from a violinplot call. -# -# Formats data in a long-format for better compatibility. -# -# Args: -# id (str): Identifier for the plot -# tracked_dict (dict): Dictionary containing tracked data -# kwargs (dict): Keyword arguments passed to violinplot -# -# Returns: -# pd.DataFrame: Formatted violinplot data in long format -# """ -# if not tracked_dict or not isinstance(tracked_dict, dict): -# return pd.DataFrame() -# -# # Parse tracking ID to get axes position and trace ID -# ax_row, ax_col, trace_id = _parse_tracking_id(id) -# -# args = tracked_dict.get("args", []) -# -# if len(args) >= 1: -# data = args[0] -# -# # Handle case when data is a simple array or list -# if isinstance(data, (list, np.ndarray)) and not isinstance( -# data[0], (list, np.ndarray, dict) -# ): -# rows = [{"group": "0", "value": val} for val in data] -# df = pd.DataFrame(rows) -# # Use structured column naming -# col_group = get_csv_column_name("group", ax_row, ax_col, trace_id=trace_id) -# col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) -# df.columns = [col_group, col_value] -# return df -# -# # Handle case when data is a dictionary -# elif isinstance(data, dict): -# rows = [] -# for group, values in data.items(): -# for val in values: -# rows.append({"group": str(group), "value": val}) -# -# if rows: -# df = pd.DataFrame(rows) -# col_group = get_csv_column_name("group", ax_row, ax_col, trace_id=trace_id) -# col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) -# df.columns = [col_group, col_value] -# return df -# -# # Handle case when data is a list of arrays -# elif isinstance(data, (list, tuple)) and all( -# isinstance(x, (list, np.ndarray)) for x in data -# ): -# rows = [] -# for i, values in enumerate(data): -# for val in values: -# rows.append({"group": str(i), "value": val}) -# -# if rows: -# df = pd.DataFrame(rows) -# col_group = get_csv_column_name("group", ax_row, ax_col, trace_id=trace_id) -# col_value = get_csv_column_name("value", ax_row, ax_col, trace_id=trace_id) -# df.columns = [col_group, col_value] -# return df -# -# return pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/_format_violinplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test_test_formatters.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test_test_formatters.py deleted file mode 100644 index a68598cd0..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test_test_formatters.py +++ /dev/null @@ -1,223 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 22:05:10 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py -# # ---------------------------------------- -# import os -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import unittest -# import numpy as np -# import pandas as pd -# -# # Import formatters directly -# from ._format_plot import _format_plot -# from ._format_plot_kde import _format_plot_kde -# from ._format_plot_ecdf import _format_plot_ecdf -# from ._format_plot_heatmap import _format_plot_heatmap -# from ._format_plot_violin import _format_plot_violin -# from ._format_plot_shaded_line import _format_plot_shaded_line -# from ._format_plot_scatter_hist import _format_plot_scatter_hist -# -# -# class FormattersTest(unittest.TestCase): -# """Test the formatter functions.""" -# -# def test_format_plot_kde(self): -# """Test _format_plot_kde function.""" -# # Test case 1: Normal input -# tracked_dict = { -# 'x': np.linspace(-3, 3, 100), -# 'kde': np.exp(-np.linspace(-3, 3, 100)**2/2), -# 'n': 500 -# } -# id = 'test_kde' -# df = _format_plot_kde(id, tracked_dict, {}) -# -# # Verify columns -# self.assertIn(f"{id}_kde_x", df.columns) -# self.assertIn(f"{id}_kde_density", df.columns) -# self.assertIn(f"{id}_kde_n", df.columns) -# -# # Test case 2: Empty tracked_dict -# df = _format_plot_kde(id, {}, {}) -# self.assertTrue(df.empty) -# -# # Test case 3: Missing 'x' key -# tracked_dict = {'kde': np.exp(-np.linspace(-3, 3, 100)**2/2)} -# df = _format_plot_kde(id, tracked_dict, {}) -# self.assertTrue(df.empty) -# -# def test_format_plot(self): -# """Test _format_plot function.""" -# # Test case 1: Normal input -# tracked_dict = { -# 'plot_df': pd.DataFrame({ -# 'x': np.linspace(0, 10, 100), -# 'y': np.sin(np.linspace(0, 10, 100)) -# }) -# } -# id = 'test_plot' -# df = _format_plot(id, tracked_dict, {}) -# -# # Verify it returned the DataFrame with added prefix -# self.assertFalse(df.empty) -# -# # Test case 2: Empty tracked_dict -# df = _format_plot(id, {}, {}) -# self.assertTrue(df.empty) -# -# def test_format_plot_ecdf(self): -# """Test _format_plot_ecdf function.""" -# # Test case 1: Normal input -# tracked_dict = { -# 'ecdf_df': pd.DataFrame({ -# 'x': np.linspace(-3, 3, 100), -# 'ecdf': np.linspace(0, 1, 100) -# }) -# } -# id = 'test_ecdf' -# df = _format_plot_ecdf(id, tracked_dict, {}) -# -# # Verify it returned the DataFrame -# self.assertFalse(df.empty) -# -# # Test case 2: Empty tracked_dict -# df = _format_plot_ecdf(id, {}, {}) -# self.assertTrue(df.empty) -# -# def test_format_plot_heatmap(self): -# """Test _format_plot_heatmap function.""" -# # Test case 1: Normal input with labels -# data = np.random.rand(3, 4) -# x_labels = ['A', 'B', 'C'] -# y_labels = ['W', 'X', 'Y', 'Z'] -# -# tracked_dict = { -# 'data': data, -# 'x_labels': x_labels, -# 'y_labels': y_labels -# } -# id = 'test_heatmap' -# df = _format_plot_heatmap(id, tracked_dict, {}) -# -# # Verify it returned the DataFrame with the expected shape -# self.assertFalse(df.empty) -# self.assertEqual(df.shape[0], 12) # 3 rows * 4 columns = 12 cells -# # We should have 5 columns: row, col, value, row_label, col_label -# self.assertEqual(df.shape[1], 5) -# -# # Test case 2: No labels -# tracked_dict = {'data': data} -# df = _format_plot_heatmap(id, tracked_dict, {}) -# self.assertFalse(df.empty) -# -# # Test case 3: Empty tracked_dict -# df = _format_plot_heatmap(id, {}, {}) -# self.assertTrue(df.empty) -# -# def test_format_plot_violin(self): -# """Test _format_plot_violin function.""" -# # Test case 1: List data -# data = [np.random.normal(0, 1, 100), np.random.normal(2, 1, 100)] -# labels = ['Group A', 'Group B'] -# -# tracked_dict = { -# 'data': data, -# 'labels': labels -# } -# id = 'test_violin' -# df = _format_plot_violin(id, tracked_dict, {}) -# -# # Verify it returned the DataFrame -# self.assertFalse(df.empty) -# -# # Test case 2: DataFrame data -# data_df = pd.DataFrame({ -# 'values': np.concatenate([np.random.normal(0, 1, 100), np.random.normal(2, 1, 100)]), -# 'group': ['A'] * 100 + ['B'] * 100 -# }) -# tracked_dict = { -# 'data': data_df, -# 'x': 'group', -# 'y': 'values' -# } -# df = _format_plot_violin(id, tracked_dict, {}) -# self.assertFalse(df.empty) -# -# # Test case 3: Empty tracked_dict -# df = _format_plot_violin(id, {}, {}) -# self.assertTrue(df.empty) -# -# def test_format_plot_shaded_line(self): -# """Test _format_plot_shaded_line function.""" -# # Test case 1: Normal input -# tracked_dict = { -# 'plot_df': pd.DataFrame({ -# 'x': np.linspace(0, 10, 100), -# 'y_lower': np.sin(np.linspace(0, 10, 100)) - 0.2, -# 'y_middle': np.sin(np.linspace(0, 10, 100)), -# 'y_upper': np.sin(np.linspace(0, 10, 100)) + 0.2 -# }) -# } -# id = 'test_shaded' -# df = _format_plot_shaded_line(id, tracked_dict, {}) -# -# # Verify it returned the DataFrame -# self.assertFalse(df.empty) -# -# # Test case 2: Empty tracked_dict -# df = _format_plot_shaded_line(id, {}, {}) -# self.assertTrue(df.empty) -# -# def test_format_plot_scatter_hist(self): -# """Test _format_plot_scatter_hist function.""" -# # Test case 1: Normal input -# tracked_dict = { -# 'x': np.random.normal(0, 1, 100), -# 'y': np.random.normal(0, 1, 100), -# 'hist_x': np.random.rand(10), -# 'hist_y': np.random.rand(10), -# 'bin_edges_x': np.linspace(-3, 3, 11), -# 'bin_edges_y': np.linspace(-3, 3, 11) -# } -# id = 'test_scatter_hist' -# df = _format_plot_scatter_hist(id, tracked_dict, {}) -# -# # Verify it returned the DataFrame with expected columns -# self.assertFalse(df.empty) -# self.assertTrue(any(col.startswith(f"{id}_scatter_hist_x") for col in df.columns)) -# self.assertTrue(any(col.startswith(f"{id}_scatter_hist_y") for col in df.columns)) -# -# # Test case 2: Missing keys -# tracked_dict = { -# 'x': np.random.normal(0, 1, 100), -# 'y': np.random.normal(0, 1, 100) -# } -# df = _format_plot_scatter_hist(id, tracked_dict, {}) -# self.assertFalse(df.empty) # Should still work with just x,y -# -# # Test case 3: Empty tracked_dict -# df = _format_plot_scatter_hist(id, {}, {}) -# self.assertTrue(df.empty) -# -# -# if __name__ == '__main__': -# unittest.main() -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/test_formatters.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test_verify_formatters.py b/tests/scitex/plt/_subplots/_export_as_csv_formatters/test_verify_formatters.py deleted file mode 100644 index 06b22c0e6..000000000 --- a/tests/scitex/plt/_subplots/_export_as_csv_formatters/test_verify_formatters.py +++ /dev/null @@ -1,375 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 23:14:10 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py -# # ---------------------------------------- -# import os -# import sys -# import numpy as np -# import pandas as pd -# import matplotlib -# -# matplotlib.use("Agg") # Non-interactive backend -# -# # Add src to path if needed -# src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../../")) -# if src_path not in sys.path: -# sys.path.insert(0, src_path) -# -# import scitex -# -# # Create output directory -# OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "formatter_test_output") -# os.makedirs(OUTPUT_DIR, exist_ok=True) -# -# -# def test_all_formatters(): -# """ -# Test all formatters by creating actual plots and saving both image and CSV files. -# Each function will create a different type of plot, save it, and verify the CSV export. -# """ -# # Test each formatter with a real plot -# test_plot_kde() -# test_plot_image() -# test_plot_shaded_line() -# test_plot_scatter_hist() -# test_plot_violin() -# test_plot_heatmap() -# test_plot_ecdf() -# test_multiple_plots() -# -# -# def test_plot_kde(): -# """Test KDE plotting and CSV export.""" -# print("Testing stx_kde...") -# -# # Create figure -# fig, ax = scitex.plt.subplots() -# -# # Generate data -# np.random.seed(42) # For reproducibility -# data = np.concatenate([np.random.normal(0, 1, 500), np.random.normal(5, 1, 300)]) -# -# # Plot with ID for tracking -# ax.stx_kde(data, label="Bimodal Distribution", id="kde_test") -# -# # Style the plot -# ax.set_xyt("Value", "Density", "KDE Test") -# ax.legend() -# -# # Save both image and data -# save_path = os.path.join(OUTPUT_DIR, "kde_test.png") -# scitex.io.save(fig, save_path) -# -# # Verify CSV was created -# csv_path = save_path.replace(".png", ".csv") -# assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" -# -# # Read CSV and verify contents -# df = pd.read_csv(csv_path) -# assert "kde_test_kde_x" in df.columns, "Expected column 'kde_test_kde_x' not found" -# assert "kde_test_kde_density" in df.columns, ( -# "Expected column 'kde_test_kde_density' not found" -# ) -# -# # Close figure -# scitex.plt.close(fig) -# print("✓ stx_kde test successful") -# -# -# def test_plot_image(): -# """Test image plotting and CSV export.""" -# print("Testing stx_image...") -# -# # Create figure -# fig, ax = scitex.plt.subplots() -# -# # Generate data -# np.random.seed(42) # For reproducibility -# data = np.random.rand(20, 20) -# -# # Plot with ID for tracking -# ax.stx_image(data, cmap="viridis", id="image_test") -# -# # Style the plot -# ax.set_xyt("X", "Y", "Image Test") -# -# # Save both image and data -# save_path = os.path.join(OUTPUT_DIR, "image_test.png") -# scitex.io.save(fig, save_path) -# -# # Verify CSV was created -# csv_path = save_path.replace(".png", ".csv") -# assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" -# -# # Read CSV and verify contents -# df = pd.read_csv(csv_path) -# # The formatter should have converted the 2D array to a DataFrame -# assert not df.empty, "CSV file is empty" -# -# # Close figure -# scitex.plt.close(fig) -# print("✓ stx_image test successful") -# -# -# def test_plot_shaded_line(): -# """Test shaded line plotting and CSV export.""" -# print("Testing stx_shaded_line...") -# -# # Create figure -# fig, ax = scitex.plt.subplots() -# -# # Generate data -# np.random.seed(42) # For reproducibility -# x = np.linspace(0, 10, 100) -# y_middle = np.sin(x) -# y_lower = y_middle - 0.2 -# y_upper = y_middle + 0.2 -# -# # Plot with ID for tracking -# ax.stx_shaded_line( -# x, y_lower, y_middle, y_upper, label="Sine with error", id="shaded_line_test" -# ) -# -# # Style the plot -# ax.set_xyt("X", "Y", "Shaded Line Test") -# ax.legend() -# -# # Save both image and data -# save_path = os.path.join(OUTPUT_DIR, "shaded_line_test.png") -# scitex.io.save(fig, save_path) -# -# # Verify CSV was created -# csv_path = save_path.replace(".png", ".csv") -# assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" -# -# # Read CSV and verify contents -# df = pd.read_csv(csv_path) -# assert not df.empty, "CSV file is empty" -# -# # Close figure -# scitex.plt.close(fig) -# print("✓ stx_shaded_line test successful") -# -# -# def test_plot_scatter_hist(): -# """Test scatter histogram plotting and CSV export.""" -# print("Testing stx_scatter_hist...") -# -# # Create figure -# fig, ax = scitex.plt.subplots(figsize=(8, 8)) -# -# # Generate data -# np.random.seed(42) # For reproducibility -# x = np.random.normal(0, 1, 500) -# y = x + np.random.normal(0, 0.5, 500) -# -# # Plot with ID for tracking -# ax.stx_scatter_hist(x, y, hist_bins=30, scatter_alpha=0.7, id="scatter_hist_test") -# -# # Style the plot -# ax.set_xyt("X Values", "Y Values", "Scatter Histogram Test") -# -# # Save both image and data -# save_path = os.path.join(OUTPUT_DIR, "scatter_hist_test.png") -# scitex.io.save(fig, save_path) -# -# # Verify CSV was created -# csv_path = save_path.replace(".png", ".csv") -# assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" -# -# # Read CSV and verify contents -# df = pd.read_csv(csv_path) -# assert not df.empty, "CSV file is empty" -# -# # Close figure -# scitex.plt.close(fig) -# print("✓ stx_scatter_hist test successful") -# -# -# def test_plot_violin(): -# """Test violin plotting and CSV export.""" -# print("Testing stx_violin...") -# -# # Create figure -# fig, ax = scitex.plt.subplots() -# -# # Generate data -# np.random.seed(42) # For reproducibility -# data = [ -# np.random.normal(0, 1, 100), -# np.random.normal(2, 1.5, 100), -# np.random.normal(5, 0.8, 100), -# ] -# labels = ["Group A", "Group B", "Group C"] -# -# # Plot with ID for tracking -# ax.stx_violin( -# data, labels=labels, colors=["red", "blue", "green"], id="violin_test" -# ) -# -# # Style the plot -# ax.set_xyt("Groups", "Values", "Violin Plot Test") -# -# # Save both image and data -# save_path = os.path.join(OUTPUT_DIR, "violin_test.png") -# scitex.io.save(fig, save_path) -# -# # Verify CSV was created -# csv_path = save_path.replace(".png", ".csv") -# assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" -# -# # Read CSV and verify contents -# df = pd.read_csv(csv_path) -# assert not df.empty, "CSV file is empty" -# -# # Close figure -# scitex.plt.close(fig) -# print("✓ stx_violin test successful") -# -# -# def test_plot_heatmap(): -# """Test heatmap plotting and CSV export.""" -# print("Testing stx_heatmap...") -# -# # Create figure -# fig, ax = scitex.plt.subplots() -# -# # Generate data -# np.random.seed(42) # For reproducibility -# data = np.random.rand(5, 10) -# x_labels = [f"X{ii + 1}" for ii in range(5)] -# y_labels = [f"Y{ii + 1}" for ii in range(10)] -# -# # Plot with ID for tracking -# ax.stx_heatmap( -# data, -# x_labels=x_labels, -# y_labels=y_labels, -# cbar_label="Values", -# show_annot=True, -# value_format="{x:.2f}", -# cmap="viridis", -# id="heatmap_test", -# ) -# -# # Style the plot -# ax.set_title("Heatmap Test") -# -# # Save both image and data -# save_path = os.path.join(OUTPUT_DIR, "heatmap_test.png") -# scitex.io.save(fig, save_path) -# -# # Verify CSV was created -# csv_path = save_path.replace(".png", ".csv") -# assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" -# -# # Read CSV and verify contents -# df = pd.read_csv(csv_path) -# assert not df.empty, "CSV file is empty" -# -# # Close figure -# scitex.plt.close(fig) -# print("✓ stx_heatmap test successful") -# -# -# def test_plot_ecdf(): -# """Test ECDF plotting and CSV export.""" -# print("Testing stx_ecdf...") -# -# # Create figure -# fig, ax = scitex.plt.subplots() -# -# # Generate data -# np.random.seed(42) # For reproducibility -# data = np.random.normal(0, 1, 1000) -# -# # Plot with ID for tracking -# ax.stx_ecdf(data, label="Normal Distribution", id="ecdf_test") -# -# # Style the plot -# ax.set_xyt("Value", "Cumulative Probability", "ECDF Test") -# ax.legend() -# -# # Save both image and data -# save_path = os.path.join(OUTPUT_DIR, "ecdf_test.png") -# scitex.io.save(fig, save_path) -# -# # Verify CSV was created -# csv_path = save_path.replace(".png", ".csv") -# assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" -# -# # Read CSV and verify contents -# df = pd.read_csv(csv_path) -# assert not df.empty, "CSV file is empty" -# -# # Close figure -# scitex.plt.close(fig) -# print("✓ stx_ecdf test successful") -# -# -# def test_multiple_plots(): -# """Test multiple plots on the same axis.""" -# print("Testing multiple plots on the same axis...") -# -# # Create figure -# fig, ax = scitex.plt.subplots() -# -# # Generate data -# np.random.seed(42) # For reproducibility -# x = np.linspace(0, 10, 100) -# y1 = np.sin(x) -# y2 = np.cos(x) -# -# # Create multiple plots with different IDs -# ax.stx_line(y1, label="Sine", id="multi_test_sine") -# ax.stx_line(y2, label="Cosine", id="multi_test_cosine") -# -# # Style the plot -# ax.set_xyt("X", "Y", "Multiple Plots Test") -# ax.legend() -# -# # Save both image and data -# save_path = os.path.join(OUTPUT_DIR, "multiple_plots_test.png") -# scitex.io.save(fig, save_path) -# -# # Verify CSV was created -# csv_path = save_path.replace(".png", ".csv") -# assert os.path.exists(csv_path), f"CSV file not created: {csv_path}" -# -# # Read CSV and verify contents -# df = pd.read_csv(csv_path) -# assert not df.empty, "CSV file is empty" -# -# # Check that both plots are in the CSV -# sine_cols = [col for col in df.columns if col.startswith("multi_test_sine")] -# cosine_cols = [col for col in df.columns if col.startswith("multi_test_cosine")] -# assert len(sine_cols) > 0, "Sine plot data not found in CSV" -# assert len(cosine_cols) > 0, "Cosine plot data not found in CSV" -# -# # Close figure -# scitex.plt.close(fig) -# print("✓ Multiple plots test successful") -# -# -# if __name__ == "__main__": -# print("Starting formatter verification tests...") -# test_all_formatters() -# print("\nAll formatter tests completed successfully!") -# print(f"Output files are in: {OUTPUT_DIR}") - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv_formatters/verify_formatters.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/test__AxesWrapper.py b/tests/scitex/plt/_subplots/test__AxesWrapper.py deleted file mode 100644 index f8c34b99c..000000000 --- a/tests/scitex/plt/_subplots/test__AxesWrapper.py +++ /dev/null @@ -1,331 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-03 12:35:16 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/_subplots/test__AxesWrapper.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/_subplots/test__AxesWrapper.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pytest - -# class TestAxesWrapper: -# def setup_method(self): -# self.fig_mock = MagicMock() -# self.ax1 = MagicMock() -# self.ax2 = MagicMock() -# self.axes_array = np.array([[self.ax1, self.ax2]]) -# self.wrapper = AxesWrapper(self.fig_mock, self.axes_array) - -# def test_init(self): -# assert self.wrapper.fig is self.fig_mock -# assert np.array_equal(self.wrapper.axes, self.axes_array) - -# def test_get_figure(self): -# assert self.wrapper.get_figure() is self.fig_mock - -# def test_getattr_existing_method(self): -# # Set up a method on both axes -# self.ax1.set_title = MagicMock() -# self.ax2.set_title = MagicMock() - -# # Call the method on the wrapper -# self.wrapper.set_title("Test Title") - -# # Check that it was called on both axes -# self.ax1.set_title.assert_called_once_with("Test Title") -# self.ax2.set_title.assert_called_once_with("Test Title") - -# def test_getattr_property(self): -# # Set up a property on both axes -# type(self.ax1).figbox = PropertyMock(return_value="figbox1") -# type(self.ax2).figbox = PropertyMock(return_value="figbox2") - -# # Get the property from the wrapper -# result = self.wrapper.figbox - -# # Should return a list of the property values -# assert result == ["figbox1", "figbox2"] - -# def test_getattr_warning(self): -# # Test attempting to access a non-existent attribute -# with pytest.warns(UserWarning, match="not implemented, ignored"): -# result = self.wrapper.nonexistent_method() -# assert result is None - -# def test_getitem(self): -# # Test accessing by index -# result = self.wrapper[0, 0] -# assert result is self.ax1 - -# # Test slice returning AxesWrapper -# result = self.wrapper[0, :] -# assert isinstance(result, AxesWrapper) -# assert np.array_equal(result.axes, np.array([self.ax1, self.ax2])) - -# def test_iteration(self): -# # Test iteration through axes -# axes = list(self.wrapper) -# assert axes == [self.ax1, self.ax2] - -# def test_len(self): -# # Test len() returns number of axes -# assert len(self.wrapper) == 2 - -# def test_legend(self): -# # Test legend method -# self.wrapper.legend(loc="upper left") - -# # Should call legend on both axes -# self.ax1.legend.assert_called_once_with(loc="upper left") -# self.ax2.legend.assert_called_once_with(loc="upper left") - -# def test_history_property(self): -# # Set up history on both axes -# self.ax1.history = {"plot1": "data1"} -# self.ax2.history = {"plot2": "data2"} - -# # Get history from wrapper -# result = self.wrapper.history - -# # Should return list of histories -# assert result == [{"plot1": "data1"}, {"plot2": "data2"}] - -# def test_shape_property(self): -# # Test shape property -# assert self.wrapper.shape == (1, 2) - -# def test_export_as_csv(self): -# # Set up export_as_csv on both axes -# self.ax1.export_as_csv = MagicMock( -# return_value=pd.DataFrame({"x1": [1, 2, 3], "y1": [4, 5, 6]}) -# ) -# self.ax2.export_as_csv = MagicMock( -# return_value=pd.DataFrame({"x2": [7, 8, 9], "y2": [10, 11, 12]}) -# ) - -# # Call export_as_csv on wrapper -# result = self.wrapper.export_as_csv() - -# # Check the result -# assert isinstance(result, pd.DataFrame) -# assert list(result.columns) == [ -# "ax_00_x1", -# "ax_00_y1", -# "ax_01_x2", -# "ax_01_y2", -# ] - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxesWrapper.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-19 15:36:54 (ywatanabe)" -# # File: /ssh:ywatanabe@sp:/home/ywatanabe/proj/scitex_repo/src/scitex/plt/_subplots/_AxesWrapper.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# from functools import wraps -# -# import numpy as np -# import pandas as pd -# -# from scitex import logging -# -# logger = logging.getLogger(__name__) -# -# -# class AxesWrapper: -# def __init__(self, fig_scitex, axes_scitex): -# self._fig_scitex = fig_scitex -# self._axes_scitex = axes_scitex -# -# def get_figure(self, root=True): -# """Get the figure, compatible with matplotlib 3.8+""" -# return self._fig_scitex -# -# def __dir__(self): -# # Combine attributes from both self and the wrapped matplotlib axes -# attrs = set(dir(self.__class__)) -# attrs.update(object.__dir__(self)) -# -# # Add attributes from the axes objects if available -# if hasattr(self, "_axes_scitex") and self._axes_scitex is not None: -# # Get attributes from the first axis if there are any -# if self._axes_scitex.size > 0: -# first_ax = self._axes_scitex.flat[0] -# attrs.update(dir(first_ax)) -# -# return sorted(attrs) -# -# def __getattr__(self, name): -# # Note that self._axes_scitex is "numpy.ndarray" -# # print(f"Attribute of AxesWrapper: {name}") -# methods = [] -# try: -# for axis in self._axes_scitex.flat: -# methods.append(getattr(axis, name)) -# except Exception: -# methods = [] -# -# if methods and all(callable(m) for m in methods): -# -# @wraps(methods[0]) -# def wrapper(*args, **kwargs): -# return [ -# getattr(ax, name)(*args, **kwargs) for ax in self._axes_scitex.flat -# ] -# -# return wrapper -# -# if methods and not callable(methods[0]): -# return methods -# -# def dummy(*args, **kwargs): -# return None -# -# return dummy -# -# # def __getitem__(self, index): -# # subset = self._axes_scitex[index] -# # if isinstance(index, slice): -# # return AxesWrapper(self._fig_scitex, subset) -# # return subset -# -# def __getitem__(self, index): -# # Handle 1D-like arrays (single row or single column) -# # For (1, n) shape with integer index, return the element from the row -# # For (n, 1) shape with integer index, return the element from the column -# if isinstance(index, int): -# shape = self._axes_scitex.shape -# if len(shape) == 2: -# if shape[0] == 1: -# # Single row case: axes[i] should return axes[0, i] -# return self._axes_scitex[0, index] -# elif shape[1] == 1: -# # Single column case: axes[i] should return axes[i, 0] -# return self._axes_scitex[index, 0] -# -# subset = self._axes_scitex[index] -# if isinstance(subset, np.ndarray): -# return AxesWrapper(self._fig_scitex, subset) -# return subset -# -# def __setitem__(self, index, value): -# """Support item assignment for axes[row, col] = new_axis operations.""" -# self._axes_scitex[index] = value -# -# def __iter__(self): -# # Iterate over flattened axes for backward compatibility -# return iter(self._axes_scitex.flat) -# -# def __len__(self): -# return self._axes_scitex.size -# -# def __array__(self): -# """Support conversion to numpy array. -# -# This allows using np.array(axes) on an AxesWrapper instance, returning -# a NumPy array with the same shape as the original axes array. -# -# Notes: -# - While this enables compatibility with NumPy functions, not all -# operations will work correctly due to the nature of the wrapped -# objects. -# - For flattening operations, use the dedicated `flatten()` method -# instead of `np.array(axes).flatten()`: -# -# # RECOMMENDED: -# flat_axes = list(axes.flatten()) -# -# # AVOID (may cause "invalid __array_struct__" error): -# flat_axes = np.array(axes).flatten() -# -# Returns: -# np.ndarray: Array of wrapped axes with the same shape -# """ -# # Show a warning to help users avoid common mistakes -# logger.warning( -# "Converting AxesWrapper to numpy array. If you're trying to flatten " -# "the axes, use 'list(axes.flatten())' instead of 'np.array(axes).flatten()'." -# ) -# -# # Convert the underlying axes to a compatible numpy array representation -# flat_axes = [ax for ax in self._axes_scitex.flat] -# array_compatible = np.empty(len(flat_axes), dtype=object) -# for idx, ax in enumerate(flat_axes): -# array_compatible[idx] = ax -# return array_compatible.reshape(self._axes_scitex.shape) -# -# def legend(self, loc="best"): -# """Add legend to all axes with 'best' automatic placement by default.""" -# return [ax.legend(loc=loc) for ax in self._axes_scitex.flat] -# -# @property -# def history(self): -# return [ax.history for ax in self._axes_scitex.flat] -# -# @property -# def shape(self): -# return self._axes_scitex.shape -# -# @property -# def flat(self): -# """Return a flat iterator over all axes. -# -# This property provides direct access to the flattened axes array, -# matching numpy array behavior. -# -# Returns: -# Iterator over all axes in row-major (C-style) order -# """ -# return self._axes_scitex.flat -# -# def flatten(self): -# """Return a flattened array of all axes in the AxesWrapper. -# -# This method collects all axes from the flat iterator and returns them -# as a NumPy array. This ensures compatibility with code that expects -# a flat collection of axes. -# -# Returns: -# np.ndarray: A flattened array containing all axes -# -# Example: -# # Preferred way to get a list of all axes: -# axes_list = list(axes.flatten()) -# -# # Alternatively, if you need a NumPy array: -# axes_array = axes.flatten() -# """ -# return np.array([ax for ax in self._axes_scitex.flat]) -# -# def export_as_csv(self): -# dfs = [] -# for ii, ax in enumerate(self._axes_scitex.flat): -# df = ax.export_as_csv() -# # Column names already include axis position via get_csv_column_name -# # No need to add extra prefix -# dfs.append(df) -# return pd.concat(dfs, axis=1) if dfs else pd.DataFrame() -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxesWrapper.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/test__AxisWrapper.py b/tests/scitex/plt/_subplots/test__AxisWrapper.py deleted file mode 100644 index 65b911113..000000000 --- a/tests/scitex/plt/_subplots/test__AxisWrapper.py +++ /dev/null @@ -1,443 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-03 12:35:20 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/_subplots/test__AxisWrapper.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/_subplots/test__AxisWrapper.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pytest - -# class TestAxisWrapper: -# def setup_method(self): -# self.fig_mock = MagicMock() -# self.axis_mock = MagicMock() -# self.wrapper = AxisWrapper(self.fig_mock, self.axis_mock, track=True) - -# def test_init(self): -# assert self.wrapper.fig is self.fig_mock -# assert self.wrapper.axis is self.axis_mock -# assert self.wrapper._ax_history == {} -# assert self.wrapper.track is True -# assert self.wrapper.id == 0 - -# def test_get_figure(self): -# assert self.wrapper.get_figure() is self.fig_mock - -# def test_getattr_existing_attribute(self): -# # Test accessing an existing attribute on the axis -# self.axis_mock.get_xlim = lambda: (0, 1) -# assert self.wrapper.get_xlim() == (0, 1) - -# def test_getattr_warning(self): -# # Test attempting to access a non-existent attribute -# with pytest.warns(UserWarning, match="not implemented, ignored"): -# result = self.wrapper.nonexistent_method() -# assert result is None - -# def test_function_with_id_parameter(self): -# # Test that id parameter is handled correctly -# self.axis_mock.plot = MagicMock(return_value="plot_result") - -# # Call plot with id -# result = self.wrapper.plot([1, 2, 3], [4, 5, 6], id="test_plot") - -# # Check that plot was called without the id parameter -# self.axis_mock.plot.assert_called_once() -# args, kwargs = self.axis_mock.plot.call_args -# assert "id" not in kwargs - -# # And the result should be what the original method returned -# assert result == "plot_result" - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapper.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 10:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapper.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import warnings -# from functools import wraps -# -# import matplotlib -# -# from scitex import logging -# -# logger = logging.getLogger(__name__) -# -# from ._AxisWrapperMixins import ( -# AdjustmentMixin, -# MatplotlibPlotMixin, -# RawMatplotlibMixin, -# SeabornMixin, -# TrackingMixin, -# UnitAwareMixin, -# ) -# from scitex.plt.styles import apply_plot_defaults, apply_plot_postprocess -# -# -# class AxisWrapper( -# MatplotlibPlotMixin, -# SeabornMixin, -# RawMatplotlibMixin, -# AdjustmentMixin, -# TrackingMixin, -# UnitAwareMixin, -# ): -# def __init__(self, fig_scitex, axis_mpl, track): -# """Initialize the AxisWrapper. -# -# Args: -# fig_scitex: Parent figure wrapper -# axis_mpl: Matplotlib axis to wrap -# track: Whether to track plotting operations -# """ -# self._fig_mpl = fig_scitex._fig_mpl -# # Axis Properties -# # self.axis = axis_mpl -# # self._axis = axis_mpl -# # self._axis_scitex = self -# self._axis_mpl = axis_mpl -# -# # Axes Properties -# # self.axes = axis_mpl -# # self._axes = axis_mpl -# self._axes_mpl = axis_mpl -# # self._axes_scitex = self -# -# # Tracking properties -# self._ax_history = {} -# self._method_counters = {} # Track method counts for auto-generated IDs -# self.track = track -# self.id = 0 -# self._counter_part = matplotlib.axes.Axes -# self._tracking_depth = 0 # Depth counter to prevent tracking internal calls -# -# # Initialize unit awareness -# UnitAwareMixin.__init__(self) -# -# def get_figure(self, root=True): -# """Get the figure, compatible with matplotlib 3.8+""" -# return self._fig_mpl -# -# def twinx(self): -# """Create a twin y-axis and wrap it with AxisWrapper.""" -# twin_ax = self._axes_mpl.twinx() -# -# # Create a mock figure wrapper for the twin axis -# class MockFigWrapper: -# def __init__(self, fig_mpl): -# self._fig_mpl = fig_mpl -# -# mock_fig = MockFigWrapper(self._fig_mpl) -# return AxisWrapper(fig_scitex=mock_fig, axis_mpl=twin_ax, track=self.track) -# -# def twiny(self): -# """Create a twin x-axis and wrap it with AxisWrapper.""" -# twin_ax = self._axes_mpl.twiny() -# -# # Create a mock figure wrapper for the twin axis -# class MockFigWrapper: -# def __init__(self, fig_mpl): -# self._fig_mpl = fig_mpl -# -# mock_fig = MockFigWrapper(self._fig_mpl) -# return AxisWrapper(fig_scitex=mock_fig, axis_mpl=twin_ax, track=self.track) -# -# def __getattr__(self, name): -# # 0. Check if the attribute is explicitly defined in AxisWrapper or its Mixins -# # This check happens implicitly before __getattr__ is called. -# # If a method like `plot` is defined in BasicPlotMixin, it will be found first. -# -# # print(f"Attribute of AxisWrapper: {name}") -# -# # 1. Try to get the attribute from the wrapped axes instance -# if hasattr(self._axes_mpl, name): -# orig_attr = getattr(self._axes_mpl, name) -# -# if callable(orig_attr): -# -# @wraps(orig_attr) -# def wrapper(*args, __method_name__=name, **kwargs): -# id_value = kwargs.pop("id", None) -# track_override = kwargs.pop("track", None) -# -# # Increment tracking depth to detect internal calls -# # Internal calls (depth > 1) won't be tracked -# self._tracking_depth += 1 -# is_top_level_call = self._tracking_depth == 1 -# -# try: -# # Apply pre-processing defaults from styles module -# apply_plot_defaults( -# __method_name__, kwargs, id_value, self._axes_mpl -# ) -# -# # Pop scitex-specific kwargs before calling matplotlib -# # These are handled in post-processing -# scitex_kwargs = {} -# if __method_name__ == "violinplot": -# scitex_kwargs["boxplot"] = kwargs.pop("boxplot", True) -# -# # Call the original matplotlib method -# result = orig_attr(*args, **kwargs) -# -# # Store the scitex id on the result for later retrieval -# # This is used by _collect_figure_metadata to map traces to CSV columns -# if id_value is not None: -# if isinstance(result, list): -# # plot() returns list of Line2D objects -# for item in result: -# item._scitex_id = id_value -# elif hasattr(result, "__iter__") and not isinstance( -# result, str -# ): -# # Other containers (e.g., bar containers) -# try: -# for item in result: -# item._scitex_id = id_value -# except (TypeError, AttributeError): -# pass -# else: -# # Single object -# try: -# result._scitex_id = id_value -# except AttributeError: -# pass -# -# # Restore scitex kwargs for post-processing -# kwargs.update(scitex_kwargs) -# -# # Apply post-processing styling from styles module -# apply_plot_postprocess( -# __method_name__, result, self._axes_mpl, kwargs, args -# ) -# -# # Determine if tracking should occur -# # Only track top-level calls (depth == 1), not internal matplotlib calls -# should_track = ( -# track_override if track_override is not None else self.track -# ) and is_top_level_call -# -# # Track the method call if tracking enabled -# # Expanded list of matplotlib plotting methods to track -# tracking_methods = { -# # Basic plots -# "plot", -# "scatter", -# "bar", -# "barh", -# "hist", -# "boxplot", -# "violinplot", -# # Line plots -# "fill_between", -# "fill_betweenx", -# "errorbar", -# "step", -# "stem", -# # Fill and area plots -# "fill", -# "stackplot", -# # Statistical plots -# "hist2d", -# "hexbin", -# "pie", -# "eventplot", -# # Contour plots -# "contour", -# "contourf", -# "tricontour", -# "tricontourf", -# # Image plots -# "imshow", -# "matshow", -# "spy", -# "pcolormesh", -# "pcolor", -# # Quiver plots -# "quiver", -# "streamplot", -# # 3D-related (if axes3d) -# "plot3D", -# "scatter3D", -# "bar3d", -# "plot_surface", -# "plot_wireframe", -# # Text and annotations (data-containing) -# "annotate", -# "text", -# } -# if should_track and __method_name__ in tracking_methods: -# # Use the _track method from TrackingMixin -# # If no id provided, it will auto-generate one -# try: -# # Convert args to tracked_dict for consistency with other tracking -# tracked_dict = {"args": args} -# self._track( -# should_track, -# id_value, -# __method_name__, -# tracked_dict, -# kwargs, -# ) -# except AttributeError: -# logger.warning( -# f"Tracking setup incomplete for AxisWrapper ({__method_name__})." -# ) -# except Exception as e: -# # Silently continue if tracking fails to not break plotting -# pass -# return result # Return the result of the original call -# finally: -# # Always decrement depth, even if exception occurs -# self._tracking_depth -= 1 -# -# return wrapper -# else: -# # If it's a non-callable attribute (property, etc.), return it directly -# return orig_attr -# -# # 2. If not found on instance, try the counterpart type (fallback) -# if hasattr(self._counter_part, name): -# counterpart_attr = getattr(self._counter_part, name) -# logger.warning( -# f"SciTeX Axis_MplWrapper: '{name}' not directly handled. " -# f"Falling back to underlying '{self._counter_part.__name__}' attribute." -# ) -# # If the counterpart attribute is callable (likely a method descriptor) -# if callable(counterpart_attr): -# # Return a new function that calls the counterpart method on self._axes_mpl -# @wraps(counterpart_attr) -# def fallback_method(*args, **kwargs): -# # Note: No id/track handling for fallback methods -# return counterpart_attr(self._axes_mpl, *args, **kwargs) -# -# return fallback_method -# else: -# # Non-callable class attribute. Attempt to get from instance again, -# # otherwise return the class attribute/descriptor. -# try: -# return getattr(self._axes_mpl, name) -# except AttributeError: -# return counterpart_attr -# -# # 3. If not found anywhere, raise AttributeError -# raise AttributeError( -# f"'{type(self).__name__}' object and its underlying '{self._counter_part.__name__}' " -# f"have no attribute '{name}'" -# ) -# -# def __dir__(self): -# # Start with attributes from the class and all parent classes (mixins) -# attrs = set() -# -# # Get attributes from all parent classes including mixins -# for cls in self.__class__.__mro__: -# attrs.update(cls.__dict__.keys()) -# -# # Add instance attributes -# attrs.update(self.__dict__.keys()) -# -# # Safely get matplotlib axes attributes -# try: -# # Get attributes from the wrapped matplotlib axes -# if hasattr(self._axes_mpl, "__class__"): -# # Get class methods from matplotlib.axes.Axes -# for cls in self._axes_mpl.__class__.__mro__: -# attrs.update( -# name for name in cls.__dict__.keys() if not name.startswith("_") -# ) -# -# # Add instance attributes of the matplotlib axes -# if hasattr(self._axes_mpl, "__dict__"): -# attrs.update( -# name -# for name in self._axes_mpl.__dict__.keys() -# if not name.startswith("_") -# ) -# -# except Exception: -# # If any error occurs, add common matplotlib methods manually -# attrs.update( -# [ -# "plot", -# "scatter", -# "bar", -# "barh", -# "hist", -# "boxplot", -# "set_xlabel", -# "set_ylabel", -# "set_title", -# "legend", -# "set_xlim", -# "set_ylim", -# "grid", -# "annotate", -# "text", -# ] -# ) -# -# # Remove private attributes -# attrs = {attr for attr in attrs if not attr.startswith("_")} -# -# return sorted(attrs) -# -# def flatten(self): -# """Return a list containing just this axis. -# -# This method makes AxisWrapper compatible with code that calls flatten() -# on an axes collection. It returns a list containing just this single axis -# to maintain consistency with AxesWrapper.flatten(). -# -# Returns: -# list: A list containing this axis wrapper -# -# Example: -# # When working with either AxesWrapper or AxisWrapper, this works: -# axes_list = list(axes.flatten()) -# """ -# return [self] -# -# -# """ -# import matplotlib.pyplot as plt -# import scitex.plt as mplt -# -# fig_scitex, axes = plt.subplots(ncols=2) -# mfig_scitex, maxes = mplt.subplots(ncols=2) -# -# print(set(dir(mfig_scitex)) - set(dir(fig_scitex))) -# print(set(dir(maxes)) - set(dir(axes))) -# -# is_compatible = np.all([kk in set(dir(msubplots)) for kk in set(dir(counter_part))]) -# if is_compatible: -# print(f"{msubplots.__name__} is compatible with {counter_part.__name__}") -# else: -# print(f"{msubplots.__name__} is incompatible with {counter_part.__name__}") -# """ -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_AxisWrapper.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/test__FigWrapper.py b/tests/scitex/plt/_subplots/test__FigWrapper.py deleted file mode 100644 index 424a2c190..000000000 --- a/tests/scitex/plt/_subplots/test__FigWrapper.py +++ /dev/null @@ -1,604 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-03 12:35:50 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/_subplots/test__FigWrapper.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/_subplots/test__FigWrapper.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import pytest - -# class TestFigWrapper: -# def setup_method(self): -# self.fig = plt.figure() -# self.wrapper = FigWrapper(self.fig) - -# def test_init(self): -# assert self.wrapper.fig is self.fig -# assert hasattr(self.wrapper, "axes") -# assert self.wrapper.axes == [] - -# def test_getattr_existing_attribute(self): -# # Test accessing an existing attribute on the figure -# assert hasattr(self.wrapper, "figsize") - -# def test_getattr_existing_method(self): -# # Test accessing an existing method on the figure -# assert callable(self.wrapper.add_subplot) - -# def test_getattr_warning(self): -# # Test attempting to access a non-existent attribute -# with pytest.warns(UserWarning, match="not implemented, ignored"): -# result = self.wrapper.nonexistent_method() -# assert result is None - -# def test_legend(self): -# # Create mock axes -# ax1 = MagicMock() -# ax2 = MagicMock() -# self.wrapper.axes = MagicMock() -# self.wrapper.axes.__iter__ = lambda _: iter([ax1, ax2]) - -# # Call legend -# self.wrapper.legend(loc="upper right") - -# # Check that legend was called on each axis -# ax1.legend.assert_called_once_with(loc="upper right") -# ax2.legend.assert_called_once_with(loc="upper right") - -# def test_export_as_csv_with_empty_axes(self): -# # Test with no axes -# self.wrapper.axes = MagicMock() -# self.wrapper.axes.flat = [] - -# result = self.wrapper.export_as_csv() -# assert isinstance(result, pd.DataFrame) -# assert result.empty - -# def test_export_as_csv_with_data(self): -# # Create mock axes with sigma data -# ax1 = MagicMock() -# ax1.export_as_csv.return_value = pd.DataFrame( -# {"x": [1, 2, 3], "y": [4, 5, 6]} -# ) - -# self.wrapper.axes = MagicMock() -# self.wrapper.axes.flat = [ax1] - -# result = self.wrapper.export_as_csv() -# assert isinstance(result, pd.DataFrame) -# assert not result.empty -# assert "ax_00_x" in result.columns -# assert "ax_00_y" in result.columns - -# def test_supxyt(self): -# # Test supxyt method -# self.wrapper.fig = MagicMock() - -# # Call with x and y labels -# self.wrapper.supxyt(x="X Label", y="Y Label") - -# # Check that appropriate methods were called -# self.wrapper.fig.supxlabel.assert_called_once_with("X Label") -# self.wrapper.fig.supylabel.assert_called_once_with("Y Label") -# self.wrapper.fig.suptitle.assert_not_called() - -# # Reset and test with title -# self.wrapper.fig.reset_mock() -# self.wrapper.supxyt(t="Title") - -# self.wrapper.fig.supxlabel.assert_not_called() -# self.wrapper.fig.supylabel.assert_not_called() -# self.wrapper.fig.suptitle.assert_called_once_with("Title") - -# def test_tight_layout(self): -# # Test tight_layout method -# self.wrapper.fig = MagicMock() - -# # Call with default rect -# self.wrapper.tight_layout() - -# # Check that tight_layout was called with the correct rect -# self.wrapper.fig.tight_layout.assert_called_once_with( -# rect=[0, 0.03, 1, 0.95] -# ) - -# # Reset and test with custom rect -# self.wrapper.fig.reset_mock() -# custom_rect = [0.1, 0.1, 0.9, 0.9] -# self.wrapper.tight_layout(rect=custom_rect) - -# self.wrapper.fig.tight_layout.assert_called_once_with(rect=custom_rect) - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_FigWrapper.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # Timestamp: "2025-05-19 02:53:28 (ywatanabe)" -# # File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/src/scitex/plt/_subplots/_FigWrapper.py.new -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/_subplots/_FigWrapper.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import warnings -# from functools import wraps -# -# import pandas as pd -# -# from scitex import logging -# -# logger = logging.getLogger(__name__) -# -# -# class FigWrapper: -# def __init__(self, fig_mpl): -# self._fig_mpl = fig_mpl -# self._axes = [] # Keep track of axes for synchronization -# self._last_saved_info = None -# self._not_saved_yet_flag = True -# self._called_from_mng_io_save = False -# -# @property -# def figure( -# self, -# ): -# return self._fig_mpl -# -# def __getattr__(self, attr): -# # print(f"Attribute of FigWrapper: {attr}") -# attr_mpl = getattr(self._fig_mpl, attr) -# -# if callable(attr_mpl): -# -# @wraps(attr_mpl) -# def wrapper(*args, track=None, id=None, **kwargs): -# # Suppress constrained_layout warnings for certain operations -# import warnings -# -# with warnings.catch_warnings(): -# if attr in ["subplots_adjust", "tight_layout"]: -# warnings.filterwarnings( -# "ignore", -# message=".*constrained_layout.*", -# category=UserWarning, -# ) -# warnings.filterwarnings( -# "ignore", -# message=".*layout engine.*incompatible.*", -# category=UserWarning, -# ) -# results = attr_mpl(*args, **kwargs) -# # self._track(track, id, attr, args, kwargs) -# return results -# -# return wrapper -# -# else: -# return attr_mpl -# -# def __dir__(self): -# # Combine attributes from both self and the wrapped matplotlib figure -# attrs = set(dir(self.__class__)) -# attrs.update(object.__dir__(self)) -# attrs.update(dir(self._fig_mpl)) -# return sorted(attrs) -# -# def savefig(self, fname, *args, embed_metadata=True, metadata=None, **kwargs): -# """ -# Save figure with automatic metadata embedding. -# -# Parameters -# ---------- -# fname : str -# Output file path -# embed_metadata : bool, optional -# Automatically embed dimension/style metadata in PNG/JPEG/TIFF/PDF (default: True) -# metadata : dict, optional -# Additional custom metadata to merge with auto-collected metadata -# *args, **kwargs -# Passed to scitex.io.save_image or matplotlib savefig -# -# Notes -# ----- -# For PNG/JPEG/TIFF/PDF formats, metadata is automatically embedded including: -# - Software versions (scitex, matplotlib) -# - Timestamp -# - Figure/axes dimensions (mm, inch, px) -# - DPI settings -# - Styling parameters (if available via _scitex_metadata) -# - Mode (display/publication) -# -# For other formats (SVG, etc.), delegates to matplotlib's savefig. -# -# When facecolor is specified (and is not 'none'), axes with transparent -# patches will temporarily have their alpha set to 1.0 to ensure the -# facecolor is visible. -# -# Examples -# -------- -# >>> fig, ax = splt.subplots(fig_mm={'width': 35, 'height': 24.5}) -# >>> ax.plot(x, y) -# >>> fig.savefig('result.png', dpi=300) # Metadata embedded automatically! -# -# >>> # Add custom metadata -# >>> fig.savefig('result.png', dpi=300, metadata={'experiment': 'test_001'}) -# -# >>> # Disable metadata embedding -# >>> fig.savefig('result.png', embed_metadata=False) -# -# >>> # Override transparent background with white -# >>> fig.savefig('result.png', facecolor='white') -# """ -# # Handle facecolor override for transparent figures -# # When facecolor is specified (not 'none'), temporarily make axes and figure opaque -# facecolor = kwargs.get("facecolor", None) -# patches_backup = [] # List of (patch, original_alpha, original_facecolor) -# -# if facecolor is not None: -# # Check if facecolor indicates a non-transparent background -# is_opaque_facecolor = True -# if isinstance(facecolor, str): -# if facecolor.lower() in ("none", "transparent"): -# is_opaque_facecolor = False -# -# if is_opaque_facecolor: -# # Backup and set figure patch to opaque -# fig_patch = self._fig_mpl.patch -# fig_alpha = fig_patch.get_alpha() -# fig_fc = fig_patch.get_facecolor() -# patches_backup.append((fig_patch, fig_alpha, fig_fc)) -# fig_patch.set_alpha(1.0) -# fig_patch.set_facecolor(facecolor) -# -# # Backup and set axes patches to opaque -# for ax_mpl in self._fig_mpl.axes: -# ax_patch = ax_mpl.patch -# original_alpha = ax_patch.get_alpha() -# original_fc = ax_patch.get_facecolor() -# patches_backup.append((ax_patch, original_alpha, original_fc)) -# ax_patch.set_alpha(1.0) -# # Set axes facecolor to match figure facecolor if it was transparent -# if original_alpha == 0.0 or original_alpha is None: -# ax_patch.set_facecolor(facecolor) -# -# # Ensure transparent=False so matplotlib respects the facecolor -# if "transparent" not in kwargs: -# kwargs["transparent"] = False -# # Wrap save logic in try/finally to restore axes alpha -# try: -# # Check if this is a format that can have metadata (PNG/JPEG/TIFF/PDF) -# # Handle both string paths and file-like objects (e.g., BytesIO) -# if isinstance(fname, str): -# is_image_format = fname.lower().endswith( -# (".png", ".jpg", ".jpeg", ".tiff", ".tif", ".pdf") -# ) -# else: -# # For file-like objects, check the 'format' kwarg if provided -# # Otherwise default to False (no metadata embedding for BytesIO etc.) -# fmt = kwargs.get("format", "").lower() if kwargs.get("format") else "" -# is_image_format = fmt in ("png", "jpg", "jpeg", "tiff", "tif", "pdf") -# -# if is_image_format and embed_metadata: -# # Collect automatic metadata -# auto_metadata = None -# -# # Get first axes if available -# # Keep the scitex AxisWrapper (for history tracking) separate from matplotlib axes -# ax = None -# ax_scitex = None # scitex AxisWrapper with history -# if hasattr(self, "axes"): -# try: -# # Try to get first axes from various wrapper types -# if hasattr(self.axes, "_ax"): # AxisWrapper -# ax = self.axes._ax -# ax_scitex = self.axes # Keep the wrapper for history -# elif hasattr(self.axes, "_axis_mpl"): # Alternative -# ax = self.axes._axis_mpl -# ax_scitex = self.axes -# elif hasattr(self.axes, "flatten"): # AxesWrapper -# flat = list(self.axes.flatten()) -# if flat and hasattr(flat[0], "_ax"): -# ax = flat[0]._ax -# ax_scitex = flat[0] # Keep the wrapper for history -# elif flat and hasattr(flat[0], "_axis_mpl"): -# ax = flat[0]._axis_mpl -# ax_scitex = flat[0] -# except Exception: -# pass -# -# # If still no axes, try from figure -# if ( -# ax is None -# and hasattr(self._fig_mpl, "axes") -# and len(self._fig_mpl.axes) > 0 -# ): -# ax = self._fig_mpl.axes[0] -# -# # Collect metadata -# # Pass ax_scitex if available (has history for plot type detection) -# try: -# from scitex.plt.utils import collect_figure_metadata -# -# auto_metadata = collect_figure_metadata( -# self._fig_mpl, ax_scitex if ax_scitex else ax -# ) -# -# # Merge with custom metadata -# if metadata: -# if "custom" not in auto_metadata: -# auto_metadata["custom"] = {} -# auto_metadata["custom"].update(metadata) -# except Exception as e: -# # If metadata collection fails, warn but continue -# logger.warning(f"Could not collect metadata: {e}") -# auto_metadata = metadata -# -# # Use scitex.io.save_image for metadata embedding -# try: -# from scitex.io._save_modules import save_image -# -# save_image( -# self._fig_mpl, fname, metadata=auto_metadata, *args, **kwargs -# ) -# except Exception as e: -# # Fallback to regular matplotlib savefig -# logger.warning( -# f"Metadata embedding failed, using regular savefig: {e}" -# ) -# self._fig_mpl.savefig(fname, *args, **kwargs) -# else: -# # For non-image formats or when metadata disabled, use regular savefig -# self._fig_mpl.savefig(fname, *args, **kwargs) -# finally: -# # Restore patch alpha and facecolor values if they were modified -# for patch, original_alpha, original_fc in patches_backup: -# patch.set_alpha(original_alpha) -# patch.set_facecolor(original_fc) -# -# def export_as_csv(self): -# """Export plotted data from all axes. -# -# This method collects data from all axes in the figure and combines -# them into a single DataFrame with appropriate axis identifiers in -# the column names. -# -# Returns -# ------- -# pd.DataFrame: Combined DataFrame with data from all axes, -# with axis ID prefixes for each column. -# """ -# dfs = [] -# -# # Use the _traverse_axes helper method to iterate through all axes -# # regardless of their structure (single, array, list, etc.) -# for ii, ax in enumerate(self._traverse_axes()): -# # Try different ways to access the export_as_csv method -# df = None -# try: -# if hasattr(ax, "_axis_mpl") and hasattr(ax._axis_mpl, "export_as_csv"): -# # If it's a nested structure with _axis_mpl having export_as_csv -# df = ax._axis_mpl.export_as_csv() -# elif hasattr(ax, "export_as_csv"): -# # Direct AxisWrapper object -# df = ax.export_as_csv() -# else: -# # Skip if no export method available -# continue -# except Exception: -# continue -# -# # Process the DataFrame if it's not empty -# if df is not None and not df.empty: -# # Column names already include axis position via get_csv_column_name -# # (single source of truth from _csv_column_naming.py) -# # Only handle duplicates by adding a counter -# new_cols = [] -# col_counts = {} -# for col in df.columns: -# col_str = str(col) -# -# # Handle duplicates by adding a counter -# if col_str in col_counts: -# col_counts[col_str] += 1 -# col_str = f"{col_str}_{col_counts[col_str]}" -# else: -# col_counts[col_str] = 0 -# -# new_cols.append(col_str) -# -# df.columns = new_cols -# dfs.append(df) -# -# # Return concatenated DataFrame or empty DataFrame if no data -# return pd.concat(dfs, axis=1) if dfs else pd.DataFrame() -# -# def colorbar(self, mappable, ax=None, **kwargs): -# """Add a colorbar to the figure, automatically unwrapping SciTeX axes. -# -# This method properly handles both regular matplotlib axes and SciTeX -# AxisWrapper objects when creating colorbars. -# -# Parameters -# ---------- -# mappable : ScalarMappable -# The image, contour set, etc. to which the colorbar applies -# ax : Axes or AxisWrapper, optional -# The axes to attach the colorbar to. If not specified, uses current axes. -# **kwargs : dict -# Additional keyword arguments passed to matplotlib's colorbar -# -# Returns -# ------- -# Colorbar -# The created colorbar object -# """ -# # Unwrap axes if it's a SciTeX AxisWrapper -# if ax is not None: -# ax_mpl = ax._axis_mpl if hasattr(ax, "_axis_mpl") else ax -# else: -# ax_mpl = None -# -# # Call matplotlib's colorbar with the unwrapped axes -# return self._fig_mpl.colorbar(mappable, ax=ax_mpl, **kwargs) -# -# def _traverse_axes(self): -# """Helper method to traverse all axis wrappers in the figure.""" -# if hasattr(self, "axes"): -# # Check if we're dealing with an AxesWrapper instance -# if hasattr(self.axes, "_axes_scitex") and hasattr( -# self.axes._axes_scitex, "flat" -# ): -# # This is an AxesWrapper, get the individual AxisWrapper objects -# for ax in self.axes._axes_scitex.flat: -# yield ax -# elif not hasattr(self.axes, "__iter__"): -# # Single axis case -# yield self.axes -# else: -# # Multiple axes case -# if hasattr(self.axes, "flat"): -# # 2D array of axes -# for ax in self.axes.flat: -# yield ax -# elif hasattr(self.axes, "ravel"): -# # Numpy array -# for ax in self.axes.ravel(): -# yield ax -# elif isinstance(self.axes, (list, tuple)): -# # List of axes -# for ax in self.axes: -# yield ax -# -# @property -# def history(self): -# """Aggregate tracking history from all axes in the figure. -# -# Returns a combined OrderedDict of all tracking records from all axes, -# enabling FTS bundle creation to build encoding from plot operations. -# """ -# from collections import OrderedDict -# -# combined = OrderedDict() -# for ax in self._traverse_axes(): -# if hasattr(ax, "history") and ax.history: -# combined.update(ax.history) -# return combined -# -# def legend(self, *args, loc="best", **kwargs): -# """Legend with 'best' automatic placement by default for all axes.""" -# for ax in self._traverse_axes(): -# try: -# ax.legend(*args, loc=loc, **kwargs) -# except Exception: -# pass -# -# def supxyt(self, x=False, y=False, t=False): -# """Wrapper for supxlabel, supylabel, and suptitle""" -# if x is not False: -# self._fig_mpl.supxlabel(x) -# if y is not False: -# self._fig_mpl.supylabel(y) -# if t is not False: -# self._fig_mpl.suptitle(t) -# return self._fig_mpl -# -# def tight_layout(self, *, rect=[0, 0.03, 1, 0.95], **kwargs): -# """Wrapper for tight_layout with rect=[0, 0.03, 1, 0.95] by default. -# -# Handles cases where certain axes (like colorbars) are incompatible -# with tight_layout. If the figure is using constrained_layout, this -# method does nothing as constrained_layout handles spacing automatically. -# """ -# # Check if figure is already using constrained_layout -# if ( -# hasattr(self._fig_mpl, "get_constrained_layout") -# and self._fig_mpl.get_constrained_layout() -# ): -# # Figure is using constrained_layout, which handles colorbars better -# # No need to call tight_layout -# return -# -# try: -# with warnings.catch_warnings(): -# # Suppress the specific warning about incompatible axes -# warnings.filterwarnings( -# "ignore", -# message="This figure includes Axes that are not compatible with tight_layout", -# ) -# self._fig_mpl.tight_layout(rect=rect, **kwargs) -# except Exception: -# # If tight_layout fails completely, try constrained_layout as fallback -# try: -# self._fig_mpl.set_constrained_layout(True) -# self._fig_mpl.set_constrained_layout_pads(w_pad=0.04, h_pad=0.04) -# except Exception: -# # If both fail, do nothing - figure will use default layout -# pass -# -# def adjust_layout(self, **kwargs): -# """Adjust the constrained layout parameters. -# -# Parameters -# ---------- -# w_pad : float, optional -# Width padding around axes (default: 0.05) -# h_pad : float, optional -# Height padding around axes (default: 0.05) -# wspace : float, optional -# Width space between subplots (default: 0.02) -# hspace : float, optional -# Height space between subplots (default: 0.02) -# rect : list of 4 floats, optional -# Rectangle in normalized figure coordinates to fit the whole layout -# [left, bottom, right, top] (default: [0, 0, 1, 1]) -# """ -# if ( -# hasattr(self._fig_mpl, "get_constrained_layout") -# and self._fig_mpl.get_constrained_layout() -# ): -# # Update constrained layout parameters -# self._fig_mpl.set_constrained_layout_pads(**kwargs) -# else: -# # Fall back to tight_layout with rect parameter if provided -# if "rect" in kwargs: -# self.tight_layout(rect=kwargs["rect"]) -# -# def close(self): -# """Close the underlying matplotlib figure""" -# import matplotlib.pyplot as plt -# -# plt.close(self._fig_mpl) -# -# @property -# def number(self): -# """Return the figure number for matplotlib.pyplot.close() compatibility""" -# return self._fig_mpl.number -# -# def __del__(self): -# """Cleanup when FigWrapper is deleted""" -# try: -# import matplotlib.pyplot as plt -# -# plt.close(self._fig_mpl) -# except: -# pass -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_FigWrapper.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/test__SubplotsWrapper.py b/tests/scitex/plt/_subplots/test__SubplotsWrapper.py deleted file mode 100755 index 41c358cb7..000000000 --- a/tests/scitex/plt/_subplots/test__SubplotsWrapper.py +++ /dev/null @@ -1,483 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2025-05-29 (ywatanabe)" -# File: /data/gpfs/projects/punim2354/ywatanabe/scitex_repo/tests/scitex/plt/_subplots/test__SubplotsWrapper.py - -import matplotlib -import pytest - -matplotlib.use("Agg") # Use non-interactive backend for testing -import matplotlib.pyplot as plt # noqa: E402 - -import scitex # noqa: E402 - - -class TestSubplotsWrapper: - """Test cases for scitex.plt.subplots wrapper functionality.""" - - def test_single_axis(self): - """Test that single axis returns an AxisWrapper object.""" - fig, ax = scitex.plt.subplots() - assert hasattr(ax, "plot"), "Single axis should have plot method" - assert hasattr(ax, "plot"), "Should have plot method" - scitex.plt.close(fig) - - def test_1d_array_single_row(self): - """Test that single row multiple columns returns 1D array.""" - fig, axes = scitex.plt.subplots(1, 3) - assert hasattr(axes, "__len__"), "Should return array-like object" - assert len(axes) == 3, "Should have 3 axes" - # Test individual axis access - for i in range(3): - assert hasattr(axes[i], "plot"), f"axes[{i}] should have plot method" - scitex.plt.close(fig) - - def test_1d_array_single_column(self): - """Test that multiple rows single column returns 1D array.""" - fig, axes = scitex.plt.subplots(3, 1) - assert hasattr(axes, "__len__"), "Should return array-like object" - assert len(axes) == 3, "Should have 3 axes" - # Test individual axis access - for i in range(3): - assert hasattr(axes[i], "plot"), f"axes[{i}] should have plot method" - scitex.plt.close(fig) - - def test_2d_array_indexing(self): - """Test that 2D grid allows 2D indexing (the main bug fix).""" - fig, axes = scitex.plt.subplots(4, 3) - - # Test shape property - assert hasattr(axes, "shape"), "Should have shape property" - assert axes.shape == (4, 3), "Shape should be (4, 3)" - - # Test 2D indexing - this is the core fix - for row in range(4): - for col in range(3): - ax = axes[row, col] - assert hasattr( - ax, "plot" - ), f"axes[{row}, {col}] should have plot method" - # Test that we can actually plot - ax.plot([1, 2, 3], [1, 2, 3]) - - scitex.plt.close(fig) - - def test_2d_array_row_access(self): - """Test accessing entire rows from 2D array.""" - fig, axes = scitex.plt.subplots(4, 3) - - # Access entire row - row_axes = axes[0] # First row - assert len(row_axes) == 3, "Row should have 3 axes" - - # Each element in row should be plottable - for i, ax in enumerate(row_axes): - assert hasattr(ax, "plot"), f"Row axis [{i}] should have plot method" - - scitex.plt.close(fig) - - def test_2d_array_slice_access(self): - """Test slice access on 2D array.""" - fig, axes = scitex.plt.subplots(4, 3) - - # Access slice of rows - slice_axes = axes[1:3] # Rows 1 and 2 - assert hasattr(slice_axes, "shape"), "Slice should return AxesWrapper" - assert slice_axes.shape == (2, 3), "Slice shape should be (2, 3)" - - scitex.plt.close(fig) - - def test_backward_compatibility_flat_iteration(self): - """Test that flat iteration works via .flat for backward compatibility.""" - fig, axes = scitex.plt.subplots(4, 3) - - # Test flat iteration yields all 12 axes - ax_list = list(axes.flat) - assert len(ax_list) == 12, "Flat iteration should yield 12 axes" - - # Test each axis is plottable - for i, ax in enumerate(axes.flat): - assert hasattr(ax, "plot"), f"Iterated axis {i} should have plot method" - - scitex.plt.close(fig) - - def test_multi_axes_plotting(self): - """Test plotting on multiple axes.""" - fig, axes = scitex.plt.subplots(2, 2) - - # Plot on each axis - axes[0, 0].plot([1, 2, 3], [1, 2, 3], id="plot00") - axes[0, 1].plot([1, 2, 3], [3, 2, 1], id="plot01") - axes[1, 0].plot([1, 2, 3], [2, 3, 4], id="plot10") - axes[1, 1].plot([1, 2, 3], [4, 3, 2], id="plot11") - - # Verify each axis has plotted data - for i in range(2): - for j in range(2): - assert len(axes[i, j].lines) > 0, f"Axis [{i},{j}] should have lines" - - scitex.plt.close(fig) - - def test_matplotlib_compatibility(self): - """Test that the behavior matches matplotlib's for common use cases.""" - # Compare with matplotlib behavior - mpl_fig, mpl_axes = plt.subplots(3, 2) - scitex_fig, scitex_axes = scitex.plt.subplots(3, 2) - - # Both should have the same shape - assert ( - scitex_axes.shape == mpl_axes.shape - ), "Should have same shape as matplotlib" - - # Both should allow 2D indexing - for i in range(3): - for j in range(2): - # This should not raise an error - _ = mpl_axes[i, j] - scitex_ax = scitex_axes[i, j] - assert hasattr(scitex_ax, "plot"), "scitex axis should have plot method" - - plt.close(mpl_fig) - scitex.plt.close(scitex_fig) - - -if __name__ == "__main__": - import os - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_SubplotsWrapper.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# """SubplotsWrapper: Monitor data plotted using matplotlib for CSV export.""" -# -# import os -# from collections import OrderedDict -# -# import matplotlib.pyplot as plt -# -# __FILE__ = "./src/scitex/plt/_subplots/_SubplotsWrapper.py" -# __DIR__ = os.path.dirname(__FILE__) -# -# # Configure fonts at import -# from ._fonts import _arial_enabled # noqa: F401 -# from ._mm_layout import create_with_mm_control -# -# # Register Arial fonts at module import -# import matplotlib.font_manager as fm -# import matplotlib as mpl -# import os -# -# _arial_enabled = False -# -# # Try to find Arial -# try: -# fm.findfont("Arial", fallback_to_default=False) -# _arial_enabled = True -# except Exception: -# # Search for Arial font files and register them -# arial_paths = [ -# f -# for f in fm.findSystemFonts() -# if os.path.basename(f).lower().startswith("arial") -# ] -# -# if arial_paths: -# for path in arial_paths: -# try: -# fm.fontManager.addfont(path) -# except Exception: -# pass -# -# # Verify Arial is now available -# try: -# fm.findfont("Arial", fallback_to_default=False) -# _arial_enabled = True -# except Exception: -# pass -# -# # Configure matplotlib to use Arial if available -# if _arial_enabled: -# mpl.rcParams["font.family"] = "Arial" -# mpl.rcParams["font.sans-serif"] = [ -# "Arial", -# "Helvetica", -# "DejaVu Sans", -# "Liberation Sans", -# ] -# else: -# # Warn about missing Arial -# from scitex import logging as _logging -# -# _logger = _logging.getLogger(__name__) -# _logger.warning( -# "Arial font not found. Using fallback fonts (Helvetica/DejaVu Sans). " -# "For publication figures with Arial: sudo apt-get install ttf-mscorefonts-installer && fc-cache -fv" -# ) -# -# -# class SubplotsWrapper: -# """ -# A wrapper class monitors data plotted using the ax methods from matplotlib.pyplot. -# This data can be converted into a CSV file formatted for SigmaPlot compatibility. -# -# Supports optional figrecipe integration for reproducible figures. -# When figrecipe is available and `use_figrecipe=True`, figures are created -# with recipe recording capability for later reproduction. -# """ -# -# def __init__(self): -# self._subplots_wrapper_history = OrderedDict() -# self._fig_scitex = None -# self._counter_part = plt.subplots -# self._figrecipe_available = None # Lazy check -# -# def _check_figrecipe(self): -# """Check if figrecipe is available (lazy, cached).""" -# if self._figrecipe_available is None: -# try: -# import figrecipe # noqa: F401 -# -# self._figrecipe_available = True -# except ImportError: -# self._figrecipe_available = False -# return self._figrecipe_available -# -# def __call__( -# self, -# *args, -# track=True, -# sharex=False, -# sharey=False, -# constrained_layout=None, -# use_figrecipe=None, # NEW: Enable figrecipe recording -# # MM-control parameters (unified style system) -# axes_width_mm=None, -# axes_height_mm=None, -# margin_left_mm=None, -# margin_right_mm=None, -# margin_bottom_mm=None, -# margin_top_mm=None, -# space_w_mm=None, -# space_h_mm=None, -# axes_thickness_mm=None, -# tick_length_mm=None, -# tick_thickness_mm=None, -# trace_thickness_mm=None, -# marker_size_mm=None, -# axis_font_size_pt=None, -# tick_font_size_pt=None, -# title_font_size_pt=None, -# legend_font_size_pt=None, -# suptitle_font_size_pt=None, -# n_ticks=None, -# mode=None, -# dpi=None, -# styles=None, -# transparent=None, -# theme=None, -# **kwargs, -# ): -# """ -# Create figure and axes with optional millimeter-based control. -# -# Parameters -# ---------- -# *args : int -# nrows, ncols passed to matplotlib.pyplot.subplots -# track : bool, optional -# Track plotting operations for CSV export (default: True) -# use_figrecipe : bool or None, optional -# If True, use figrecipe for recipe recording. -# If None (default), auto-detect figrecipe availability. -# If False, disable figrecipe even if available. -# -# MM-Control Parameters -# --------------------- -# axes_width_mm, axes_height_mm : float or list -# Axes dimensions in mm -# margin_*_mm : float -# Figure margins in mm -# space_w_mm, space_h_mm : float -# Spacing between axes in mm -# mode : str -# 'publication' or 'display' -# -# Returns -# ------- -# fig : FigWrapper -# Wrapped matplotlib Figure (with optional RecordingFigure) -# ax or axes : AxisWrapper or AxesWrapper -# Wrapped matplotlib Axes -# """ -# # Resolve style values -# from scitex.plt.styles import SCITEX_STYLE as _S -# from scitex.plt.styles import resolve_style_value as _resolve -# -# axes_width_mm = _resolve( -# "axes.width_mm", axes_width_mm, _S.get("axes_width_mm") -# ) -# axes_height_mm = _resolve( -# "axes.height_mm", axes_height_mm, _S.get("axes_height_mm") -# ) -# margin_left_mm = _resolve( -# "margins.left_mm", margin_left_mm, _S.get("margin_left_mm") -# ) -# margin_right_mm = _resolve( -# "margins.right_mm", margin_right_mm, _S.get("margin_right_mm") -# ) -# margin_bottom_mm = _resolve( -# "margins.bottom_mm", margin_bottom_mm, _S.get("margin_bottom_mm") -# ) -# margin_top_mm = _resolve( -# "margins.top_mm", margin_top_mm, _S.get("margin_top_mm") -# ) -# space_w_mm = _resolve("spacing.horizontal_mm", space_w_mm, _S.get("space_w_mm")) -# space_h_mm = _resolve("spacing.vertical_mm", space_h_mm, _S.get("space_h_mm")) -# axes_thickness_mm = _resolve( -# "axes.thickness_mm", axes_thickness_mm, _S.get("axes_thickness_mm") -# ) -# tick_length_mm = _resolve( -# "ticks.length_mm", tick_length_mm, _S.get("tick_length_mm") -# ) -# tick_thickness_mm = _resolve( -# "ticks.thickness_mm", tick_thickness_mm, _S.get("tick_thickness_mm") -# ) -# trace_thickness_mm = _resolve( -# "lines.trace_mm", trace_thickness_mm, _S.get("trace_thickness_mm") -# ) -# marker_size_mm = _resolve( -# "markers.size_mm", marker_size_mm, _S.get("marker_size_mm") -# ) -# axis_font_size_pt = _resolve( -# "fonts.axis_label_pt", axis_font_size_pt, _S.get("axis_font_size_pt") -# ) -# tick_font_size_pt = _resolve( -# "fonts.tick_label_pt", tick_font_size_pt, _S.get("tick_font_size_pt") -# ) -# title_font_size_pt = _resolve( -# "fonts.title_pt", title_font_size_pt, _S.get("title_font_size_pt") -# ) -# legend_font_size_pt = _resolve( -# "fonts.legend_pt", legend_font_size_pt, _S.get("legend_font_size_pt") -# ) -# suptitle_font_size_pt = _resolve( -# "fonts.suptitle_pt", suptitle_font_size_pt, _S.get("suptitle_font_size_pt") -# ) -# n_ticks = _resolve("ticks.n_ticks", n_ticks, _S.get("n_ticks"), int) -# dpi = _resolve("output.dpi", dpi, _S.get("dpi"), int) -# -# if transparent is None: -# transparent = _S.get("transparent", True) -# if mode is None: -# mode = _S.get("mode", "publication") -# if theme is None: -# theme = _resolve("theme.mode", None, "light", str) -# -# # Determine figrecipe usage -# if use_figrecipe is None: -# use_figrecipe = self._check_figrecipe() -# -# # Create figure with mm-control -# fig, axes = create_with_mm_control( -# *args, -# track=track, -# sharex=sharex, -# sharey=sharey, -# axes_width_mm=axes_width_mm, -# axes_height_mm=axes_height_mm, -# margin_left_mm=margin_left_mm, -# margin_right_mm=margin_right_mm, -# margin_bottom_mm=margin_bottom_mm, -# margin_top_mm=margin_top_mm, -# space_w_mm=space_w_mm, -# space_h_mm=space_h_mm, -# axes_thickness_mm=axes_thickness_mm, -# tick_length_mm=tick_length_mm, -# tick_thickness_mm=tick_thickness_mm, -# trace_thickness_mm=trace_thickness_mm, -# marker_size_mm=marker_size_mm, -# axis_font_size_pt=axis_font_size_pt, -# tick_font_size_pt=tick_font_size_pt, -# title_font_size_pt=title_font_size_pt, -# legend_font_size_pt=legend_font_size_pt, -# suptitle_font_size_pt=suptitle_font_size_pt, -# n_ticks=n_ticks, -# mode=mode, -# dpi=dpi, -# styles=styles, -# transparent=transparent, -# theme=theme, -# **kwargs, -# ) -# -# # If figrecipe enabled, create recording layer -# if use_figrecipe: -# self._attach_figrecipe_recorder(fig) -# -# self._fig_scitex = fig -# return fig, axes -# -# def _attach_figrecipe_recorder(self, fig_wrapper): -# """Attach figrecipe recorder to FigWrapper for recipe export. -# -# This creates a RecordingFigure layer that wraps the underlying -# matplotlib figure, enabling save_recipe() on the FigWrapper. -# """ -# try: -# from figrecipe._recorder import Recorder -# -# # Get the underlying matplotlib figure -# mpl_fig = fig_wrapper._fig_mpl -# -# # Create recorder -# recorder = Recorder() -# figsize = mpl_fig.get_size_inches() -# dpi_val = mpl_fig.dpi -# recorder.start_figure(figsize=tuple(figsize), dpi=int(dpi_val)) -# -# # Store recorder on FigWrapper for later recipe export -# fig_wrapper._figrecipe_recorder = recorder -# fig_wrapper._figrecipe_enabled = True -# -# # Store style info from scitex in the recipe -# if hasattr(mpl_fig, "_scitex_theme"): -# recorder.figure_record.style = {"theme": mpl_fig._scitex_theme} -# -# except Exception: -# # Silently fail - figrecipe is optional -# fig_wrapper._figrecipe_enabled = False -# -# def __dir__(self): -# """Provide combined directory for tab completion.""" -# local_attrs = set(super().__dir__()) -# try: -# counterpart_attrs = set(dir(self._counter_part)) -# except Exception: -# counterpart_attrs = set() -# return sorted(local_attrs.union(counterpart_attrs)) -# -# -# # Instantiate the wrapper -# subplots = SubplotsWrapper() -# -# -# if __name__ == "__main__": -# import matplotlib -# -# import scitex -# -# matplotlib.use("TkAgg") -# -# fig, ax = subplots() -# ax.plot([1, 2, 3], [4, 5, 6], id="plot1") -# ax.plot([4, 5, 6], [1, 2, 3], id="plot2") -# scitex.io.save(fig, "/tmp/subplots_demo/plots.png") -# -# print(ax.export_as_csv()) -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_SubplotsWrapper.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/test__export_as_csv.py b/tests/scitex/plt/_subplots/test__export_as_csv.py deleted file mode 100644 index 37ea7c3e0..000000000 --- a/tests/scitex/plt/_subplots/test__export_as_csv.py +++ /dev/null @@ -1,1155 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-06-11 03:20:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/_subplots/test__export_as_csv.py -# ---------------------------------------- -"""Comprehensive tests for export_as_csv functionality.""" - -import os -import warnings - -__FILE__ = "./tests/scitex/plt/_subplots/test__export_as_csv.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 06:05:04 (ywatanabe)" -# # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/tests/scitex/plt/_subplots/test__export_as_csv.py -# # ---------------------------------------- -# import os -# __FILE__ = ( -# "./tests/scitex/plt/_subplots/test__export_as_csv.py" -# ) -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- - -from unittest.mock import MagicMock, patch - -import numpy as np -import pandas as pd -import pytest -import xarray as xr - -# Try direct import -try: - from scitex.plt._subplots import _format_imshow2d, export_as_csv, format_record -except ImportError: - # Skip tests if module not properly available - pytest.skip( - "Module scitex.plt._subplots._export_as_csv not available", - allow_module_level=True, - ) - - -class TestExportAsCSV: - """Test suite for export_as_csv function.""" - - def test_empty_history(self): - """Test export with empty history.""" - history = {} - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert result.empty - assert len(w) == 1 - assert "Plotting records not found" in str(w[0].message) - - def test_simple_plot(self): - """Test export with simple plot.""" - history = {"plot1": ("plot1", "plot", ([1, 2, 3], [4, 5, 6]), {})} - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert "plot1_plot_x" in result.columns - assert "plot1_plot_y" in result.columns - assert result["plot1_plot_x"].tolist() == [1, 2, 3] - assert result["plot1_plot_y"].tolist() == [4, 5, 6] - - def test_multiple_plots(self): - """Test export with multiple plots.""" - history = { - "plot1": ("plot1", "plot", ([1, 2, 3], [4, 5, 6]), {}), - "plot2": ("plot2", "plot", ([4, 5, 6], [1, 2, 3]), {}), - } - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert list(result.columns) == [ - "plot1_plot_x", - "plot1_plot_y", - "plot2_plot_x", - "plot2_plot_y", - ] - - def test_export_concat_failure(self): - """Test export when concat fails.""" - # Create a mock that raises exception - with patch("pandas.concat", side_effect=ValueError("Test error")): - history = { - "plot1": ("plot1", "plot", ([1, 2], [3, 4]), {}), - "plot2": ( - "plot2", - "plot", - ([5, 6, 7], [8, 9, 10]), - {}, - ), # Different length - } - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert result.empty - assert len(w) == 1 - assert "Plotting records not combined" in str(w[0].message) - - def test_scatter_plot(self): - """Test export with scatter plot.""" - history = {"scatter1": ("scatter1", "scatter", ([1, 2, 3], [4, 5, 6]), {})} - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert "scatter1_scatter_x" in result.columns - assert "scatter1_scatter_y" in result.columns - assert result["scatter1_scatter_x"].tolist() == [1, 2, 3] - assert result["scatter1_scatter_y"].tolist() == [4, 5, 6] - - def test_bar_plot(self): - """Test export with bar plot.""" - history = {"bar1": ("bar1", "bar", (["A", "B", "C"], [4, 5, 6]), {})} - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert "bar1_bar_x" in result.columns - assert "bar1_bar_y" in result.columns - assert result["bar1_bar_x"].tolist() == ["A", "B", "C"] - assert result["bar1_bar_y"].tolist() == [4, 5, 6] - - def test_bar_plot_with_yerr(self): - """Test export with bar plot including error bars.""" - history = { - "bar1": ( - "bar1", - "bar", - (["A", "B", "C"], [4, 5, 6]), - {"yerr": [0.1, 0.2, 0.3]}, - ) - } - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert "bar1_bar_yerr" in result.columns - assert result["bar1_bar_yerr"].tolist() == [0.1, 0.2, 0.3] - - def test_histogram_plot(self): - """Test export with histogram.""" - history = {"hist1": ("hist1", "hist", [1, 2, 2, 3, 3, 3], {})} - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert "hist1_hist_x" in result.columns - assert result["hist1_hist_x"].tolist() == [1, 2, 2, 3, 3, 3] - - def test_mixed_plot_types(self): - """Test export with mixed plot types.""" - history = { - "plot1": ("plot1", "plot", ([1, 2], [3, 4]), {}), - "scatter1": ("scatter1", "scatter", ([5, 6], [7, 8]), {}), - "bar1": ("bar1", "bar", (["X", "Y"], [9, 10]), {}), - } - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert result.shape[1] == 6 # 2 cols per plot type - assert all( - col in result.columns - for col in [ - "plot1_plot_x", - "plot1_plot_y", - "scatter1_scatter_x", - "scatter1_scatter_y", - "bar1_bar_x", - "bar1_bar_y", - ] - ) - - -class TestFormatRecord: - """Test suite for format_record function.""" - - def test_imshow2d_format(self): - """Test formatting imshow2d data.""" - df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) - record = ("img1", "imshow2d", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - pd.testing.assert_frame_equal(result, df) - - def test_plot_with_single_array(self): - """Test plot formatting with single 2D array.""" - record = ("plot1", "plot", [np.array([[1, 4], [2, 5], [3, 6]])], {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "plot1_plot_x" in result.columns - assert result["plot1_plot_x"].tolist() == [1, 2, 3] - - def test_plot_with_separate_arrays(self): - """Test plot formatting with separate x and y arrays.""" - record = ( - "plot1", - "plot", - [np.array([1, 2, 3]), np.array([4, 5, 6])], - {}, - ) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "plot1_plot_x" in result.columns - assert "plot1_plot_y" in result.columns - assert result["plot1_plot_x"].tolist() == [1, 2, 3] - assert result["plot1_plot_y"].tolist() == [4, 5, 6] - - def test_plot_with_2d_y_array(self): - """Test plot formatting with 2D y array (multiple lines).""" - record = ( - "plot1", - "plot", - [np.array([1, 2, 3]), np.array([[4, 7], [5, 8], [6, 9]])], - {}, - ) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "plot1_plot_x00" in result.columns - assert "plot1_plot_y00" in result.columns - assert "plot1_plot_y01" in result.columns - assert result["plot1_plot_y00"].tolist() == [4, 5, 6] - assert result["plot1_plot_y01"].tolist() == [7, 8, 9] - - def test_plot_with_dataframe_y(self): - """Test plot formatting with DataFrame as y values.""" - y_df = pd.DataFrame({"col1": [4, 5, 6], "col2": [7, 8, 9]}) - record = ("plot1", "plot", [np.array([1, 2, 3]), y_df], {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "plot1_plot_x" in result.columns - assert "plot1_plot_y00" in result.columns - assert "plot1_plot_y01" in result.columns - - def test_plot_with_xarray(self): - """Test plot formatting with xarray DataArray.""" - y_xr = xr.DataArray([[4, 7], [5, 8], [6, 9]], dims=["x", "y"]) - record = ("plot1", "plot", [np.array([1, 2, 3]), y_xr], {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "plot1_plot_x00" in result.columns - assert "plot1_plot_y00" in result.columns - assert "plot1_plot_y01" in result.columns - - def test_plot_with_list_y(self): - """Test plot formatting with list as y values.""" - record = ("plot1", "plot", [np.array([1, 2, 3]), [4, 5, 6]], {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "plot1_plot_x" in result.columns - assert "plot1_plot_y" in result.columns - - def test_bar_with_scalar_values(self): - """Test bar formatting with scalar x and y.""" - record = ("bar1", "bar", (1, 5), {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "bar1_bar_x" in result.columns - assert "bar1_bar_y" in result.columns - assert len(result) == 1 - - def test_bar_with_scalar_yerr(self): - """Test bar formatting with scalar error value.""" - record = ("bar1", "bar", (["A"], [5]), {"yerr": 0.5}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "bar1_bar_yerr" in result.columns - assert result["bar1_bar_yerr"].iloc[0] == 0.5 - - def test_boxplot(self): - """Test boxplot formatting with single box.""" - record = ("box1", "boxplot", [[1, 2, 3, 4, 5]], {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "box1_boxplot_0_x" in result.columns - assert result["box1_boxplot_0_x"].tolist() == [1, 2, 3, 4, 5] - - def test_boxplot_multiple(self): - """Test boxplot formatting with multiple boxes.""" - record = ("box1", "boxplot", [[[1, 2, 3], [4, 5, 6, 7]]], {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "box1_boxplot_0_x" in result.columns - assert "box1_boxplot_1_x" in result.columns - # Check dropna behavior - assert len(result["box1_boxplot_0_x"].dropna()) == 3 - assert len(result["box1_boxplot_1_x"].dropna()) == 4 - - def test_boxplot_with_numpy(self): - """Test boxplot formatting with numpy array.""" - record = ("box1", "boxplot", [np.array([1.5, 2.5, 3.5])], {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert result.shape[1] == 1 - - def test_plot_fillv(self): - """Test plot_fillv formatting.""" - record = ("fill1", "plot_fillv", ([1, 3, 5], [2, 4, 6]), {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "fill1_plot_fillv_starts" in result.columns - assert "fill1_plot_fillv_ends" in result.columns - assert result["fill1_plot_fillv_starts"].tolist() == [1, 3, 5] - assert result["fill1_plot_fillv_ends"].tolist() == [2, 4, 6] - - def test_plot_raster(self): - """Test plot_raster formatting.""" - df = pd.DataFrame({"spike_times": [0.1, 0.5, 1.2]}) - record = ("raster1", "plot_raster", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - pd.testing.assert_frame_equal(result, df) - - def test_plot_ecdf(self): - """Test plot_ecdf formatting.""" - df = pd.DataFrame({"values": [1, 2, 3, 4, 5]}) - record = ("ecdf1", "plot_ecdf", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - pd.testing.assert_frame_equal(result, df) - - def test_plot_kde(self): - """Test plot_kde formatting.""" - df = pd.DataFrame({"density": [0.1, 0.3, 0.5, 0.3, 0.1]}) - record = ("kde1", "plot_kde", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "kde1_plot_kde_density" in result.columns - - def test_plot_kde_no_id(self): - """Test plot_kde formatting without ID.""" - df = pd.DataFrame({"density": [0.1, 0.3, 0.5]}) - record = (None, "plot_kde", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert result.columns[0] == "density" # Original column name preserved - - def test_sns_barplot(self): - """Test seaborn barplot formatting.""" - df = pd.DataFrame({"A": [1, 2, 3], "B": [2, 4, 6], "C": [3, 6, 9]}) - record = ("sns_bar1", "sns_barplot", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert result.shape == (1, 3) # Diagonal values - - def test_sns_boxplot(self): - """Test seaborn boxplot formatting.""" - df = pd.DataFrame({"group1": [1, 2, 3, 4], "group2": [5, 6, 7, 8]}) - record = ("sns_box1", "sns_boxplot", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert "sns_box1_sns_boxplot_group1" in result.columns - assert "sns_box1_sns_boxplot_group2" in result.columns - - def test_sns_boxplot_no_id(self): - """Test seaborn boxplot formatting without ID.""" - df = pd.DataFrame({"data": [1, 2, 3]}) - record = (None, "sns_boxplot", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - assert result.columns[0] == "data" - - def test_sns_heatmap(self): - """Test seaborn heatmap formatting.""" - df = pd.DataFrame(np.random.rand(3, 3)) - record = ("heatmap1", "sns_heatmap", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - pd.testing.assert_frame_equal(result, df) - - def test_sns_histplot(self): - """Test seaborn histplot formatting.""" - df = pd.DataFrame({"values": np.random.randn(100)}) - record = ("hist1", "sns_histplot", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - pd.testing.assert_frame_equal(result, df) - - def test_sns_violinplot(self): - """Test seaborn violinplot formatting.""" - df = pd.DataFrame({"A": np.random.randn(50), "B": np.random.randn(50)}) - record = ("violin1", "sns_violinplot", df, {}) - result = format_record(record) - assert isinstance(result, pd.DataFrame) - pd.testing.assert_frame_equal(result, df) - - def test_unsupported_method(self): - """Test formatting with unsupported method.""" - record = ("unknown1", "unknown_method", [1, 2, 3], {}) - result = format_record(record) - assert result is None - - def test_set_method_ignored(self): - """Test that set_ methods are ignored.""" - record = ("set1", "set_xlabel", ["X Label"], {}) - result = format_record(record) - assert result is None - - -class TestFormatImshow2D: - """Test suite for _format_imshow2d function.""" - - def test_basic_imshow2d(self): - """Test basic imshow2d formatting.""" - df = pd.DataFrame(np.random.rand(5, 5)) - record = ("img1", "imshow2d", df, {}) - result = _format_imshow2d(record) - assert isinstance(result, pd.DataFrame) - pd.testing.assert_frame_equal(result, df) - - def test_imshow2d_preserves_structure(self): - """Test that imshow2d preserves DataFrame structure.""" - df = pd.DataFrame( - np.arange(9).reshape(3, 3), - index=["row1", "row2", "row3"], - columns=["col1", "col2", "col3"], - ) - record = ("img1", "imshow2d", df, {}) - result = _format_imshow2d(record) - pd.testing.assert_frame_equal(result, df) - assert list(result.index) == ["row1", "row2", "row3"] - assert list(result.columns) == ["col1", "col2", "col3"] - - -class TestEdgeCases: - """Test edge cases and error conditions.""" - - def test_none_values_in_history(self): - """Test handling of None values in history.""" - history = { - "plot1": ("plot1", "plot", ([1, 2, None], [4, None, 6]), {}), - } - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert pd.isna(result["plot1_plot_x"].iloc[2]) - assert pd.isna(result["plot1_plot_y"].iloc[1]) - - def test_empty_arrays(self): - """Test handling of empty arrays.""" - history = { - "plot1": ("plot1", "plot", ([], []), {}), - } - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert len(result) == 0 - - def test_mismatched_array_lengths(self): - """Test handling of mismatched array lengths in plots.""" - # This should be handled by the plotting function, but test robustness - record = ("plot1", "plot", ([1, 2, 3], [4, 5]), {}) - # Format record should handle this gracefully - result = format_record(record) - assert isinstance(result, pd.DataFrame) - - def test_unicode_in_labels(self): - """Test handling of unicode characters in labels.""" - history = {"plot1": ("plot1", "bar", (["α", "β", "γ"], [1, 2, 3]), {})} - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert result["plot1_bar_x"].tolist() == ["α", "β", "γ"] - - def test_very_long_ids(self): - """Test handling of very long plot IDs.""" - long_id = "a" * 100 - history = {long_id: (long_id, "plot", ([1, 2], [3, 4]), {})} - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert f"{long_id}_plot_x" in result.columns - - def test_special_characters_in_id(self): - """Test handling of special characters in plot IDs.""" - special_id = "plot-1_test@#$" - history = {special_id: (special_id, "plot", ([1, 2], [3, 4]), {})} - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert f"{special_id}_plot_x" in result.columns - - -class TestWarningSystem: - """Test suite for the improved warning system.""" - - def test_warn_once_single_warning(self): - """Test that _warn_once shows a warning only once.""" - from scitex.plt._subplots._export_as_csv import _warn_once, _warning_registry - - # Clear registry for clean test - _warning_registry.clear() - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - # First call should warn - _warn_once("Test warning message") - assert len(w) == 1 - assert "Test warning message" in str(w[0].message) - - # Second call should NOT warn - _warn_once("Test warning message") - assert len(w) == 1 # Still only 1 warning - - def test_warn_once_different_warnings(self): - """Test that different warnings are shown separately.""" - from scitex.plt._subplots._export_as_csv import _warn_once, _warning_registry - - _warning_registry.clear() - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - _warn_once("First warning") - _warn_once("Second warning") - _warn_once("First warning") # Duplicate, shouldn't show - - assert len(w) == 2 # Only 2 unique warnings - assert "First warning" in str(w[0].message) - assert "Second warning" in str(w[1].message) - - def test_helpful_warning_for_imshow(self): - """Test that imshow without tracking shows helpful warning.""" - from scitex.plt._subplots._export_as_csv import _warning_registry - - _warning_registry.clear() - - # Create history with imshow (no data tracked) - history = {"img1": ("img1", "imshow", {"args": (np.random.rand(10, 10),)}, {})} - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - result = export_as_csv(history) - - # Should get helpful warning - assert len(w) >= 1 - warning_text = str(w[0].message) - assert "imshow" in warning_text - assert "plot_imshow" in warning_text - assert "Consider using" in warning_text - - def test_helpful_warning_for_unknown_method(self): - """Test that unknown methods show generic helpful warning.""" - from scitex.plt._subplots._export_as_csv import _warning_registry - - _warning_registry.clear() - - history = {"unknown1": ("unknown1", "custom_plot", {"args": ([1, 2, 3],)}, {})} - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - result = export_as_csv(history) - - # Should get generic warning - assert len(w) >= 1 - warning_text = str(w[0].message) - assert "custom_plot" in warning_text - assert "scitex plot methods" in warning_text - - def test_no_duplicate_warnings_multiple_axes(self): - """Test that using same method on multiple axes warns only once.""" - from scitex.plt._subplots._export_as_csv import _warning_registry - - _warning_registry.clear() - - # Simulate multiple axes using imshow - history = { - "img1": ("img1", "imshow", {"args": (np.random.rand(5, 5),)}, {}), - "img2": ("img2", "imshow", {"args": (np.random.rand(5, 5),)}, {}), - "img3": ("img3", "imshow", {"args": (np.random.rand(5, 5),)}, {}), - } - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - result = export_as_csv(history) - - # Should only warn once for imshow despite 3 uses - imshow_warnings = [msg for msg in w if "imshow" in str(msg.message)] - assert len(imshow_warnings) == 1 - - def test_method_alternatives_coverage(self): - """Test that _METHOD_ALTERNATIVES includes common methods.""" - from scitex.plt._subplots._export_as_csv import _METHOD_ALTERNATIVES - - # Matplotlib methods - assert "imshow" in _METHOD_ALTERNATIVES - assert "boxplot" in _METHOD_ALTERNATIVES - assert "violinplot" in _METHOD_ALTERNATIVES - assert "fill_between" in _METHOD_ALTERNATIVES - - # Seaborn methods - assert "scatterplot" in _METHOD_ALTERNATIVES - assert "barplot" in _METHOD_ALTERNATIVES - assert "histplot" in _METHOD_ALTERNATIVES - - # Check alternatives are meaningful - assert _METHOD_ALTERNATIVES["imshow"] == "plot_imshow" - assert "sns_" in _METHOD_ALTERNATIVES["scatterplot"] - - -class TestPlotImshowExport: - """Test suite for plot_imshow CSV export functionality.""" - - def test_plot_imshow_basic_export(self): - """Test basic plot_imshow export.""" - # Create sample image data - img_data = np.random.rand(5, 5) - df = pd.DataFrame(img_data, columns=[f"col_{i}" for i in range(5)]) - - history = {"img1": ("img1", "plot_imshow", {"imshow_df": df}, {})} - - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert not result.empty - - # Check column naming - assert any("img1_plot_imshow" in col for col in result.columns) - - def test_plot_imshow_multiple_images(self): - """Test export with multiple plot_imshow calls.""" - img1 = np.random.rand(3, 3) - img2 = np.random.rand(3, 3) - - df1 = pd.DataFrame(img1, columns=[f"col_{i}" for i in range(3)]) - df2 = pd.DataFrame(img2, columns=[f"col_{i}" for i in range(3)]) - - history = { - "img1": ("img1", "plot_imshow", {"imshow_df": df1}, {}), - "img2": ("img2", "plot_imshow", {"imshow_df": df2}, {}), - } - - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert not result.empty - - # Should have columns from both images - img1_cols = [col for col in result.columns if "img1_plot_imshow" in col] - img2_cols = [col for col in result.columns if "img2_plot_imshow" in col] - - assert len(img1_cols) == 3 - assert len(img2_cols) == 3 - - def test_plot_imshow_with_none_id(self): - """Test plot_imshow export without explicit ID.""" - img_data = np.random.rand(4, 4) - df = pd.DataFrame(img_data, columns=[f"col_{i}" for i in range(4)]) - - history = {None: (None, "plot_imshow", {"imshow_df": df}, {})} - - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert not result.empty - - def test_plot_imshow_empty_data(self): - """Test plot_imshow with empty tracked dict.""" - history = {"img1": ("img1", "plot_imshow", {}, {})} - - result = export_as_csv(history) - # Should handle gracefully with empty DataFrame - assert isinstance(result, pd.DataFrame) - - def test_plot_imshow_mixed_with_plot_image(self): - """Test mix of plot_imshow and plot_image in same export.""" - img1 = np.random.rand(3, 3) - img2 = np.random.rand(3, 3) - - df1 = pd.DataFrame(img1, columns=[f"col_{i}" for i in range(3)]) - df2 = pd.DataFrame(img2) - - history = { - "imshow1": ("imshow1", "plot_imshow", {"imshow_df": df1}, {}), - "image1": ("image1", "plot_image", {"image_df": df2}, {}), - } - - result = export_as_csv(history) - assert isinstance(result, pd.DataFrame) - assert not result.empty - - # Should have data from both methods - imshow_cols = [col for col in result.columns if "plot_imshow" in col] - assert len(imshow_cols) > 0 - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-09-21 01:52:22 (ywatanabe)" -# # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/plt/_subplots/_export_as_csv.py -# # ---------------------------------------- -# from __future__ import annotations -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# import pandas as pd -# from scitex.pd import to_xyz -# -# from scitex import logging -# -# logger = logging.getLogger(__name__) -# -# # Global warning registry to track which warnings have been shown -# _warning_registry = set() -# -# # Mapping of matplotlib/seaborn methods to their scitex equivalents -# _METHOD_ALTERNATIVES = { -# # Matplotlib methods -# "imshow": "plot_imshow", -# "plot": "plot", # already tracked -# "scatter": "plot_scatter", # already tracked -# "bar": "plot_bar", # already tracked -# "barh": "plot_barh", # already tracked -# "hist": "hist", # already tracked -# "boxplot": "stx_box or plot_boxplot", -# "violinplot": "stx_violin or plot_violinplot", -# "fill_between": "plot_fill_between", -# "errorbar": "plot_errorbar", -# "contour": "plot_contour", -# "heatmap": "stx_heatmap", -# # Seaborn methods (accessed via ax.sns_*) -# "scatterplot": "sns_scatterplot", -# "lineplot": "sns_lineplot", -# "barplot": "sns_barplot", -# "boxplot_sns": "sns_boxplot", -# "violinplot_sns": "sns_violinplot", -# "stripplot": "sns_stripplot", -# "swarmplot": "sns_swarmplot", -# "histplot": "sns_histplot", -# "kdeplot": "sns_kdeplot", -# "heatmap_sns": "sns_heatmap", -# "jointplot": "sns_jointplot", -# "pairplot": "sns_pairplot", -# } -# -# -# def _warn_once(message, category=UserWarning): -# """Show a warning only once per runtime. -# -# Args: -# message: Warning message to display -# category: Warning category (default: UserWarning) -# """ -# if message not in _warning_registry: -# _warning_registry.add(message) -# logger.warning(message) -# -# -# from ._export_as_csv_formatters import ( -# # Standard matplotlib formatters -# _format_annotate, -# _format_bar, -# _format_barh, -# _format_boxplot, -# _format_contour, -# _format_contourf, -# _format_errorbar, -# _format_eventplot, -# _format_fill, -# _format_fill_between, -# _format_stackplot, -# _format_pcolormesh, -# _format_hexbin, -# _format_hist, -# _format_hist2d, -# _format_imshow, -# _format_imshow2d, -# _format_matshow, -# _format_pie, -# _format_plot, -# _format_quiver, -# _format_scatter, -# _format_stem, -# _format_step, -# _format_streamplot, -# _format_text, -# _format_violin, -# _format_violinplot, -# # Custom scitex formatters -# _format_plot_box, -# _format_plot_conf_mat, -# _format_stx_contour, -# _format_plot_ecdf, -# _format_plot_fillv, -# _format_plot_heatmap, -# _format_plot_image, -# _format_plot_imshow, -# _format_stx_imshow, -# _format_plot_joyplot, -# _format_plot_kde, -# _format_plot_line, -# _format_plot_mean_ci, -# _format_plot_mean_std, -# _format_plot_median_iqr, -# _format_plot_raster, -# _format_plot_rectangle, -# _format_plot_scatter, -# _format_plot_scatter_hist, -# _format_plot_shaded_line, -# _format_plot_violin, -# # stx_ aliases formatters -# _format_stx_scatter, -# _format_stx_bar, -# _format_stx_barh, -# _format_stx_errorbar, -# # Seaborn formatters -# _format_sns_barplot, -# _format_sns_boxplot, -# _format_sns_heatmap, -# _format_sns_histplot, -# _format_sns_jointplot, -# _format_sns_kdeplot, -# _format_sns_lineplot, -# _format_sns_pairplot, -# _format_sns_scatterplot, -# _format_sns_stripplot, -# _format_sns_swarmplot, -# _format_sns_violinplot, -# ) -# -# # Registry mapping method names to their formatter functions -# _FORMATTER_REGISTRY = { -# # Standard matplotlib methods -# "annotate": _format_annotate, -# "bar": _format_bar, -# "barh": _format_barh, -# "boxplot": _format_boxplot, -# "contour": _format_contour, -# "contourf": _format_contourf, -# "errorbar": _format_errorbar, -# "eventplot": _format_eventplot, -# "fill": _format_fill, -# "fill_between": _format_fill_between, -# "stackplot": _format_stackplot, -# "pcolormesh": _format_pcolormesh, -# "pcolor": _format_pcolormesh, -# "hexbin": _format_hexbin, -# "hist": _format_hist, -# "hist2d": _format_hist2d, -# "imshow": _format_imshow, -# "imshow2d": _format_imshow2d, -# "matshow": _format_matshow, -# "pie": _format_pie, -# "plot": _format_plot, -# "quiver": _format_quiver, -# "scatter": _format_scatter, -# "stem": _format_stem, -# "step": _format_step, -# "streamplot": _format_streamplot, -# "text": _format_text, -# "violin": _format_violin, -# "violinplot": _format_violinplot, -# # Custom scitex methods -# "stx_box": _format_plot_box, -# "stx_conf_mat": _format_plot_conf_mat, -# "stx_contour": _format_stx_contour, -# "stx_ecdf": _format_plot_ecdf, -# "stx_fillv": _format_plot_fillv, -# "stx_heatmap": _format_plot_heatmap, -# "stx_image": _format_plot_image, -# "plot_imshow": _format_plot_imshow, -# "stx_imshow": _format_stx_imshow, -# "stx_joyplot": _format_plot_joyplot, -# "stx_kde": _format_plot_kde, -# "stx_line": _format_plot_line, -# "stx_mean_ci": _format_plot_mean_ci, -# "stx_mean_std": _format_plot_mean_std, -# "stx_median_iqr": _format_plot_median_iqr, -# "stx_raster": _format_plot_raster, -# "stx_rectangle": _format_plot_rectangle, -# "plot_scatter": _format_plot_scatter, -# "stx_scatter_hist": _format_plot_scatter_hist, -# "stx_shaded_line": _format_plot_shaded_line, -# "stx_violin": _format_plot_violin, -# # stx_ aliases -# "stx_scatter": _format_stx_scatter, -# "stx_bar": _format_stx_bar, -# "stx_barh": _format_stx_barh, -# "stx_errorbar": _format_stx_errorbar, -# # Seaborn methods (sns_ prefix) -# "sns_barplot": _format_sns_barplot, -# "sns_boxplot": _format_sns_boxplot, -# "sns_heatmap": _format_sns_heatmap, -# "sns_histplot": _format_sns_histplot, -# "sns_jointplot": _format_sns_jointplot, -# "sns_kdeplot": _format_sns_kdeplot, -# "sns_lineplot": _format_sns_lineplot, -# "sns_pairplot": _format_sns_pairplot, -# "sns_scatterplot": _format_sns_scatterplot, -# "sns_stripplot": _format_sns_stripplot, -# "sns_swarmplot": _format_sns_swarmplot, -# "sns_violinplot": _format_sns_violinplot, -# } -# -# -# def _to_numpy(data): -# """Convert various data types to numpy array. -# -# Handles torch tensors, pandas Series/DataFrame, and other array-like objects. -# -# Parameters -# ---------- -# data : array-like -# Data to convert to numpy array -# -# Returns -# ------- -# numpy.ndarray -# Data as numpy array -# """ -# if hasattr(data, "numpy"): # torch tensor -# return data.detach().numpy() if hasattr(data, "detach") else data.numpy() -# elif hasattr(data, "values"): # pandas series/dataframe -# return data.values -# else: -# return np.asarray(data) -# -# -# def export_as_csv(history_records): -# """Convert plotting history records to a combined DataFrame suitable for CSV export. -# -# Args: -# history_records (dict): Dictionary of plotting records. -# -# Returns: -# pd.DataFrame: Combined DataFrame containing all plotting data. -# -# Raises: -# ValueError: If no plotting records are found or they cannot be combined. -# """ -# if len(history_records) <= 0: -# logger.warning("Plotting records not found. Cannot export empty data.") -# return pd.DataFrame() # Return empty DataFrame instead of None -# -# dfs = [] -# failed_methods = set() # Track failed methods for helpful warnings -# -# for record_index, record in enumerate(list(history_records.values())): -# try: -# formatted_df = format_record(record, record_index=record_index) -# if formatted_df is not None and not formatted_df.empty: -# dfs.append(formatted_df) -# else: -# # Track the method that failed to format -# method_name = record[1] if len(record) > 1 else "unknown" -# failed_methods.add(method_name) -# except Exception as e: -# method_name = record[1] if len(record) > 1 else "unknown" -# failed_methods.add(method_name) -# -# # If no valid dataframes were created, provide helpful suggestions -# if not dfs and failed_methods: -# for method in failed_methods: -# if method in _METHOD_ALTERNATIVES: -# alternative = _METHOD_ALTERNATIVES[method] -# message = ( -# f"Matplotlib method '{method}()' does not support full data tracking for CSV export. " -# f"Consider using 'ax.{alternative}()' instead for better data export support." -# ) -# else: -# message = ( -# f"Method '{method}()' does not support data tracking for CSV export. " -# f"Consider using scitex plot methods (e.g., stx_image, plot_imshow) for data export support." -# ) -# _warn_once(message) -# return pd.DataFrame() -# -# try: -# # Reset index for each dataframe to avoid alignment issues -# dfs_reset = [df.reset_index(drop=True) for df in dfs] -# df = pd.concat(dfs_reset, axis=1) -# return df -# except Exception as e: -# logger.warning(f"Failed to combine plotting records: {e}") -# # Return a DataFrame with metadata about what records were attempted -# meta_df = pd.DataFrame( -# { -# "record_id": [r[0] for r in history_records.values()], -# "method": [r[1] for r in history_records.values()], -# "has_data": [ -# "Yes" if r[2] and r[2] != {} else "No" -# for r in history_records.values() -# ], -# } -# ) -# return meta_df -# -# -# def format_record(record, record_index=0): -# """Route record to the appropriate formatting function based on plot method. -# -# Args: -# record (tuple): Plotting record tuple (id, method, tracked_dict, kwargs). -# record_index (int): Index of this record in the history (used as fallback -# for trace_id when user doesn't provide an explicit id= kwarg). -# -# Returns: -# pd.DataFrame: Formatted data for the plot record. -# """ -# id, method, tracked_dict, kwargs = record -# -# # Basic Matplotlib functions -# if method == "plot": -# return _format_plot(id, tracked_dict, kwargs) -# elif method == "scatter": -# return _format_scatter(id, tracked_dict, kwargs) -# elif method == "bar": -# return _format_bar(id, tracked_dict, kwargs) -# elif method == "barh": -# return _format_barh(id, tracked_dict, kwargs) -# elif method == "hist": -# return _format_hist(id, tracked_dict, kwargs) -# elif method == "boxplot": -# return _format_boxplot(id, tracked_dict, kwargs) -# elif method == "contour": -# return _format_contour(id, tracked_dict, kwargs) -# elif method == "contourf": -# return _format_contourf(id, tracked_dict, kwargs) -# elif method == "errorbar": -# return _format_errorbar(id, tracked_dict, kwargs) -# elif method == "eventplot": -# return _format_eventplot(id, tracked_dict, kwargs) -# elif method == "fill": -# return _format_fill(id, tracked_dict, kwargs) -# elif method == "fill_between": -# return _format_fill_between(id, tracked_dict, kwargs) -# elif method == "stackplot": -# return _format_stackplot(id, tracked_dict, kwargs) -# elif method == "pcolormesh": -# return _format_pcolormesh(id, tracked_dict, kwargs) -# elif method == "pcolor": -# return _format_pcolormesh(id, tracked_dict, kwargs) -# elif method == "hexbin": -# return _format_hexbin(id, tracked_dict, kwargs) -# elif method == "hist2d": -# return _format_hist2d(id, tracked_dict, kwargs) -# elif method == "imshow": -# return _format_imshow(id, tracked_dict, kwargs) -# elif method == "imshow2d": -# return _format_imshow2d(id, tracked_dict, kwargs) -# elif method == "matshow": -# return _format_matshow(id, tracked_dict, kwargs) -# elif method == "pie": -# return _format_pie(id, tracked_dict, kwargs) -# elif method == "quiver": -# return _format_quiver(id, tracked_dict, kwargs) -# elif method == "stem": -# return _format_stem(id, tracked_dict, kwargs) -# elif method == "step": -# return _format_step(id, tracked_dict, kwargs) -# elif method == "streamplot": -# return _format_streamplot(id, tracked_dict, kwargs) -# elif method == "violin": -# return _format_violin(id, tracked_dict, kwargs) -# elif method == "violinplot": -# return _format_violinplot(id, tracked_dict, kwargs) -# elif method == "text": -# return _format_text(id, tracked_dict, kwargs) -# elif method == "annotate": -# return _format_annotate(id, tracked_dict, kwargs) -# -# # Custom plotting functions -# elif method == "stx_box": -# return _format_plot_box(id, tracked_dict, kwargs) -# elif method == "stx_conf_mat": -# return _format_plot_conf_mat(id, tracked_dict, kwargs) -# elif method == "stx_contour": -# return _format_stx_contour(id, tracked_dict, kwargs) -# elif method == "stx_ecdf": -# return _format_plot_ecdf(id, tracked_dict, kwargs) -# elif method == "stx_fillv": -# return _format_plot_fillv(id, tracked_dict, kwargs) -# elif method == "stx_heatmap": -# return _format_plot_heatmap(id, tracked_dict, kwargs) -# elif method == "stx_image": -# return _format_plot_image(id, tracked_dict, kwargs) -# elif method == "plot_imshow": -# return _format_plot_imshow(id, tracked_dict, kwargs) -# elif method == "stx_imshow": -# return _format_stx_imshow(id, tracked_dict, kwargs) -# elif method == "stx_joyplot": -# return _format_plot_joyplot(id, tracked_dict, kwargs) -# elif method == "stx_kde": -# return _format_plot_kde(id, tracked_dict, kwargs) -# elif method == "stx_line": -# return _format_plot_line(id, tracked_dict, kwargs) -# elif method == "stx_mean_ci": -# return _format_plot_mean_ci(id, tracked_dict, kwargs) -# elif method == "stx_mean_std": -# return _format_plot_mean_std(id, tracked_dict, kwargs) -# elif method == "stx_median_iqr": -# return _format_plot_median_iqr(id, tracked_dict, kwargs) -# elif method == "stx_raster": -# return _format_plot_raster(id, tracked_dict, kwargs) -# elif method == "stx_rectangle": -# return _format_plot_rectangle(id, tracked_dict, kwargs) -# elif method == "plot_scatter": -# return _format_plot_scatter(id, tracked_dict, kwargs) -# elif method == "stx_scatter_hist": -# return _format_plot_scatter_hist(id, tracked_dict, kwargs) -# elif method == "stx_shaded_line": -# return _format_plot_shaded_line(id, tracked_dict, kwargs) -# elif method == "stx_violin": -# return _format_plot_violin(id, tracked_dict, kwargs) -# -# # stx_ aliases -# elif method == "stx_scatter": -# return _format_stx_scatter(id, tracked_dict, kwargs) -# elif method == "stx_bar": -# return _format_stx_bar(id, tracked_dict, kwargs) -# elif method == "stx_barh": -# return _format_stx_barh(id, tracked_dict, kwargs) -# elif method == "stx_errorbar": -# return _format_stx_errorbar(id, tracked_dict, kwargs) -# -# # Seaborn functions (sns_ prefix) -# elif method == "sns_barplot": -# return _format_sns_barplot(id, tracked_dict, kwargs) -# elif method == "sns_boxplot": -# return _format_sns_boxplot(id, tracked_dict, kwargs) -# elif method == "sns_heatmap": -# return _format_sns_heatmap(id, tracked_dict, kwargs) -# elif method == "sns_histplot": -# return _format_sns_histplot(id, tracked_dict, kwargs) -# elif method == "sns_jointplot": -# return _format_sns_jointplot(id, tracked_dict, kwargs) -# elif method == "sns_kdeplot": -# return _format_sns_kdeplot(id, tracked_dict, kwargs) -# elif method == "sns_lineplot": -# return _format_sns_lineplot(id, tracked_dict, kwargs) -# elif method == "sns_pairplot": -# return _format_sns_pairplot(id, tracked_dict, kwargs) -# elif method == "sns_scatterplot": -# return _format_sns_scatterplot(id, tracked_dict, kwargs) -# elif method == "sns_stripplot": -# return _format_sns_stripplot(id, tracked_dict, kwargs) -# elif method == "sns_swarmplot": -# return _format_sns_swarmplot(id, tracked_dict, kwargs) -# elif method == "sns_violinplot": -# return _format_sns_violinplot(id, tracked_dict, kwargs) -# else: -# # Unknown or unimplemented method -# raise NotImplementedError( -# f"CSV export for plot method '{method}' is not yet implemented in the scitex.plt module. " -# f"Check the feature-request-export-as-csv-functions.md for implementation status." -# ) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_export_as_csv.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/test__fonts.py b/tests/scitex/plt/_subplots/test__fonts.py deleted file mode 100644 index 460f13016..000000000 --- a/tests/scitex/plt/_subplots/test__fonts.py +++ /dev/null @@ -1,87 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_fonts.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# """Font configuration for matplotlib figures.""" -# -# import os -# -# import matplotlib as mpl -# import matplotlib.font_manager as fm -# -# -# def configure_arial_font(): -# """Configure Arial font for matplotlib if available. -# -# Returns -# ------- -# bool -# True if Arial was successfully configured, False otherwise. -# """ -# arial_enabled = False -# -# # Try to find Arial -# try: -# fm.findfont("Arial", fallback_to_default=False) -# arial_enabled = True -# except Exception: -# # Search for Arial font files and register them -# arial_paths = [ -# f -# for f in fm.findSystemFonts() -# if os.path.basename(f).lower().startswith("arial") -# ] -# -# if arial_paths: -# for path in arial_paths: -# try: -# fm.fontManager.addfont(path) -# except Exception: -# pass -# -# # Verify Arial is now available -# try: -# fm.findfont("Arial", fallback_to_default=False) -# arial_enabled = True -# except Exception: -# pass -# -# # Configure matplotlib to use Arial if available -# if arial_enabled: -# mpl.rcParams["font.family"] = "Arial" -# mpl.rcParams["font.sans-serif"] = [ -# "Arial", -# "Helvetica", -# "DejaVu Sans", -# "Liberation Sans", -# ] -# else: -# # Warn about missing Arial -# from scitex import logging as _logging -# -# _logger = _logging.getLogger(__name__) -# _logger.warning( -# "Arial font not found. Using fallback fonts (Helvetica/DejaVu Sans). " -# "For publication figures with Arial: sudo apt-get install ttf-mscorefonts-installer && fc-cache -fv" -# ) -# -# return arial_enabled -# -# -# # Configure fonts at module import -# _arial_enabled = configure_arial_font() -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_fonts.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/_subplots/test__mm_layout.py b/tests/scitex/plt/_subplots/test__mm_layout.py deleted file mode 100644 index 3cd676b7a..000000000 --- a/tests/scitex/plt/_subplots/test__mm_layout.py +++ /dev/null @@ -1,298 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_mm_layout.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# """Millimeter-based layout control for matplotlib figures.""" -# -# import matplotlib.pyplot as plt -# import numpy as np -# -# from ._AxesWrapper import AxesWrapper -# from ._AxisWrapper import AxisWrapper -# from ._FigWrapper import FigWrapper -# -# -# def create_with_mm_control( -# *args, -# track=True, -# sharex=False, -# sharey=False, -# axes_width_mm=None, -# axes_height_mm=None, -# margin_left_mm=None, -# margin_right_mm=None, -# margin_bottom_mm=None, -# margin_top_mm=None, -# space_w_mm=None, -# space_h_mm=None, -# axes_thickness_mm=None, -# tick_length_mm=None, -# tick_thickness_mm=None, -# trace_thickness_mm=None, -# marker_size_mm=None, -# axis_font_size_pt=None, -# tick_font_size_pt=None, -# title_font_size_pt=None, -# legend_font_size_pt=None, -# suptitle_font_size_pt=None, -# label_pad_pt=None, -# tick_pad_pt=None, -# title_pad_pt=None, -# font_family=None, -# n_ticks=None, -# mode=None, -# dpi=None, -# styles=None, -# transparent=None, -# theme=None, -# **kwargs, -# ): -# """Create figure with mm-based control over axes dimensions. -# -# Returns -# ------- -# tuple -# (FigWrapper, AxisWrapper or AxesWrapper) -# """ -# from scitex.plt.utils import apply_style_mm, mm_to_inch -# -# # Parse nrows, ncols from args or kwargs -# nrows, ncols = 1, 1 -# if len(args) >= 1: -# nrows = args[0] -# elif "nrows" in kwargs: -# nrows = kwargs.pop("nrows") -# if len(args) >= 2: -# ncols = args[1] -# elif "ncols" in kwargs: -# ncols = kwargs.pop("ncols") -# -# n_axes = nrows * ncols -# -# # Apply mode-specific defaults -# if mode == "display": -# scale_factor = 3.0 -# dpi = dpi or 100 -# else: -# scale_factor = 1.0 -# dpi = dpi or 300 -# -# # Set defaults with scaling -# if axes_width_mm is None: -# axes_width_mm = 30.0 * scale_factor -# elif mode == "display": -# axes_width_mm = axes_width_mm * scale_factor -# -# if axes_height_mm is None: -# axes_height_mm = 21.0 * scale_factor -# elif mode == "display": -# axes_height_mm = axes_height_mm * scale_factor -# -# margin_left_mm = ( -# margin_left_mm if margin_left_mm is not None else (5.0 * scale_factor) -# ) -# margin_right_mm = ( -# margin_right_mm if margin_right_mm is not None else (2.0 * scale_factor) -# ) -# margin_bottom_mm = ( -# margin_bottom_mm if margin_bottom_mm is not None else (5.0 * scale_factor) -# ) -# margin_top_mm = margin_top_mm if margin_top_mm is not None else (2.0 * scale_factor) -# space_w_mm = space_w_mm if space_w_mm is not None else (3.0 * scale_factor) -# space_h_mm = space_h_mm if space_h_mm is not None else (3.0 * scale_factor) -# -# # Handle list vs scalar for axes dimensions -# if isinstance(axes_width_mm, (list, tuple)): -# ax_widths_mm = list(axes_width_mm) -# if len(ax_widths_mm) != n_axes: -# raise ValueError( -# f"axes_width_mm list length ({len(ax_widths_mm)}) " -# f"must match nrows*ncols ({n_axes})" -# ) -# else: -# ax_widths_mm = [axes_width_mm] * n_axes -# -# if isinstance(axes_height_mm, (list, tuple)): -# ax_heights_mm = list(axes_height_mm) -# if len(ax_heights_mm) != n_axes: -# raise ValueError( -# f"axes_height_mm list length ({len(ax_heights_mm)}) " -# f"must match nrows*ncols ({n_axes})" -# ) -# else: -# ax_heights_mm = [axes_height_mm] * n_axes -# -# # Calculate figure size from axes grid -# ax_widths_2d = np.array(ax_widths_mm).reshape(nrows, ncols) -# ax_heights_2d = np.array(ax_heights_mm).reshape(nrows, ncols) -# -# max_widths_per_col = ax_widths_2d.max(axis=0) -# max_heights_per_row = ax_heights_2d.max(axis=1) -# -# total_width_mm = ( -# margin_left_mm -# + max_widths_per_col.sum() -# + (ncols - 1) * space_w_mm -# + margin_right_mm -# ) -# total_height_mm = ( -# margin_bottom_mm -# + max_heights_per_row.sum() -# + (nrows - 1) * space_h_mm -# + margin_top_mm -# ) -# -# # Create figure -# figsize_inch = (mm_to_inch(total_width_mm), mm_to_inch(total_height_mm)) -# if transparent: -# fig_mpl = plt.figure(figsize=figsize_inch, dpi=dpi, facecolor="none") -# else: -# fig_mpl = plt.figure(figsize=figsize_inch, dpi=dpi) -# -# # Store theme on figure -# if theme is not None: -# fig_mpl._scitex_theme = theme -# -# # Create axes array and position each one manually -# axes_mpl_list = [] -# ax_idx = 0 -# -# for row in range(nrows): -# for col in range(ncols): -# # Calculate position -# left_mm = margin_left_mm + max_widths_per_col[:col].sum() + col * space_w_mm -# rows_below = nrows - row - 1 -# bottom_mm = ( -# margin_bottom_mm -# + max_heights_per_row[row + 1 :].sum() -# + rows_below * space_h_mm -# ) -# -# # Convert to figure coordinates [0-1] -# left = left_mm / total_width_mm -# bottom = bottom_mm / total_height_mm -# width = ax_widths_mm[ax_idx] / total_width_mm -# height = ax_heights_mm[ax_idx] / total_height_mm -# -# # Create axes -# ax_mpl = fig_mpl.add_axes([left, bottom, width, height]) -# if transparent: -# ax_mpl.patch.set_alpha(0.0) -# axes_mpl_list.append(ax_mpl) -# -# # Tag with metadata -# ax_mpl._scitex_metadata = { -# "created_with": "scitex.plt.subplots", -# "mode": mode or "publication", -# "axes_size_mm": (ax_widths_mm[ax_idx], ax_heights_mm[ax_idx]), -# "position_in_grid": (row, col), -# } -# ax_idx += 1 -# -# # Apply styling to each axes -# suptitle_font_size_pt_value = None -# for i, ax_mpl in enumerate(axes_mpl_list): -# # Determine which style dict to use -# if styles is not None: -# if isinstance(styles, list): -# if len(styles) != n_axes: -# raise ValueError( -# f"styles list length ({len(styles)}) " -# f"must match nrows*ncols ({n_axes})" -# ) -# style_dict = styles[i] -# else: -# style_dict = styles -# else: -# # Build style dict from individual parameters -# style_dict = {} -# if axes_thickness_mm is not None: -# style_dict["axis_thickness_mm"] = axes_thickness_mm -# if tick_length_mm is not None: -# style_dict["tick_length_mm"] = tick_length_mm -# if tick_thickness_mm is not None: -# style_dict["tick_thickness_mm"] = tick_thickness_mm -# if trace_thickness_mm is not None: -# style_dict["trace_thickness_mm"] = trace_thickness_mm -# if marker_size_mm is not None: -# style_dict["marker_size_mm"] = marker_size_mm -# if axis_font_size_pt is not None: -# style_dict["axis_font_size_pt"] = axis_font_size_pt -# if tick_font_size_pt is not None: -# style_dict["tick_font_size_pt"] = tick_font_size_pt -# if title_font_size_pt is not None: -# style_dict["title_font_size_pt"] = title_font_size_pt -# if legend_font_size_pt is not None: -# style_dict["legend_font_size_pt"] = legend_font_size_pt -# if suptitle_font_size_pt is not None: -# style_dict["suptitle_font_size_pt"] = suptitle_font_size_pt -# if label_pad_pt is not None: -# style_dict["label_pad_pt"] = label_pad_pt -# if tick_pad_pt is not None: -# style_dict["tick_pad_pt"] = tick_pad_pt -# if title_pad_pt is not None: -# style_dict["title_pad_pt"] = title_pad_pt -# if font_family is not None: -# style_dict["font_family"] = font_family -# if n_ticks is not None: -# style_dict["n_ticks"] = n_ticks -# -# # Always add theme to style_dict -# if theme is not None: -# style_dict["theme"] = theme -# -# # Extract suptitle font size if available -# if "suptitle_font_size_pt" in style_dict: -# suptitle_font_size_pt_value = style_dict["suptitle_font_size_pt"] -# -# # Apply style if not empty -# if style_dict: -# apply_style_mm(ax_mpl, style_dict) -# ax_mpl._scitex_metadata["style_mm"] = style_dict -# -# # Store suptitle font size in figure metadata -# if suptitle_font_size_pt_value is not None: -# fig_mpl._scitex_suptitle_font_size_pt = suptitle_font_size_pt_value -# -# # Wrap the figure -# fig_scitex = FigWrapper(fig_mpl) -# -# # Reshape axes list -# axes_array_mpl = np.array(axes_mpl_list).reshape(nrows, ncols) -# -# # Handle single axis case -# if n_axes == 1: -# ax_mpl_scalar = axes_array_mpl.item() -# axis_scitex = AxisWrapper(fig_scitex, ax_mpl_scalar, track) -# fig_scitex.axes = [axis_scitex] -# ax_mpl_scalar._scitex_wrapper = axis_scitex -# return fig_scitex, axis_scitex -# -# # Handle multiple axes case -# axes_flat_scitex_list = [] -# for ax_mpl in axes_mpl_list: -# ax_scitex = AxisWrapper(fig_scitex, ax_mpl, track) -# ax_mpl._scitex_wrapper = ax_scitex -# axes_flat_scitex_list.append(ax_scitex) -# -# axes_array_scitex = np.array(axes_flat_scitex_list).reshape(nrows, ncols) -# axes_scitex = AxesWrapper(fig_scitex, axes_array_scitex) -# fig_scitex.axes = axes_scitex -# -# return fig_scitex, axes_scitex -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/_subplots/_mm_layout.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__add_fitted_line.py b/tests/scitex/plt/ax/_plot/test__add_fitted_line.py deleted file mode 100644 index 447dd432a..000000000 --- a/tests/scitex/plt/ax/_plot/test__add_fitted_line.py +++ /dev/null @@ -1,168 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_add_fitted_line.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-11-19 15:52:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_plot/_add_fitted_line.py -# -# """ -# Add fitted regression line to scatter plots. -# """ -# -# import numpy as np -# from typing import Optional, Tuple, Dict -# -# -# def add_fitted_line( -# ax, -# x, -# y, -# color: str = "black", -# linestyle: str = "--", -# linewidth_mm: float = 0.2, -# label: Optional[str] = None, -# degree: int = 1, -# show_stats: bool = True, -# stats_position: float = 0.75, -# stats_fontsize: int = 6, -# ) -> Tuple: -# """ -# Add a fitted polynomial line to a scatter plot with optional R² and p-value. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# Axes to plot on -# x : array-like -# X data -# y : array-like -# Y data -# color : str, optional -# Line color (default: 'black') -# linestyle : str, optional -# Line style (default: '--' for dashed) -# linewidth_mm : float, optional -# Line thickness in millimeters (default: 0.2mm) -# label : str, optional -# Label for the fitted line (default: None) -# degree : int, optional -# Polynomial degree for fitting (default: 1 for linear) -# show_stats : bool, optional -# Whether to display R² and p-value near the line (default: True) -# Only applicable for linear fits (degree=1) -# stats_position : float, optional -# Position along x-axis (0-1 scale) for stats text (default: 0.75) -# stats_fontsize : int, optional -# Font size for statistics text in points (default: 6) -# -# Returns -# ------- -# line : Line2D -# The fitted line object -# coeffs : np.ndarray -# Polynomial coefficients from np.polyfit -# stats : StatResult or None -# StatResult instance with correlation statistics (only for degree=1). -# Use .to_dict() for dictionary format. -# -# Examples -# -------- -# >>> fig, ax = stx.plt.subplots(**stx.plt.presets.SCITEX_STYLE) -# >>> scatter = ax.scatter(x, y) -# >>> stx.plt.ax.add_fitted_line(ax, x, y) # Auto-shows R² and p -# -# >>> # Without statistics -# >>> line, coeffs, stats = stx.plt.ax.add_fitted_line( -# ... ax, x, y, show_stats=False -# ... ) -# -# >>> # Custom position for stats -# >>> line, coeffs, stats = stx.plt.ax.add_fitted_line( -# ... ax, x, y, stats_position=0.5 -# ... ) -# """ -# from scitex.plt.utils import mm_to_pt -# -# # Convert data to numpy arrays -# x = np.asarray(x) -# y = np.asarray(y) -# -# # Fit polynomial -# coeffs = np.polyfit(x, y, degree) -# poly_fn = np.poly1d(coeffs) -# -# # Generate fitted line points -# x_fit = np.linspace(x.min(), x.max(), 100) -# y_fit = poly_fn(x_fit) -# -# # Convert linewidth to points -# lw_pt = mm_to_pt(linewidth_mm) -# -# # Plot fitted line -# line = ax.plot( -# x_fit, -# y_fit, -# color=color, -# linestyle=linestyle, -# linewidth=lw_pt, -# label=label, -# )[0] -# -# # Calculate and display statistics for linear regression (degree=1) -# stats_result = None -# if degree == 1 and show_stats: -# # Import scitex.stats correlation test -# from scitex.stats.tests.correlation import test_pearson -# -# # Calculate correlation statistics using scitex.stats -# stats_result = test_pearson(x, y) -# -# # Position for text annotation -# x_pos = x.min() + stats_position * (x.max() - x.min()) -# y_pos = poly_fn(x_pos) -# -# # Format statistics text with R² and significance stars -# r_squared = stats_result.effect_size["value"] # r_squared from effect_size -# stars = stats_result.stars -# -# if stars and stars != "ns": # Only show if significant -# stats_text = f"$R^2$ = {r_squared:.3f}{stars}" -# else: # Not significant -# stats_text = f"$R^2$ = {r_squared:.3f} (ns)" -# -# # Add text annotation near the line -# ax.text( -# x_pos, -# y_pos, -# stats_text, -# verticalalignment="bottom", -# fontsize=stats_fontsize, -# ) -# -# # Store stats in axes metadata for embedding in saved figures -# if not hasattr(ax, "_scitex_metadata"): -# ax._scitex_metadata = {} -# if "stats" not in ax._scitex_metadata: -# ax._scitex_metadata["stats"] = [] -# -# # Add this StatResult to the stats list -# ax._scitex_metadata["stats"].append(stats_result.to_dict()) -# -# return line, coeffs, stats_result -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_add_fitted_line.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__plot_circular_hist.py b/tests/scitex/plt/ax/_plot/test__plot_circular_hist.py deleted file mode 100644 index 5505ac8f1..000000000 --- a/tests/scitex/plt/ax/_plot/test__plot_circular_hist.py +++ /dev/null @@ -1,399 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-03 15:13:17 (ywatanabe)" -# File: /home/ywatanabe/proj/_scitex_repo/tests/scitex/plt/ax/_plot/test__plot_circular_hist.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_plot/test__plot_circular_hist.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._plot import plot_circular_hist - -matplotlib.use("Agg") - - -class TestPlotCircularHist: - def setup_method(self): - # Setup test fixtures - polar axes required for circular histogram - self.fig = plt.figure() - self.ax = self.fig.add_subplot(111, projection="polar") - # Create sample radians data (0 to 2pi) - self.rads = np.random.uniform(0, 2 * np.pi, 1000) - # Create output directory if it doesn't exist - self.out_dir = __file__.replace(".py", "_out") - os.makedirs(self.out_dir, exist_ok=True) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def save_test_figure(self, method_name): - """Helper method to save figure using method name""" - from scitex.io import save - - spath = f".{method_name}.jpg" - save(self.fig, spath) - # Check saved file - actual_spath = os.path.join(self.out_dir, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - def test_basic_functionality(self): - # Test with default parameters - n, bins, patches = plot_circular_hist(self.ax, self.rads) - self.ax.set_title("Basic Circular Histogram") - - # Save figure - self.save_test_figure("test_basic_functionality") - - # Check return values - assert isinstance(n, np.ndarray) - assert isinstance(bins, np.ndarray) - assert len(n) == len(bins) - 1 - assert len(n) == 16 # Default bin count - # Check that patches were added to the plot - assert len(patches) == 16 - - def test_with_custom_bins(self): - # Test with custom number of bins - bin_count = 24 - n, bins, patches = plot_circular_hist(self.ax, self.rads, bins=bin_count) - self.ax.set_title("Circular Histogram with Custom Bins") - - # Save figure - self.save_test_figure("test_with_custom_bins") - - # Check correct number of bins - assert len(n) == bin_count - assert len(patches) == bin_count - - def test_with_no_gaps(self): - # Test with gaps=False - n, bins, patches = plot_circular_hist(self.ax, self.rads, gaps=False) - self.ax.set_title("Circular Histogram with No Gaps") - - # Save figure - self.save_test_figure("test_with_no_gaps") - - # Check that bins span the entire circle - assert np.isclose(bins[0], -np.pi) - assert np.isclose(bins[-1], np.pi) - - def test_with_custom_color(self): - # Test with custom color - color = "red" - n, bins, patches = plot_circular_hist(self.ax, self.rads, color=color) - self.ax.set_title("Circular Histogram with Custom Color") - - # Save figure - self.save_test_figure("test_with_custom_color") - - # Check that patches have the correct color - for patch in patches: - assert patch.get_edgecolor()[0:3] == matplotlib.colors.to_rgb(color) - - def test_with_non_density(self): - # Test with density=False - n, bins, patches = plot_circular_hist(self.ax, self.rads, density=False) - self.ax.set_title("Circular Histogram with Non-Density") - - # Save figure - self.save_test_figure("test_with_non_density") - - # Check that y-ticks are visible - assert len(self.ax.get_yticks()) > 0 - - def test_with_offset(self): - # Test with custom offset - offset = np.pi / 4 # 45 degrees - n, bins, patches = plot_circular_hist(self.ax, self.rads, offset=offset) - self.ax.set_title("Circular Histogram with Offset") - - # Save figure - self.save_test_figure("test_with_offset") - - # Check that theta offset was set correctly - assert self.ax.get_theta_offset() == offset - - def test_with_range_bias(self): - # Test with range_bias - range_bias = 0.5 - n, bins, patches = plot_circular_hist(self.ax, self.rads, range_bias=range_bias) - self.ax.set_title("Circular Histogram with Range Bias") - - # Save figure - self.save_test_figure("test_with_range_bias") - - # Check that histogram is biased as expected - assert np.isclose(bins[0], -np.pi + range_bias, atol=1e-5) - - def test_plot_circular_hist_savefig(self): - n, bins, patches = plot_circular_hist(self.ax, self.rads, color="green") - self.ax.set_title("Circular Histogram Test") - - # Saving - from scitex.io import save - - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -# class TestMainFunctionality: -# def setup_method(self): -# # Setup test fixtures - polar axes required for circular histogram -# self.fig = plt.figure() -# self.ax = self.fig.add_subplot(111, projection="polar") - -# # Create sample radians data (0 to 2pi) -# self.rads = np.random.uniform(0, 2 * np.pi, 1000) - -# def teardown_method(self): -# # Clean up after tests -# plt.close(self.fig) - -# def test_basic_functionality(self): -# # Test with default parameters -# n, bins, patches = plot_circular_hist(self.ax, self.rads) - -# # Check return values -# assert isinstance(n, np.ndarray) -# assert isinstance(bins, np.ndarray) -# assert len(n) == len(bins) - 1 -# assert len(n) == 16 # Default bin count - -# # Check that patches were added to the plot -# assert len(patches) == 16 - -# def test_with_custom_bins(self): -# # Test with custom number of bins -# bin_count = 24 -# n, bins, patches = plot_circular_hist( -# self.ax, self.rads, bins=bin_count -# ) - -# # Check correct number of bins -# assert len(n) == bin_count -# assert len(patches) == bin_count - -# def test_with_no_gaps(self): -# # Test with gaps=False -# n, bins, patches = plot_circular_hist(self.ax, self.rads, gaps=False) - -# # Check that bins span the entire circle -# assert np.isclose(bins[0], -np.pi) -# assert np.isclose(bins[-1], np.pi) - -# def test_with_custom_color(self): -# # Test with custom color -# color = "red" -# n, bins, patches = plot_circular_hist(self.ax, self.rads, color=color) - -# # Check that patches have the correct color -# for patch in patches: -# assert patch.get_edgecolor()[0:3] == matplotlib.colors.to_rgb( -# color -# ) - -# def test_with_non_density(self): -# # Test with density=False -# n, bins, patches = plot_circular_hist( -# self.ax, self.rads, density=False -# ) - -# # Check that y-ticks are visible -# assert len(self.ax.get_yticks()) > 0 - -# def test_with_offset(self): -# # Test with custom offset -# offset = np.pi / 4 # 45 degrees -# n, bins, patches = plot_circular_hist( -# self.ax, self.rads, offset=offset -# ) - -# # Check that theta offset was set correctly -# assert self.ax.get_theta_offset() == offset - -# def test_with_range_bias(self): -# # Test with range_bias -# range_bias = 0.5 -# n, bins, patches = plot_circular_hist( -# self.ax, self.rads, range_bias=range_bias -# ) - -# # Check that histogram is biased as expected -# assert np.isclose(bins[0], -np.pi + range_bias, atol=1e-5) - -# def test_plot_circular_hist_savefig(self): - -# # fig = plt.figure() -# # ax = fig.add_subplot(111, projection="polar") -# # rads = np.random.uniform(0, 2 * np.pi, 1000) -# n, bins, patches = plot_circular_hist( -# self.ax, self.rads, color="green" -# ) - -# # Saving -# from scitex.io import save - -# spath = f"./{os.path.basename(__file__)}.jpg" -# save(self.fig, spath) - -# # Check saved file -# ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") -# actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) -# assert os.path.exists( -# actual_spath -# ), f"Failed to save figure to {spath}" - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_plot_circular_hist.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-01 15:21:48 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_circular_hist.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot/_plot_circular_hist.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# # Time-stamp: "2024-02-03 13:10:50 (ywatanabe)" -# import matplotlib -# import numpy as np -# from ....plt.utils import assert_valid_axis -# -# -# def plot_circular_hist( -# axis, -# radians, -# bins=16, -# density=True, -# offset=0, -# gaps=True, -# color=None, -# range_bias=0, -# ): -# """ -# Example: -# fig, ax = plt.subplots(subplot_kw=dict(projection="polar")) -# ax = scitex.plt.plot_circular_hist(ax, radians) -# Produce a circular histogram of angles on ax. -# -# Parameters -# ---------- -# ax : matplotlib.axes._subplots.PolarAxesSubplot or scitex.plt._subplots.AxisWrapper -# axis instance created with subplot_kw=dict(projection='polar'). -# -# radians : array -# Angles to plot, expected in units of radians. -# -# bins : int, optional -# Defines the number of equal-width bins in the range. The default is 16. -# -# density : bool, optional -# If True plot frequency proportional to area. If False plot frequency -# proportional to radius. The default is True. -# -# offset : float, optional -# Sets the offset for the location of the 0 direction in units of -# radians. The default is 0. -# -# gaps : bool, optional -# Whether to allow gaps between bins. When gaps = False the bins are -# forced to partition the entire [-pi, pi] range. The default is True. -# -# Returns -# ------- -# n : array or list of arrays -# The number of values in each bin. -# -# bins : array -# The edges of the bins. -# -# patches : `.BarContainer` or list of a single `.Polygon` -# Container of individual artists used to create the histogram -# or list of such containers if there are multiple input datasets. -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# # Wrap angles to [-pi, pi) -# radians = (radians + np.pi) % (2 * np.pi) - np.pi -# -# # Force bins to partition entire circle -# if not gaps: -# bins = np.linspace(-np.pi, np.pi, num=bins + 1) -# -# # Bin data and record counts -# n, bins = np.histogram( -# radians, bins=bins, range=(-np.pi + range_bias, np.pi + range_bias) -# ) -# -# # Compute width of each bin -# widths = np.diff(bins) -# -# # By default plot frequency proportional to area -# if density: -# # Area to assign each bin -# area = n / radians.size -# # Calculate corresponding bin radius -# radius = (area / np.pi) ** 0.5 -# # Otherwise plot frequency proportional to radius -# else: -# radius = n -# -# mean_val = np.nanmean(radians) -# std_val = np.nanstd(radians) -# axis.axvline(mean_val, color=color) -# axis.text(mean_val, 1, std_val) -# -# # Plot data on ax -# patches = axis.bar( -# bins[:-1], -# radius, -# zorder=1, -# align="edge", -# width=widths, -# edgecolor=color, -# alpha=0.9, -# fill=False, -# linewidth=1, -# ) -# -# # Set the direction of the zero angle -# axis.set_theta_offset(offset) -# -# # Remove ylabels for area plots (they are mostly obstructive) -# if density: -# axis.set_yticks([]) -# -# return n, bins, patches -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_plot_circular_hist.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__plot_cube.py b/tests/scitex/plt/ax/_plot/test__plot_cube.py deleted file mode 100644 index a3920468e..000000000 --- a/tests/scitex/plt/ax/_plot/test__plot_cube.py +++ /dev/null @@ -1,171 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 10:31:25 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_plot/test__plot_cube.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_plot/test__plot_cube.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib.pyplot as plt -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._plot import plot_cube - - -class TestPlotCube: - def setup_method(self): - # Create output directory if it doesn't exist - self.out_dir = __file__.replace(".py", "_out") - os.makedirs(self.out_dir, exist_ok=True) - - def save_test_figure(self, fig, method_name): - """Helper method to save figure using method name""" - from scitex.io import save - - spath = f"./{os.path.basename(__file__).replace('.py', '')}_{method_name}.jpg" - save(fig, spath) - # Check saved file - actual_spath = os.path.join(self.out_dir, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - def test_plot_cube_creates_12_edges(self): - fig = plt.figure() - ax = fig.add_subplot(projection="3d") - r1 = [0, 1] - r2 = [0, 1] - r3 = [0, 1] - plot_cube(ax, r1, r2, r3) - ax.set_title("3D Cube with 12 Edges") - - # Save figure - self.save_test_figure(fig, "test_plot_cube_creates_12_edges") - - # Clean up - plt.close(fig) - - def test_plot_cube_with_custom_color(self): - fig = plt.figure() - ax = fig.add_subplot(projection="3d") - r1 = [0, 1] - r2 = [0, 1] - r3 = [0, 1] - plot_cube(ax, r1, r2, r3, c="red") - ax.set_title("3D Cube with Custom Color") - - # Save figure - self.save_test_figure(fig, "test_plot_cube_with_custom_color") - - # Clean up - plt.close(fig) - - def test_plot_cube_with_custom_alpha(self): - fig = plt.figure() - ax = fig.add_subplot(projection="3d") - r1 = [0, 1] - r2 = [0, 1] - r3 = [0, 1] - plot_cube(ax, r1, r2, r3, alpha=0.5) - ax.set_title("3D Cube with Custom Alpha") - - # Save figure - self.save_test_figure(fig, "test_plot_cube_with_custom_alpha") - - # Clean up - plt.close(fig) - - def test_plot_cube_savefig(self): - fig = plt.figure() - ax = fig.add_subplot(projection="3d") - ax = plot_cube(ax, [0, 1], [0, 1], [0, 1], c="red") - ax.set_title("3D Cube Plot") - - # Saving - from scitex.io import save - - spath = f"./{os.path.basename(__file__)}.jpg" - save(fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - # Clean up - plt.close(fig) - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_plot_cube.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-01 15:21:37 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_cube.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot/_plot_cube.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# from itertools import combinations, product -# -# import numpy as np -# -# -# def plot_cube(ax, xlim, ylim, zlim, c="blue", alpha=1.0): -# """ -# Plot a 3D cube on the given axis. -# -# Args: -# ax: Matplotlib 3D axis -# xlim: Range for x-axis as a tuple (min, max) -# ylim: Range for y-axis as a tuple (min, max) -# zlim: Range for z-axis as a tuple (min, max) -# c: Color of the cube edges (default: 'blue') -# alpha: Transparency of the cube edges (default: 1.0) -# -# Returns: -# Matplotlib axis with the cube plotted -# """ -# # Validate inputs -# assert hasattr(ax, "plot3D"), "The axis must be a 3D axis with plot3D method" -# assert len(xlim) == 2, "xlim must be a tuple of (min, max)" -# assert len(ylim) == 2, "ylim must be a tuple of (min, max)" -# assert len(zlim) == 2, "zlim must be a tuple of (min, max)" -# assert xlim[0] < xlim[1], "xlim[0] must be less than xlim[1]" -# assert ylim[0] < ylim[1], "ylim[0] must be less than ylim[1]" -# assert zlim[0] < zlim[1], "zlim[0] must be less than zlim[1]" -# -# # Get all corners of the cube -# corners = np.array(list(product(xlim, ylim, zlim))) -# -# # Draw edges between corners -# for start, end in combinations(corners, 2): -# # Check if the points form an edge (differ in exactly one dimension) -# if np.sum(np.abs(start - end)) == xlim[1] - xlim[0]: -# ax.plot3D(*zip(start, end), c=c, linewidth=3, alpha=alpha) -# if np.sum(np.abs(start - end)) == ylim[1] - ylim[0]: -# ax.plot3D(*zip(start, end), c=c, linewidth=3, alpha=alpha) -# if np.sum(np.abs(start - end)) == zlim[1] - zlim[0]: -# ax.plot3D(*zip(start, end), c=c, linewidth=3, alpha=alpha) -# -# return ax -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_plot_cube.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__plot_statistical_shaded_line.py b/tests/scitex/plt/ax/_plot/test__plot_statistical_shaded_line.py deleted file mode 100644 index 2b133593f..000000000 --- a/tests/scitex/plt/ax/_plot/test__plot_statistical_shaded_line.py +++ /dev/null @@ -1,369 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-01 19:32:38 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_plot/test__plot_statistical_shaded_line.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_plot/test__plot_statistical_shaded_line.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - - -def test_plot_line_savefig(): - import matplotlib.pyplot as plt - import numpy as np - - from scitex.plt.ax._plot import plot_line - - fig, ax = plt.subplots() - data = np.sin(np.linspace(0, 10, 100)) - ax, df = plot_line(ax, data) - - # Saving - from scitex.io import save - - spath = f"./{os.path.basename(__file__)}.jpg" - save(fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -def test_plot_mean_std_savefig(): - import matplotlib.pyplot as plt - import numpy as np - - from scitex.plt.ax._plot import plot_mean_std - - fig, ax = plt.subplots() - data = np.random.normal(0, 1, (10, 100)) - ax, df = plot_mean_std(ax, data, label="Test") - - # Saving - from scitex.io import save - - spath = f"./{os.path.basename(__file__)}.jpg" - save(fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -def test_plot_mean_ci_savefig(): - import matplotlib.pyplot as plt - import numpy as np - - from scitex.plt.ax._plot import plot_mean_ci - - fig, ax = plt.subplots() - data = np.random.normal(0, 1, (10, 100)) - ax, df = plot_mean_ci(ax, data, label="Test") - - # Saving - from scitex.io import save - - spath = f"./{os.path.basename(__file__)}.jpg" - save(fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -def test_plot_median_iqr_savefig(): - import matplotlib.pyplot as plt - import numpy as np - - from scitex.plt.ax._plot import plot_median_iqr - - fig, ax = plt.subplots() - data = np.random.normal(0, 1, (10, 100)) - ax, df = plot_median_iqr(ax, data, label="Test") - - # Saving - from scitex.io import save - - spath = f"./{os.path.basename(__file__)}.jpg" - save(fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_plot_statistical_shaded_line.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-04-30 20:50:45 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot_statistical_shaded_line.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot_statistical_shaded_line.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# import numpy as np -# import pandas as pd -# from ....plt.utils import assert_valid_axis -# -# from ._stx_shaded_line import stx_shaded_line as scitex_plt_plot_shaded_line -# -# -# def _format_sample_size(values_2d): -# """Format sample size string, showing range if variable due to NaN. -# -# Parameters -# ---------- -# values_2d : np.ndarray, shape (n_samples, n_points) -# 2D array where sample count may vary per column due to NaN. -# -# Returns -# ------- -# str -# Formatted sample size string, e.g., "20" or "18-20". -# """ -# if values_2d.ndim == 1: -# return "1" -# -# # Count non-NaN values per column (timepoint) -# n_per_point = np.sum(~np.isnan(values_2d), axis=0) -# n_min, n_max = int(n_per_point.min()), int(n_per_point.max()) -# -# if n_min == n_max: -# return str(n_min) -# else: -# return f"{n_min}-{n_max}" -# -# -# def stx_line(axis, values_1d, xx=None, **kwargs): -# """ -# Plot a simple line. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axis to plot on -# values_1d : array-like, shape (n_points,) -# 1D array of y-values to plot -# xx : array-like, shape (n_points,), optional -# X coordinates for the data. If None, will use np.arange(len(values_1d)) -# **kwargs -# Additional keyword arguments passed to axis.plot() -# -# Returns -# ------- -# axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axis with the plot -# df : pandas.DataFrame -# DataFrame with x and y values -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# values_1d = np.asarray(values_1d) -# assert values_1d.ndim <= 2, f"Data must be 1D or 2D, got {values_1d.ndim}D" -# if xx is None: -# xx = np.arange(len(values_1d)) -# else: -# xx = np.asarray(xx) -# assert len(xx) == len(values_1d), ( -# f"xx length ({len(xx)}) must match values_1d length ({len(values_1d)})" -# ) -# -# axis.plot(xx, values_1d, **kwargs) -# return axis, pd.DataFrame({"x": xx, "y": values_1d}) -# -# -# def stx_mean_std(axis, values_2d, xx=None, sd=1, **kwargs): -# """ -# Plot mean line with standard deviation shading. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axis to plot on -# values_2d : array-like, shape (n_samples, n_points) or (n_points,) -# 2D array where mean and std are calculated across axis=0 (samples). -# Can also be 1D for a single line without shading. -# xx : array-like, shape (n_points,), optional -# X coordinates for the data. If None, will use np.arange(n_points) -# sd : float, optional -# Number of standard deviations for the shaded region. Default is 1 -# **kwargs -# Additional keyword arguments passed to stx_shaded_line() -# -# Returns -# ------- -# axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axis with the plot -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# assert isinstance(sd, (int, float)), f"sd must be a number, got {type(sd)}" -# assert sd >= 0, f"sd must be non-negative, got {sd}" -# values_2d = np.asarray(values_2d) -# assert values_2d.ndim <= 2, f"Data must be 1D or 2D, got {values_2d.ndim}D" -# if xx is None: -# xx = np.arange(values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d)) -# else: -# xx = np.asarray(xx) -# expected_len = values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d) -# assert len(xx) == expected_len, ( -# f"xx length ({len(xx)}) must match values_2d length ({expected_len})" -# ) -# -# if values_2d.ndim == 1: -# central = values_2d -# error = np.zeros_like(central) -# else: -# central = np.nanmean(values_2d, axis=0) -# error = np.nanstd(values_2d, axis=0) * sd -# -# y_lower = central - error -# y_upper = central + error -# -# if "label" in kwargs and kwargs["label"]: -# n_str = _format_sample_size(values_2d) -# kwargs["label"] = f"{kwargs['label']} ($n$={n_str})" -# -# return scitex_plt_plot_shaded_line(axis, xx, y_lower, central, y_upper, **kwargs) -# -# -# def stx_mean_ci(axis, values_2d, xx=None, perc=95, **kwargs): -# """ -# Plot mean line with confidence interval shading. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axis to plot on -# values_2d : array-like, shape (n_samples, n_points) or (n_points,) -# 2D array where mean and percentiles are calculated across axis=0 (samples). -# Can also be 1D for a single line without shading. -# xx : array-like, shape (n_points,), optional -# X coordinates for the data. If None, will use np.arange(n_points) -# perc : float, optional -# Confidence interval percentage (0-100). Default is 95 -# **kwargs -# Additional keyword arguments passed to stx_shaded_line() -# -# Returns -# ------- -# axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axis with the plot -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# assert isinstance(perc, (int, float)), f"perc must be a number, got {type(perc)}" -# assert 0 <= perc <= 100, f"perc must be between 0 and 100, got {perc}" -# values_2d = np.asarray(values_2d) -# assert values_2d.ndim <= 2, f"Data must be 1D or 2D, got {values_2d.ndim}D" -# -# if xx is None: -# xx = np.arange(values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d)) -# else: -# xx = np.asarray(xx) -# -# expected_len = values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d) -# assert len(xx) == expected_len, ( -# f"xx length ({len(xx)}) must match values_2d length ({expected_len})" -# ) -# -# if values_2d.ndim == 1: -# central = values_2d -# y_lower = central -# y_upper = central -# else: -# central = np.nanmean(values_2d, axis=0) -# # Calculate CI bounds -# alpha = 1 - perc / 100 -# y_lower_perc = alpha / 2 * 100 -# y_upper_perc = (1 - alpha / 2) * 100 -# y_lower = np.nanpercentile(values_2d, y_lower_perc, axis=0) -# y_upper = np.nanpercentile(values_2d, y_upper_perc, axis=0) -# -# if "label" in kwargs and kwargs["label"]: -# n_str = _format_sample_size(values_2d) -# kwargs["label"] = f"{kwargs['label']} ($n$={n_str}, CI={perc}%)" -# -# return scitex_plt_plot_shaded_line(axis, xx, y_lower, central, y_upper, **kwargs) -# -# -# def stx_median_iqr(axis, values_2d, xx=None, **kwargs): -# """ -# Plot median line with interquartile range shading. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axis to plot on -# values_2d : array-like, shape (n_samples, n_points) or (n_points,) -# 2D array where median and IQR are calculated across axis=0 (samples). -# Can also be 1D for a single line without shading. -# xx : array-like, shape (n_points,), optional -# X coordinates for the data. If None, will use np.arange(n_points) -# **kwargs -# Additional keyword arguments passed to stx_shaded_line() -# -# Returns -# ------- -# axis : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axis with the plot -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# values_2d = np.asarray(values_2d) -# assert values_2d.ndim <= 2, f"Data must be 1D or 2D, got {values_2d.ndim}D" -# -# if xx is None: -# xx = np.arange(values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d)) -# else: -# xx = np.asarray(xx) -# -# expected_len = values_2d.shape[1] if values_2d.ndim > 1 else len(values_2d) -# assert len(xx) == expected_len, ( -# f"xx length ({len(xx)}) must match values_2d length ({expected_len})" -# ) -# -# if values_2d.ndim == 1: -# central = values_2d -# y_lower = central -# y_upper = central -# else: -# central = np.nanmedian(values_2d, axis=0) -# y_lower = np.nanpercentile(values_2d, 25, axis=0) -# y_upper = np.nanpercentile(values_2d, 75, axis=0) -# -# if "label" in kwargs and kwargs["label"]: -# n_str = _format_sample_size(values_2d) -# kwargs["label"] = f"{kwargs['label']} ($n$={n_str}, IQR)" -# -# return scitex_plt_plot_shaded_line(axis, xx, y_lower, central, y_upper, **kwargs) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_plot_statistical_shaded_line.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_conf_mat.py b/tests/scitex/plt/ax/_plot/test__stx_conf_mat.py deleted file mode 100644 index bf9e0a593..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_conf_mat.py +++ /dev/null @@ -1,155 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_conf_mat.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-18 15:08:16 (ywatanabe)" -# # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_conf_mat.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot/_plot_conf_mat.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# from typing import List, Optional, Tuple, Union -# -# import matplotlib -# import matplotlib.pyplot as plt -# import numpy as np -# import pandas as pd -# import seaborn as sns -# -# from scitex.plt.utils._calc_bacc_from_conf_mat import calc_bacc_from_conf_mat -# from scitex.plt.utils import assert_valid_axis -# from .._style._extend import extend as scitex_plt_extend -# -# -# def stx_conf_mat( -# axis: plt.Axes, -# conf_mat_2d: Union[np.ndarray, pd.DataFrame], -# x_labels: Optional[List[str]] = None, -# y_labels: Optional[List[str]] = None, -# title: str = "Confusion Matrix", -# cmap: str = "Blues", -# cbar: bool = True, -# cbar_kw: dict = {}, -# label_rotation_xy: Tuple[float, float] = (15, 15), -# x_extend_ratio: float = 1.0, -# y_extend_ratio: float = 1.0, -# calc_bacc: bool = False, -# **kwargs, -# ) -> Union[plt.Axes, Tuple[plt.Axes, float]]: -# """Creates a confusion matrix heatmap with optional balanced accuracy. -# -# Parameters -# ---------- -# axis : plt.Axes or scitex.plt._subplots.AxisWrapper -# Matplotlib axes or scitex axis wrapper to plot on -# conf_mat_2d : Union[np.ndarray, pd.DataFrame], shape (n_classes, n_classes) -# 2D confusion matrix data (true labels × predicted labels) -# x_labels : Optional[List[str]], optional -# Labels for predicted classes -# y_labels : Optional[List[str]], optional -# Labels for true classes -# title : str, optional -# Plot title -# cmap : str, optional -# Colormap name -# cbar : bool, optional -# Whether to show colorbar -# cbar_kw : dict, optional -# Colorbar parameters -# label_rotation_xy : Tuple[float, float], optional -# (x,y) label rotation angles -# x_extend_ratio : float, optional -# X-axis extension ratio -# y_extend_ratio : float, optional -# Y-axis extension ratio -# calc_bacc : bool, optional -# Calculate Balanced Accuracy from Confusion Matrix -# -# Returns -# ------- -# Union[plt.Axes, Tuple[plt.Axes, float]] or Union[scitex.plt._subplots.AxisWrapper, Tuple[scitex.plt._subplots.AxisWrapper, float]] -# Axes object and optionally balanced accuracy -# -# Example -# ------- -# >>> data = np.array([[10, 2, 0], [1, 15, 3], [0, 2, 20]]) -# >>> fig, ax = plt.subplots() -# >>> ax, bacc = stx_conf_mat(ax, data, x_labels=['A','B','C'], -# ... y_labels=['X','Y','Z'], calc_bacc=True) -# >>> print(f"Balanced Accuracy: {bacc:.3f}") -# Balanced Accuracy: 0.889 -# """ -# -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# if not isinstance(conf_mat_2d, pd.DataFrame): -# conf_mat_2d = pd.DataFrame(conf_mat_2d) -# -# bacc_val = calc_bacc_from_conf_mat(conf_mat_2d.values) -# title = f"{title} (bACC = {bacc_val:.3f})" -# -# res = sns.heatmap( -# conf_mat_2d, -# ax=axis, -# cmap=cmap, -# annot=True, -# fmt=",d", -# cbar=False, -# vmin=0, -# **kwargs, -# ) -# -# res.invert_yaxis() -# -# for _, spine in res.spines.items(): -# spine.set_visible(False) -# -# axis.set_xlabel("Predicted label") -# axis.set_ylabel("True label") -# axis.set_title(title) -# -# if x_labels is not None: -# axis.set_xticklabels(x_labels) -# if y_labels is not None: -# axis.set_yticklabels(y_labels) -# -# axis = scitex_plt_extend(axis, x_extend_ratio, y_extend_ratio) -# if conf_mat_2d.shape[0] == conf_mat_2d.shape[1]: -# axis.set_box_aspect(1) -# axis.set_xticklabels( -# axis.get_xticklabels(), -# rotation=label_rotation_xy[0], -# fontdict={"verticalalignment": "top"}, -# ) -# axis.set_yticklabels( -# axis.get_yticklabels(), -# rotation=label_rotation_xy[1], -# fontdict={"horizontalalignment": "right"}, -# ) -# -# if calc_bacc: -# return axis, bacc_val -# else: -# return axis, None -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_conf_mat.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_ecdf.py b/tests/scitex/plt/ax/_plot/test__stx_ecdf.py deleted file mode 100644 index b4e731feb..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_ecdf.py +++ /dev/null @@ -1,129 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_ecdf.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 14:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_plot/_plot_ecdf.py -# -# """Empirical Cumulative Distribution Function (ECDF) plotting.""" -# -# from typing import Any, Tuple, Union -# -# import numpy as np -# import pandas as pd -# from matplotlib.axes import Axes -# -# from scitex import logging -# from scitex.pd._force_df import force_df as scitex_pd_force_df -# from ....plt.utils import assert_valid_axis, mm_to_pt -# -# logger = logging.getLogger(__name__) -# -# -# # Default line width (0.2mm for publication) -# DEFAULT_LINE_WIDTH_MM = 0.2 -# -# -# def stx_ecdf( -# axis: Union[Axes, "AxisWrapper"], -# values_1d: np.ndarray, -# **kwargs: Any, -# ) -> Tuple[Union[Axes, "AxisWrapper"], pd.DataFrame]: -# """Plot Empirical Cumulative Distribution Function (ECDF). -# -# The ECDF shows the proportion of data points less than or equal to each -# value, representing the empirical estimate of the cumulative distribution -# function. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes or AxisWrapper -# Matplotlib axis or scitex axis wrapper to plot on. -# values_1d : array-like, shape (n_samples,) -# 1D array of values to compute and plot ECDF for. NaN values are automatically ignored. -# **kwargs : dict -# Additional arguments passed to plot function. -# -# Returns -# ------- -# axis : matplotlib.axes.Axes or AxisWrapper -# The axes with the ECDF plot. -# df : pd.DataFrame -# DataFrame containing ECDF data with columns: -# - x: sorted data values -# - y: cumulative percentages (0-100) -# - n: total number of data points -# - x_step, y_step: step plot coordinates -# -# Examples -# -------- -# >>> import numpy as np -# >>> import scitex as stx -# >>> data = np.random.randn(100) -# >>> fig, ax = stx.plt.subplots() -# >>> ax, df = stx.plt.ax.stx_ecdf(ax, data) -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# # Flatten and remove NaN values -# values_1d = np.hstack(values_1d) -# -# # Warnings -# if np.isnan(values_1d).any(): -# logger.warning("NaN values are ignored for ECDF plot.") -# values_1d = values_1d[~np.isnan(values_1d)] -# nn = len(values_1d) -# -# # Sort the data and compute the ECDF values -# data_sorted = np.sort(values_1d) -# ecdf_perc = 100 * np.arange(1, len(data_sorted) + 1) / len(data_sorted) -# -# # Create the pseudo x-axis for step plotting -# x_step = np.repeat(data_sorted, 2)[1:] -# y_step = np.repeat(ecdf_perc, 2)[:-1] -# -# # Apply default linewidth if not specified -# if "linewidth" not in kwargs and "lw" not in kwargs: -# kwargs["linewidth"] = mm_to_pt(DEFAULT_LINE_WIDTH_MM) -# -# # Add sample size to label if provided -# if "label" in kwargs and kwargs["label"]: -# kwargs["label"] = f"{kwargs['label']} ($n$={nn})" -# -# # Plot the ECDF using steps (no markers - clean line only) -# axis.plot(x_step, y_step, drawstyle="steps-post", **kwargs) -# -# # Set ylim (xlim is auto-scaled based on data) -# axis.set_ylim(0, 100) -# -# # Create a DataFrame to hold the ECDF data -# df = scitex_pd_force_df( -# { -# "x": data_sorted, -# "y": ecdf_perc, -# "n": nn, -# "x_step": x_step, -# "y_step": y_step, -# } -# ) -# -# return axis, df -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_ecdf.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_fillv.py b/tests/scitex/plt/ax/_plot/test__stx_fillv.py deleted file mode 100644 index c963671dc..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_fillv.py +++ /dev/null @@ -1,73 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_fillv.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-04-30 21:26:45 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot_fillv.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot_fillv.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# import numpy as np -# from ....plt.utils import assert_valid_axis -# -# -# def stx_fillv(axes, starts_1d, ends_1d, color="red", alpha=0.2): -# """ -# Fill between specified start and end intervals on an axis or array of axes. -# -# Parameters -# ---------- -# axes : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper or numpy.ndarray of axes -# The axis object(s) to fill intervals on. -# starts_1d : array-like, shape (n_regions,) -# 1D array of start x-positions for vertical fill regions. -# ends_1d : array-like, shape (n_regions,) -# 1D array of end x-positions for vertical fill regions. -# color : str, optional -# The color to use for the filled regions. Default is "red". -# alpha : float, optional -# The alpha blending value, between 0 (transparent) and 1 (opaque). Default is 0.2. -# -# Returns -# ------- -# list -# List of axes with filled intervals. -# """ -# -# is_axes = isinstance(axes, np.ndarray) -# -# axes = axes if isinstance(axes, np.ndarray) else [axes] -# -# for ax in axes: -# assert_valid_axis( -# ax, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# for start, end in zip(starts_1d, ends_1d): -# ax.axvspan(start, end, facecolor=color, edgecolor="none", alpha=alpha) -# -# if not is_axes: -# return axes[0] -# else: -# return axes -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_fillv.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_heatmap.py b/tests/scitex/plt/ax/_plot/test__stx_heatmap.py deleted file mode 100644 index ad368f154..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_heatmap.py +++ /dev/null @@ -1,385 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_heatmap.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 13:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_plot/_plot_heatmap.py -# -# """Heatmap plotting with automatic annotation color switching.""" -# -# from typing import Any, List, Optional, Tuple, Union -# -# import matplotlib -# import matplotlib.pyplot as plt -# import numpy as np -# from matplotlib.axes import Axes -# from matplotlib.colorbar import Colorbar -# from matplotlib.image import AxesImage -# -# -# def stx_heatmap( -# ax: Union[Axes, "AxisWrapper"], -# values_2d: np.ndarray, -# x_labels: Optional[List[str]] = None, -# y_labels: Optional[List[str]] = None, -# cmap: str = "viridis", -# cbar_label: str = "ColorBar Label", -# annot_format: str = "{x:.1f}", -# show_annot: bool = True, -# annot_color_lighter: str = "black", -# annot_color_darker: str = "white", -# **kwargs: Any, -# ) -> Tuple[Union[Axes, "AxisWrapper"], AxesImage, Colorbar]: -# """Plot a heatmap on the given axes with automatic annotation colors. -# -# Creates a heatmap visualization with optional cell annotations. Annotation -# text colors are automatically switched based on background brightness for -# optimal readability. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes or AxisWrapper -# The axes to plot on. -# values_2d : np.ndarray, shape (n_rows, n_cols) -# 2D array of data to display as heatmap. -# x_labels : list of str, optional -# Labels for the x-axis (columns). -# y_labels : list of str, optional -# Labels for the y-axis (rows). -# cmap : str, default "viridis" -# Colormap name to use. -# cbar_label : str, default "ColorBar Label" -# Label for the colorbar. -# annot_format : str, default "{x:.1f}" -# Format string for cell annotations. -# show_annot : bool, default True -# Whether to annotate the heatmap with values. -# annot_color_lighter : str, default "black" -# Text color for annotations on lighter backgrounds. -# annot_color_darker : str, default "white" -# Text color for annotations on darker backgrounds. -# **kwargs : dict -# Additional keyword arguments passed to imshow(). -# -# Returns -# ------- -# ax : matplotlib.axes.Axes or AxisWrapper -# The axes with the heatmap. -# im : matplotlib.image.AxesImage -# The image object created by imshow. -# cbar : matplotlib.colorbar.Colorbar -# The colorbar object. -# -# Examples -# -------- -# >>> import numpy as np -# >>> import scitex as stx -# >>> data = np.random.rand(5, 10) -# >>> fig, ax = stx.plt.subplots() -# >>> ax, im, cbar = stx.plt.ax.stx_heatmap( -# ... ax, data, -# ... x_labels=[f"X{i}" for i in range(10)], -# ... y_labels=[f"Y{i}" for i in range(5)], -# ... cmap="Blues" -# ... ) -# """ -# -# im, cbar = _mpl_heatmap( -# values_2d, -# x_labels, -# y_labels, -# ax=ax, -# cmap=cmap, -# cbarlabel=cbar_label, -# ) -# -# if show_annot: -# textcolors = _switch_annot_colors(cmap, annot_color_lighter, annot_color_darker) -# texts = _mpl_annotate_heatmap( -# im, -# valfmt=annot_format, -# textcolors=textcolors, -# ) -# -# return ax, im, cbar -# -# -# def _switch_annot_colors( -# cmap: str, -# annot_color_lighter: str, -# annot_color_darker: str, -# ) -> Tuple[str, str]: -# """Determine annotation text colors based on colormap brightness. -# -# Uses perceived brightness (ITU-R BT.709) to select appropriate text -# colors for light vs dark backgrounds in the colormap. -# -# Parameters -# ---------- -# cmap : str -# Colormap name. -# annot_color_lighter : str -# Color to use on lighter backgrounds. -# annot_color_darker : str -# Color to use on darker backgrounds. -# -# Returns -# ------- -# tuple of str -# (color_for_dark_bg, color_for_light_bg) text colors. -# """ -# cmap_obj = plt.cm.get_cmap(cmap) -# -# # Sample colormap at extremes (avoiding edge effects) -# dark_color = cmap_obj(0.1) -# light_color = cmap_obj(0.9) -# -# # Calculate perceived brightness using ITU-R BT.709 coefficients -# dark_brightness = ( -# 0.2126 * dark_color[0] + 0.7152 * dark_color[1] + 0.0722 * dark_color[2] -# ) -# -# # Choose text colors based on background brightness -# if dark_brightness < 0.5: -# return (annot_color_lighter, annot_color_darker) -# else: -# return (annot_color_darker, annot_color_lighter) -# -# -# def _mpl_heatmap( -# data: np.ndarray, -# row_labels: Optional[List[str]], -# col_labels: Optional[List[str]], -# ax: Optional[Axes] = None, -# cbar_kw: Optional[dict] = None, -# cbarlabel: str = "", -# **kwargs: Any, -# ) -> Tuple[AxesImage, Colorbar]: -# """Create a heatmap with imshow and add a colorbar. -# -# Parameters -# ---------- -# data : np.ndarray -# 2D array of data to display. -# row_labels : list of str or None -# Labels for the rows (y-axis). -# col_labels : list of str or None -# Labels for the columns (x-axis). -# ax : matplotlib.axes.Axes, optional -# Axes to plot on. If None, uses current axes. -# cbar_kw : dict, optional -# Keyword arguments for colorbar creation. -# cbarlabel : str, default "" -# Label for the colorbar. -# **kwargs : dict -# Additional keyword arguments passed to imshow(). -# -# Returns -# ------- -# im : matplotlib.image.AxesImage -# The image object. -# cbar : matplotlib.colorbar.Colorbar -# The colorbar object. -# """ -# -# if ax is None: -# ax = plt.gca() -# -# if cbar_kw is None: -# cbar_kw = {} -# -# # Plot the heatmap -# im = ax.imshow(data, **kwargs) -# -# # Create colorbar with proper formatting -# cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) -# cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") -# -# # Set colorbar border width to match axes spines -# cbar.outline.set_linewidth(0.2 * 2.83465) # 0.2mm in points -# -# # Format colorbar ticks -# from matplotlib.ticker import MaxNLocator -# -# cbar.ax.yaxis.set_major_locator(MaxNLocator(nbins=4, min_n_ticks=3)) -# cbar.ax.tick_params(width=0.2 * 2.83465, length=0.8 * 2.83465) # Match tick styling -# -# # Show all ticks and label them with the respective list entries. -# ax.set_xticks( -# range(data.shape[1]), -# labels=col_labels, -# # rotation=45, -# # ha="right", -# # rotation_mode="anchor", -# ) -# ax.set_yticks(range(data.shape[0]), labels=row_labels) -# -# # Let the horizontal axes labeling appear on top. -# ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True) -# -# # Show all 4 spines for heatmap -# ax.spines[:].set_visible(True) -# -# # Set aspect ratio to 'equal' for square cells (1:1) -# ax.set_aspect("equal", adjustable="box") -# -# ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True) -# ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True) -# ax.tick_params(which="minor", bottom=False, left=False) -# -# return im, cbar -# -# -# def _calc_annot_fontsize(n_rows: int, n_cols: int) -> float: -# """Calculate dynamic annotation font size based on cell count. -# -# Uses a base size of 6pt for small heatmaps and scales down for larger ones. -# -# Parameters -# ---------- -# n_rows : int -# Number of rows in the heatmap. -# n_cols : int -# Number of columns in the heatmap. -# -# Returns -# ------- -# float -# Font size in points. -# """ -# # Base font size for small heatmaps (e.g., 5x5) -# BASE_FONTSIZE = 6.0 -# BASE_CELLS = 5 # Reference dimension -# -# # Use the larger dimension to scale -# max_dim = max(n_rows, n_cols) -# -# if max_dim <= BASE_CELLS: -# return BASE_FONTSIZE -# elif max_dim <= 10: -# # Linear interpolation: 6pt at 5 cells, 5pt at 10 cells -# return BASE_FONTSIZE - (max_dim - BASE_CELLS) * 0.2 -# elif max_dim <= 20: -# # 5pt at 10 cells, 4pt at 20 cells -# return 5.0 - (max_dim - 10) * 0.1 -# else: -# # Minimum 3pt for very large heatmaps -# return max(3.0, 4.0 - (max_dim - 20) * 0.05) -# -# -# def _mpl_annotate_heatmap( -# im: AxesImage, -# data: Optional[np.ndarray] = None, -# valfmt: str = "{x:.2f}", -# textcolors: Tuple[str, str] = ("lightgray", "black"), -# threshold: Optional[float] = None, -# fontsize: Optional[float] = None, -# **textkw: Any, -# ) -> List: -# """Annotate a heatmap with cell values. -# -# Parameters -# ---------- -# im : matplotlib.image.AxesImage -# The image to be annotated. -# data : np.ndarray, optional -# Data used to annotate. If None, uses the image's array. -# valfmt : str, default "{x:.2f}" -# Format string for the annotations. -# textcolors : tuple of str, default ("lightgray", "black") -# Colors for annotations. First color for values below threshold, -# second for values above. -# threshold : float, optional -# Value in normalized colormap space (0 to 1) above which the -# second color is used. If None, uses 0.7 * max(data). -# fontsize : float, optional -# Font size in points. If None, dynamically calculated based on -# cell count (6pt base, scaling down for larger heatmaps). -# **textkw : dict -# Additional keyword arguments passed to ax.text(). -# -# Returns -# ------- -# texts : list of matplotlib.text.Text -# The annotation text objects. -# """ -# -# if not isinstance(data, (list, np.ndarray)): -# data = im.get_array() -# -# # Calculate dynamic font size if not specified -# if fontsize is None: -# fontsize = _calc_annot_fontsize(data.shape[0], data.shape[1]) -# -# # Normalize the threshold to the images color range. -# if threshold is not None: -# threshold = im.norm(threshold) -# else: -# # Use 0.7 instead of 0.5 for better visibility with most colormaps -# threshold = im.norm(data.max()) * 0.7 -# -# # Set default alignment to center, but allow it to be -# # overwritten by textkw. -# kw = dict( -# horizontalalignment="center", verticalalignment="center", fontsize=fontsize -# ) -# kw.update(textkw) -# -# # Get the formatter in case a string is supplied -# if isinstance(valfmt, str): -# valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) -# -# # Loop over the data and create a `Text` for each "pixel". -# # Change the text's color depending on the data. -# texts = [] -# for ii in range(data.shape[0]): -# for jj in range(data.shape[1]): -# kw.update(color=textcolors[int(im.norm(data[ii, jj]) > threshold)]) -# text = im.axes.text(jj, ii, valfmt(data[ii, jj], None), **kw) -# texts.append(text) -# -# return texts -# -# -# if __name__ == "__main__": -# import matplotlib -# import matplotlib as mpl -# import matplotlib.pyplot as plt -# import numpy as np -# -# data = np.random.rand(5, 10) -# x_labels = [f"X{ii + 1}" for ii in range(5)] -# y_labels = [f"Y{ii + 1}" for ii in range(10)] -# -# fig, ax = plt.subplots() -# -# im, cbar = stx_heatmap( -# ax, -# data, -# x_labels=x_labels, -# y_labels=y_labels, -# show_annot=True, -# annot_color_lighter="white", -# annot_color_darker="black", -# cmap="Blues", -# ) -# -# fig.tight_layout() -# plt.show() -# # EOF -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_heatmap.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_image.py b/tests/scitex/plt/ax/_plot/test__stx_image.py deleted file mode 100644 index db4427fe5..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_image.py +++ /dev/null @@ -1,112 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_image.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-01 08:39:46 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_image2d.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot/_plot_image2d.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# from scitex.plt.utils import assert_valid_axis -# -# -# def stx_image( -# ax, -# arr_2d, -# cbar=True, -# cbar_label=None, -# cbar_shrink=1.0, -# cbar_fraction=0.046, -# cbar_pad=0.04, -# cmap="viridis", -# aspect="auto", -# vmin=None, -# vmax=None, -# **kwargs, -# ): -# """ -# Imshows an two-dimensional array with theese two conditions: -# 1) The first dimension represents the x dim, from left to right. -# 2) The second dimension represents the y dim, from bottom to top -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axis to plot on -# arr_2d : numpy.ndarray -# The 2D array to display -# cbar : bool, optional -# Whether to show colorbar, by default True -# cbar_label : str, optional -# Label for the colorbar, by default None -# cbar_shrink : float, optional -# Shrink factor for the colorbar, by default 1.0 -# cbar_fraction : float, optional -# Fraction of original axes to use for colorbar, by default 0.046 -# cbar_pad : float, optional -# Padding between the image axes and colorbar axes, by default 0.04 -# cmap : str, optional -# Colormap name, by default "viridis" -# aspect : str, optional -# Aspect ratio adjustment, by default "auto" -# vmin : float, optional -# Minimum data value for colormap scaling, by default None -# vmax : float, optional -# Maximum data value for colormap scaling, by default None -# **kwargs -# Additional keyword arguments passed to ax.imshow() -# -# Returns -# ------- -# matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axis with the image plotted -# """ -# assert_valid_axis( -# ax, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# assert arr_2d.ndim == 2, "Input array must be 2-dimensional" -# -# if kwargs.get("xyz"): -# kwargs.pop("xyz") -# -# # Transposes arr_2d for correct orientation -# arr_2d = arr_2d.T -# -# # Cals the original ax.imshow() method on the transposed array -# im = ax.imshow(arr_2d, cmap=cmap, vmin=vmin, vmax=vmax, aspect=aspect, **kwargs) -# -# # Color bar -# if cbar: -# fig = ax.get_figure() -# _cbar = fig.colorbar( -# im, ax=ax, shrink=cbar_shrink, fraction=cbar_fraction, pad=cbar_pad -# ) -# if cbar_label: -# _cbar.set_label(cbar_label) -# -# # Invert y-axis to match typical image orientation -# ax.invert_yaxis() -# -# return ax -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_image.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_joyplot.py b/tests/scitex/plt/ax/_plot/test__stx_joyplot.py deleted file mode 100644 index 4a72ebe03..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_joyplot.py +++ /dev/null @@ -1,151 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_joyplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-02 09:03:23 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_joyplot.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot/_plot_joyplot.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# from scipy import stats -# -# from ....plt.utils import assert_valid_axis -# -# -# def stx_joyplot( -# ax, arrays, overlap=0.5, fill_alpha=0.7, line_alpha=1.0, colors=None, **kwargs -# ): -# """ -# Create a joyplot (ridgeline plot) on the provided axes. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The axes to plot on -# arrays : list of array-like -# List of 1D arrays for each ridge -# overlap : float, default 0.5 -# Amount of overlap between ridges (0 = no overlap, 1 = full overlap) -# fill_alpha : float, default 0.7 -# Alpha for the filled KDE area -# line_alpha : float, default 1.0 -# Alpha for the KDE line -# colors : list, optional -# Colors for each ridge. If None, uses scitex palette. -# **kwargs -# Additional keyword arguments -# -# Returns -# ------- -# matplotlib.axes.Axes -# The axes with the joyplot -# """ -# assert_valid_axis( -# ax, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# # Convert dict to list of arrays (values only) -# if isinstance(arrays, dict): -# arrays = list(arrays.values()) -# -# # Add sample size per distribution to label if provided (show range if variable) -# if kwargs.get("label"): -# n_per_dist = [len(arr) for arr in arrays] -# n_min, n_max = min(n_per_dist), max(n_per_dist) -# n_str = str(n_min) if n_min == n_max else f"{n_min}-{n_max}" -# kwargs["label"] = f"{kwargs['label']} ($n$={n_str})" -# -# # Import scitex colors -# from scitex.plt.color._PARAMS import HEX -# -# # Default colors from scitex palette -# if colors is None: -# colors = [ -# HEX["blue"], -# HEX["red"], -# HEX["green"], -# HEX["yellow"], -# HEX["purple"], -# HEX["orange"], -# HEX["lightblue"], -# HEX["pink"], -# ] -# -# n_ridges = len(arrays) -# -# # Calculate global x range -# all_data = np.concatenate([np.asarray(arr) for arr in arrays]) -# x_min, x_max = np.min(all_data), np.max(all_data) -# x_range = x_max - x_min -# x_padding = x_range * 0.1 -# x = np.linspace(x_min - x_padding, x_max + x_padding, 200) -# -# # Calculate KDEs and find max density for scaling -# kdes = [] -# max_density = 0 -# for arr in arrays: -# arr = np.asarray(arr) -# if len(arr) > 1: -# kde = stats.gaussian_kde(arr) -# density = kde(x) -# kdes.append(density) -# max_density = max(max_density, np.max(density)) -# else: -# kdes.append(np.zeros_like(x)) -# -# # Scale factor for ridge height -# ridge_height = 1.0 / (1.0 - overlap * 0.5) if overlap < 1 else 2.0 -# -# # Plot each ridge from back to front -# for i in range(n_ridges - 1, -1, -1): -# color = colors[i % len(colors)] -# baseline = i * (1.0 - overlap) -# -# # Scale density to fit nicely -# scaled_density = ( -# kdes[i] / max_density * ridge_height if max_density > 0 else kdes[i] -# ) -# -# # Fill -# ax.fill_between( -# x, -# baseline, -# baseline + scaled_density, -# facecolor=color, -# edgecolor="none", -# alpha=fill_alpha, -# ) -# # Line on top -# ax.plot( -# x, baseline + scaled_density, color=color, alpha=line_alpha, linewidth=1.0 -# ) -# -# # Set y limits -# ax.set_ylim(-0.1, n_ridges * (1.0 - overlap) + ridge_height) -# -# # Hide y-axis ticks for cleaner look (joyplots typically don't show y values) -# ax.set_yticks([]) -# -# return ax -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_joyplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_raster.py b/tests/scitex/plt/ax/_plot/test__stx_raster.py deleted file mode 100644 index 6b697c53d..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_raster.py +++ /dev/null @@ -1,215 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_raster.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-01 15:23:01 (ywatanabe)" -# # File: /home/ywatanabe/proj/_scitex_repo/src/scitex/plt/ax/_plot/_plot_raster.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot/_plot_raster.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# from bisect import bisect_left -# -# import matplotlib.pyplot as plt -# import numpy as np -# import pandas as pd -# from ....plt.utils import assert_valid_axis -# -# -# def stx_raster( -# ax, -# spike_times_list, -# time=None, -# labels=None, -# colors=None, -# orientation="horizontal", -# y_offset=None, -# lineoffsets=None, -# linelengths=None, -# apply_set_n_ticks=True, -# n_xticks=4, -# n_yticks=None, -# **kwargs, -# ): -# """ -# Create a raster plot using eventplot with custom labels and colors. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axes on which to draw the raster plot. -# spike_times_list : list of array-like, shape (n_trials,) where each element is (n_spikes,) -# List of spike/event time arrays, one per trial/channel -# time : array-like, optional -# The time indices for the events (default: np.linspace(0, max(event_times))). -# labels : list, optional -# Labels for each channel/trial. -# colors : list, optional -# Colors for each channel/trial. -# orientation: str, optional -# Orientation of raster plot (default: horizontal). -# y_offset : float, optional -# Vertical spacing between trials/channels (default: 1.0). -# lineoffsets : array-like, optional -# Y-positions for each trial/channel (overrides automatic positioning). -# linelengths : float, optional -# Height of each spike mark (default: 0.8, slightly less than y_offset to prevent overlap). -# apply_set_n_ticks : bool, optional -# Whether to apply set_n_ticks for cleaner axis (default: True). -# n_xticks : int, optional -# Number of x-axis ticks (default: 4). -# n_yticks : int or None, optional -# Number of y-axis ticks (default: None, auto-determined). -# **kwargs : dict -# Additional keyword arguments for eventplot. -# -# Returns -# ------- -# ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axes with the raster plot. -# df : pandas.DataFrame -# DataFrame with time indices and channel events. -# """ -# assert_valid_axis( -# ax, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# # Format spike_times_list data -# spike_times_list = _ensure_list(spike_times_list) -# -# # Add sample size (number of trials) to label if provided -# if kwargs.get("label"): -# n_trials = len(spike_times_list) -# kwargs["label"] = f"{kwargs['label']} ($n$={n_trials})" -# -# # Handle colors and labels -# colors = _handle_colors(colors, spike_times_list) -# -# # Handle lineoffsets for positioning between trials/channels -# if y_offset is None: -# y_offset = 1.0 # Default spacing -# if lineoffsets is None: -# lineoffsets = np.arange(len(spike_times_list)) * y_offset -# -# # Set linelengths to prevent overlap (80% of y_offset by default) -# if linelengths is None: -# linelengths = y_offset * 0.8 -# -# # Ensure lineoffsets is iterable and matches spike_times_list length -# if np.isscalar(lineoffsets): -# lineoffsets = [lineoffsets] -# if len(lineoffsets) < len(spike_times_list): -# lineoffsets = list(lineoffsets) + list( -# range(len(lineoffsets), len(spike_times_list)) -# ) -# -# # Plotting as eventplot using spike_times_list with proper positioning -# for ii, (pos, color, offset) in enumerate( -# zip(spike_times_list, colors, lineoffsets) -# ): -# label = _define_label(labels, ii) -# ax.eventplot( -# pos, -# lineoffsets=offset, -# linelengths=linelengths, -# orientation=orientation, -# colors=color, -# label=label, -# **kwargs, -# ) -# -# # Apply set_n_ticks for cleaner axes if requested -# if apply_set_n_ticks: -# from scitex.plt.ax._style._set_n_ticks import set_n_ticks -# -# # For categorical y-axis (trials/channels), use appropriate tick count -# if n_yticks is None: -# n_yticks = min(len(spike_times_list), 8) # Max 8 ticks for readability -# -# # Only apply if we have reasonable numeric ranges -# try: -# x_range = ax.get_xlim() -# y_range = ax.get_ylim() -# -# # Apply x-ticks if we have a reasonable numeric range -# if x_range[1] - x_range[0] > 0: -# set_n_ticks(ax, n_xticks=n_xticks, n_yticks=None) -# -# # Apply y-ticks only if we don't have categorical labels -# if labels is None and y_range[1] - y_range[0] > 0: -# set_n_ticks(ax, n_xticks=None, n_yticks=n_yticks) -# -# except Exception: -# # Skip set_n_ticks if there are issues (e.g., categorical data) -# pass -# -# # Legend -# if labels is not None: -# ax.legend() -# -# # Return spike_times in a useful format -# spike_times_digital_df = _event_times_to_digital_df( -# spike_times_list, time, lineoffsets -# ) -# -# return ax, spike_times_digital_df -# -# -# def _ensure_list(event_times): -# return [[pos] if isinstance(pos, (int, float)) else pos for pos in event_times] -# -# -# def _define_label(labels, ii): -# if (labels is not None) and (ii < len(labels)): -# return labels[ii] -# else: -# return None -# -# -# def _handle_colors(colors, event_times_list): -# if colors is None: -# colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] -# if len(colors) < len(event_times_list): -# colors = colors * (len(event_times_list) // len(colors) + 1) -# return colors -# -# -# def _event_times_to_digital_df(event_times_list, time, lineoffsets=None): -# if time is None: -# time = np.linspace(0, np.max([np.max(pos) for pos in event_times_list]), 1000) -# -# digi = np.full((len(event_times_list), len(time)), np.nan, dtype=float) -# -# for i_ch, posis_ch in enumerate(event_times_list): -# for posi_ch in posis_ch: -# i_insert = bisect_left(time, posi_ch) -# if i_insert == len(time): -# i_insert -= 1 -# # Use lineoffset position if available, otherwise use channel index -# if lineoffsets is not None and i_ch < len(lineoffsets): -# digi[i_ch, i_insert] = lineoffsets[i_ch] -# else: -# digi[i_ch, i_insert] = i_ch -# -# return pd.DataFrame(digi.T, index=time) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_raster.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_rectangle.py b/tests/scitex/plt/ax/_plot/test__stx_rectangle.py deleted file mode 100644 index 9f04e2512..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_rectangle.py +++ /dev/null @@ -1,86 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_rectangle.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-01 08:45:44 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_rectangle.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot/_plot_rectangle.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# from matplotlib.patches import Rectangle -# -# -# def stx_rectangle(ax, xx, yy, ww, hh, **kwargs): -# """Add a rectangle patch to an axes. -# -# Convenience function for adding rectangular patches to plots, useful for -# highlighting regions, creating box annotations, or drawing geometric shapes. -# By default, rectangles have no edge (border) for cleaner publication figures. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The axes to add the rectangle to. -# xx : float -# X-coordinate of the rectangle's bottom-left corner. -# yy : float -# Y-coordinate of the rectangle's bottom-left corner. -# ww : float -# Width of the rectangle. -# hh : float -# Height of the rectangle. -# **kwargs : dict -# Additional keyword arguments passed to matplotlib.patches.Rectangle. -# Common options include: -# - facecolor/fc : fill color -# - edgecolor/ec : edge color (default: 'none') -# - linewidth/lw : edge line width -# - alpha : transparency (0-1) -# - linestyle/ls : edge line style -# -# Returns -# ------- -# matplotlib.axes.Axes -# The axes with the rectangle added. -# -# Examples -# -------- -# >>> fig, ax = plt.subplots() -# >>> ax.plot([0, 10], [0, 10]) -# >>> # Highlight a region (no border by default) -# >>> stx_rectangle(ax, 2, 3, 4, 3, facecolor='yellow', alpha=0.3) -# -# >>> # Draw a box with explicit edge -# >>> stx_rectangle(ax, 5, 5, 2, 2, facecolor='none', edgecolor='red', linewidth=2) -# -# See Also -# -------- -# matplotlib.patches.Rectangle : The underlying Rectangle class -# matplotlib.axes.Axes.add_patch : Method used to add the patch -# """ -# # Default to no edge for cleaner publication figures -# if "edgecolor" not in kwargs and "ec" not in kwargs: -# kwargs["edgecolor"] = "none" -# ax.add_patch(Rectangle((xx, yy), ww, hh, **kwargs)) -# return ax -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_rectangle.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_scatter_hist.py b/tests/scitex/plt/ax/_plot/test__stx_scatter_hist.py deleted file mode 100644 index 60d99dd0a..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_scatter_hist.py +++ /dev/null @@ -1,149 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_scatter_hist.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-02 18:14:56 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_scatter_hist.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot/_plot_scatter_hist.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import numpy as np -# -# -# def stx_scatter_hist( -# ax, -# x, -# y, -# fig=None, -# hist_bins: int = 20, -# scatter_alpha: float = 0.6, -# scatter_size: float = 20, -# scatter_color: str = "blue", -# hist_color_x: str = "blue", -# hist_color_y: str = "red", -# hist_alpha: float = 0.5, -# scatter_ratio: float = 0.8, -# **kwargs, -# ): -# """ -# Plot a scatter plot with histograms on the x and y axes. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The main scatter plot axes -# x : array-like -# x data for scatter plot and histogram -# y : array-like -# y data for scatter plot and histogram -# fig : matplotlib.figure.Figure, optional -# Figure to create axes in. If None, uses ax.figure -# hist_bins : int, optional -# Number of bins for histograms, default 20 -# scatter_alpha : float, optional -# Alpha value for scatter points, default 0.6 -# scatter_size : float, optional -# Size of scatter points, default 20 -# scatter_color : str, optional -# Color of scatter points, default "blue" -# hist_color_x : str, optional -# Color of x-axis histogram, default "blue" -# hist_color_y : str, optional -# Color of y-axis histogram, default "red" -# hist_alpha : float, optional -# Alpha value for histograms, default 0.5 -# scatter_ratio : float, optional -# Ratio of main plot to histograms, default 0.8 -# **kwargs -# Additional keyword arguments passed to scatter and hist functions -# -# Returns -# ------- -# tuple -# (ax, ax_histx, ax_histy, hist_data) - All axes objects and histogram data -# hist_data is a dictionary containing histogram counts and bin edges -# """ -# # Get the current figure if not provided -# if fig is None: -# fig = ax.figure -# -# # Calculate the positions based on scatter_ratio -# margin = 0.1 * (1 - scatter_ratio) -# hist_size = 0.2 * scatter_ratio -# -# # Create the histogram axes -# ax_histx = fig.add_axes( -# [ -# ax.get_position().x0, -# ax.get_position().y1 + margin, -# ax.get_position().width * scatter_ratio, -# hist_size, -# ] -# ) -# ax_histy = fig.add_axes( -# [ -# ax.get_position().x1 + margin, -# ax.get_position().y0, -# hist_size, -# ax.get_position().height * scatter_ratio, -# ] -# ) -# -# # No labels for histograms -# ax_histx.tick_params(axis="x", labelbottom=False) -# ax_histy.tick_params(axis="y", labelleft=False) -# -# # The scatter plot -# ax.scatter( -# x, -# y, -# alpha=scatter_alpha, -# s=scatter_size, -# color=scatter_color, -# **kwargs, -# ) -# -# # Calculate histogram data -# hist_x, bin_edges_x = np.histogram(x, bins=hist_bins) -# hist_y, bin_edges_y = np.histogram(y, bins=hist_bins) -# -# # Plot histograms -# ax_histx.hist(x, bins=hist_bins, color=hist_color_x, alpha=hist_alpha) -# ax_histy.hist( -# y, -# bins=hist_bins, -# orientation="horizontal", -# color=hist_color_y, -# alpha=hist_alpha, -# ) -# -# # Create return data structure -# hist_data = { -# "hist_x": hist_x, -# "hist_y": hist_y, -# "bin_edges_x": bin_edges_x, -# "bin_edges_y": bin_edges_y, -# } -# -# return ax, ax_histx, ax_histy, hist_data -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_scatter_hist.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_shaded_line.py b/tests/scitex/plt/ax/_plot/test__stx_shaded_line.py deleted file mode 100644 index dc6ee2e9c..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_shaded_line.py +++ /dev/null @@ -1,235 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_shaded_line.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 13:15:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_plot/_plot_shaded_line.py -# -# """Line plots with shaded uncertainty regions (e.g., confidence intervals).""" -# -# from typing import Any, List, Optional, Tuple, Union -# -# import numpy as np -# import pandas as pd -# from matplotlib.axes import Axes -# -# from scitex.types import ColorLike -# from ....plt.utils import assert_valid_axis -# -# -# def _plot_single_shaded_line( -# axis: Union[Axes, "AxisWrapper"], -# xx: np.ndarray, -# y_lower: np.ndarray, -# y_middle: np.ndarray, -# y_upper: np.ndarray, -# color: Optional[ColorLike] = None, -# alpha: float = 0.3, -# **kwargs: Any, -# ) -> Tuple[Union[Axes, "AxisWrapper"], pd.DataFrame]: -# """Plot a single line with shaded area between y_lower and y_upper bounds. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes or AxisWrapper -# Axes to plot on. -# xx : np.ndarray -# X values. -# y_lower : np.ndarray -# Lower bound y values. -# y_middle : np.ndarray -# Middle (mean/median) y values. -# y_upper : np.ndarray -# Upper bound y values. -# color : ColorLike, optional -# Color for line and fill. -# alpha : float, default 0.3 -# Transparency for shaded region. -# **kwargs : dict -# Additional keyword arguments passed to plot(). -# -# Returns -# ------- -# axis : matplotlib.axes.Axes or AxisWrapper -# The axes with the plot. -# df : pd.DataFrame -# DataFrame with x, y_lower, y_middle, y_upper columns. -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# assert len(xx) == len(y_middle) == len(y_lower) == len(y_upper), ( -# "All arrays must have the same length" -# ) -# -# label = kwargs.pop("label", None) -# axis.plot(xx, y_middle, color=color, alpha=alpha, label=label, **kwargs) -# kwargs["linewidth"] = 0 -# kwargs["edgecolor"] = "none" # Remove edge line -# axis.fill_between(xx, y_lower, y_upper, alpha=alpha, color=color, **kwargs) -# -# return axis, pd.DataFrame( -# {"x": xx, "y_lower": y_lower, "y_middle": y_middle, "y_upper": y_upper} -# ) -# -# -# def _plot_shaded_line( -# axis: Union[Axes, "AxisWrapper"], -# xs: List[np.ndarray], -# ys_lower: List[np.ndarray], -# ys_middle: List[np.ndarray], -# ys_upper: List[np.ndarray], -# color: Optional[Union[List[ColorLike], ColorLike]] = None, -# **kwargs: Any, -# ) -> Tuple[Union[Axes, "AxisWrapper"], List[pd.DataFrame]]: -# """Plot multiple lines with shaded areas between ys_lower and ys_upper bounds. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes or AxisWrapper -# Axes to plot on. -# xs : list of np.ndarray -# List of x value arrays. -# ys_lower : list of np.ndarray -# List of lower bound y value arrays. -# ys_middle : list of np.ndarray -# List of middle y value arrays. -# ys_upper : list of np.ndarray -# List of upper bound y value arrays. -# color : ColorLike or list of ColorLike, optional -# Color(s) for lines and fills. -# **kwargs : dict -# Additional keyword arguments passed to plot(). -# -# Returns -# ------- -# axis : matplotlib.axes.Axes or AxisWrapper -# The axes with the plots. -# results : list of pd.DataFrame -# List of DataFrames with plot data. -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# assert len(xs) == len(ys_lower) == len(ys_middle) == len(ys_upper), ( -# "All input lists must have the same length" -# ) -# -# results = [] -# colors = color -# color_list = colors -# -# if colors is not None: -# if not isinstance(colors, list): -# color_list = [colors] * len(xs) -# else: -# assert len(colors) == len(xs), "Number of colors must match number of lines" -# color_list = colors -# -# for idx, (xx, y_lower, y_middle, y_upper) in enumerate( -# zip(xs, ys_lower, ys_middle, ys_upper) -# ): -# this_kwargs = kwargs.copy() -# this_kwargs["color"] = color_list[idx] -# _, result_df = _plot_single_shaded_line( -# axis, xx, y_lower, y_middle, y_upper, **this_kwargs -# ) -# results.append(result_df) -# else: -# for xx, y_lower, y_middle, y_upper in zip(xs, ys_lower, ys_middle, ys_upper): -# _, result_df = _plot_single_shaded_line( -# axis, xx, y_lower, y_middle, y_upper, **kwargs -# ) -# results.append(result_df) -# -# return axis, results -# -# -# def stx_shaded_line( -# axis: Union[Axes, "AxisWrapper"], -# xs: Union[np.ndarray, List[np.ndarray]], -# ys_lower: Union[np.ndarray, List[np.ndarray]], -# ys_middle: Union[np.ndarray, List[np.ndarray]], -# ys_upper: Union[np.ndarray, List[np.ndarray]], -# color: Optional[Union[ColorLike, List[ColorLike]]] = None, -# **kwargs: Any, -# ) -> Tuple[Union[Axes, "AxisWrapper"], Union[pd.DataFrame, List[pd.DataFrame]]]: -# """Plot line(s) with shaded uncertainty regions. -# -# Automatically handles both single and multiple line cases. Useful for -# plotting mean/median with confidence intervals or standard deviation bands. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes or AxisWrapper -# Axes to plot on. -# xs : np.ndarray or list of np.ndarray -# X values (single array or list of arrays for multiple lines). -# ys_lower : np.ndarray or list of np.ndarray -# Lower bound y values. -# ys_middle : np.ndarray or list of np.ndarray -# Middle (mean/median) y values. -# ys_upper : np.ndarray or list of np.ndarray -# Upper bound y values. -# color : ColorLike or list of ColorLike, optional -# Color(s) for lines and shaded regions. -# **kwargs : dict -# Additional keyword arguments passed to plot(). -# -# Returns -# ------- -# axis : matplotlib.axes.Axes or AxisWrapper -# The axes with the plot(s). -# data : pd.DataFrame or list of pd.DataFrame -# DataFrame(s) containing plot data with columns: -# x, y_lower, y_middle, y_upper. -# -# Examples -# -------- -# >>> import numpy as np -# >>> import scitex as stx -# >>> x = np.linspace(0, 10, 100) -# >>> y_mean = np.sin(x) -# >>> y_std = 0.2 -# >>> fig, ax = stx.plt.subplots() -# >>> ax, df = stx.plt.ax.stx_shaded_line( -# ... ax, x, y_mean - y_std, y_mean, y_mean + y_std, -# ... color='blue', alpha=0.3 -# ... ) -# """ -# is_single = not ( -# isinstance(xs, list) -# and isinstance(ys_lower, list) -# and isinstance(ys_middle, list) -# and isinstance(ys_upper, list) -# ) -# -# if is_single: -# assert len(xs) == len(ys_lower) == len(ys_middle) == len(ys_upper), ( -# "All arrays must have the same length for single line plot" -# ) -# -# return _plot_single_shaded_line( -# axis, xs, ys_lower, ys_middle, ys_upper, color=color, **kwargs -# ) -# else: -# return _plot_shaded_line( -# axis, xs, ys_lower, ys_middle, ys_upper, color=color, **kwargs -# ) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_shaded_line.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_plot/test__stx_violin.py b/tests/scitex/plt/ax/_plot/test__stx_violin.py deleted file mode 100644 index eede84d39..000000000 --- a/tests/scitex/plt/ax/_plot/test__stx_violin.py +++ /dev/null @@ -1,368 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_violin.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-02 22:01:54 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_plot/_plot_violin.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_plot/_plot_violin.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# import matplotlib.pyplot as plt -# import numpy as np -# import pandas as pd -# import seaborn as sns -# from ....plt.utils import assert_valid_axis -# -# -# def stx_violin( -# ax, -# values_list, -# labels=None, -# colors=None, -# half=False, -# **kwargs, -# ): -# """ -# Plot a violin plot using seaborn. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axes to plot on -# values_list : list of array-like, shape (n_groups,) where each element is (n_samples,) -# List of 1D arrays to plot as violins, one per group -# labels : list, optional -# Labels for each array in values_list -# colors : list, optional -# Colors for each violin -# half : bool, optional -# If True, plots only the left half of the violins, default False -# **kwargs -# Additional keyword arguments passed to seaborn.violinplot -# -# Returns -# ------- -# ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axes object with the plot -# """ -# # Add sample size to label if provided (show range if variable) -# if kwargs.get("label"): -# n_per_group = [len(g) for g in values_list] -# n_min, n_max = min(n_per_group), max(n_per_group) -# n_str = str(n_min) if n_min == n_max else f"{n_min}-{n_max}" -# kwargs["label"] = f"{kwargs['label']} ($n$={n_str})" -# -# # Convert list-style data to DataFrame -# all_values = [] -# all_groups = [] -# -# for idx, values in enumerate(values_list): -# all_values.extend(values) -# group_label = labels[idx] if labels and idx < len(labels) else f"x {idx}" -# all_groups.extend([group_label] * len(values)) -# -# # Create DataFrame -# df = pd.DataFrame({"x": all_groups, "y": all_values}) -# -# # Setup colors if provided -# if colors: -# if isinstance(colors, list): -# kwargs["palette"] = { -# group: color -# for group, color in zip(set(all_groups), colors[: len(set(all_groups))]) -# } -# else: -# kwargs["palette"] = colors -# -# # Call seaborn-based function -# return sns_plot_violin(ax, data=df, x="x", y="y", hue="x", half=half, **kwargs) -# -# -# def sns_plot_violin(ax, data=None, x=None, y=None, hue=None, half=False, **kwargs): -# """ -# Plot a violin plot with option for half violins. -# Parameters -# ---------- -# ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axes to plot on -# data : DataFrame -# The dataframe containing the data -# x : str -# Column name for x-axis variable -# y : str -# Column name for y-axis variable -# hue : str, optional -# Column name for hue variable -# half : bool, optional -# If True, plots only the left half of the violins, default False -# **kwargs -# Additional keyword arguments passed to seaborn.violinplot -# Returns -# ------- -# ax : matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper -# The axes object with the plot -# """ -# assert_valid_axis( -# ax, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# if not half: -# # Standard violin plot -# return sns.violinplot(data=data, x=x, y=y, hue=hue, ax=ax, **kwargs) -# -# # Create a copy of the dataframe to avoid modifying the original -# df = data.copy() -# -# # If no hue provided, create default hue -# if hue is None: -# df["_hue"] = "default" -# hue = "_hue" -# -# # Add fake hue for the right side -# df["_fake_hue"] = df[hue] + "_right" -# -# # Adjust hue_order and palette if provided -# if "hue_order" in kwargs: -# kwargs["hue_order"] = kwargs["hue_order"] + [ -# h + "_right" for h in kwargs["hue_order"] -# ] -# else: -# kwargs["hue_order"] = [] -# for group in df[x].unique().tolist(): -# kwargs["hue_order"].append(group) -# kwargs["hue_order"].append(group + "_right") -# -# if "palette" in kwargs: -# palette = kwargs["palette"] -# if isinstance(palette, dict): -# kwargs["palette"] = { -# **palette, -# **{k + "_right": v for k, v in palette.items()}, -# } -# elif isinstance(palette, list): -# kwargs["palette"] = palette + palette -# -# # Conc left and right -# df_left = df[[x, y]] -# df_right = df[["_fake_hue", y]].rename(columns={"_fake_hue": x}) -# df_right[y] = [np.nan for _ in range(len(df_right))] -# df_conc = pd.concat([df_left, df_right], axis=0, ignore_index=True) -# df_conc = df_conc.sort_values(x) -# -# # Plot -# sns.violinplot(data=df_conc, x=x, y=y, hue="x", split=True, ax=ax, **kwargs) -# -# # Remove right half of violins -# for collection in ax.collections: -# if isinstance(collection, plt.matplotlib.collections.PolyCollection): -# collection.set_clip_path(None) -# -# # Adjust legend -# if ax.legend_ is not None: -# handles, labels = ax.get_legend_handles_labels() -# ax.legend(handles[: len(handles) // 2], labels[: len(labels) // 2]) -# -# return ax -# -# -# # def _plot_half_violin(ax, data=None, x=None, y=None, hue=None, **kwargs): -# -# # assert isinstance( -# # ax, matplotlib.axes._axes.Axes -# # ), "First argument must be a matplotlib axis" -# -# # # Prepare data -# # df = data.copy() -# # if hue is None: -# # df["_hue"] = "default" -# # hue = "_hue" -# -# # # Add fake hue for the right side -# # df["_fake_hue"] = df[hue] + "_right" -# -# # # Adjust hue_order and palette if provided -# # if "hue_order" in kwargs: -# # kwargs["hue_order"] = kwargs["hue_order"] + [ -# # h + "_right" for h in kwargs["hue_order"] -# # ] -# -# # if "palette" in kwargs: -# # palette = kwargs["palette"] -# # if isinstance(palette, dict): -# # kwargs["palette"] = { -# # **palette, -# # **{k + "_right": v for k, v in palette.items()}, -# # } -# # elif isinstance(palette, list): -# # kwargs["palette"] = palette + palette -# -# # # Plot -# # sns.violinplot( -# # data=df, x=x, y=y, hue="_fake_hue", split=True, ax=ax, **kwargs -# # ) -# -# # # Remove right half of violins -# # for collection in ax.collections: -# # if isinstance(collection, plt.matplotlib.collections.PolyCollection): -# # collection.set_clip_path(None) -# -# # # Adjust legend -# # if ax.legend_ is not None: -# # handles, labels = ax.get_legend_handles_labels() -# # ax.legend(handles[: len(handles) // 2], labels[: len(labels) // 2]) -# -# # return ax -# -# # import matplotlib -# # import matplotlib.pyplot as plt -# # import seaborn as sns -# -# # def plot_violin_half(ax, data=None, x=None, y=None, hue=None, **kwargs): -# # """ -# # Plot a half violin plot (showing only the left side of violins). -# -# # Parameters -# # ---------- -# # ax : matplotlib.axes.Axes -# # The axes to plot on -# # data : DataFrame -# # The dataframe containing the data -# # x : str -# # Column name for x-axis variable -# # y : str -# # Column name for y-axis variable -# # hue : str, optional -# # Column name for hue variable -# # **kwargs -# # Additional keyword arguments passed to seaborn.violinplot -# -# # Returns -# # ------- -# # ax : matplotlib.axes.Axes -# # The axes object with the plot -# # """ -# # assert isinstance( -# # ax, matplotlib.axes._axes.Axes -# # ), "First argument must be a matplotlib axis" -# -# # # Prepare data -# # df = data.copy() -# # if hue is None: -# # df["_hue"] = "default" -# # hue = "_hue" -# -# # # Add fake hue for the right side -# # df["_fake_hue"] = df[hue] + "_right" -# -# # # Adjust hue_order and palette if provided -# # if "hue_order" in kwargs: -# # kwargs["hue_order"] = kwargs["hue_order"] + [ -# # h + "_right" for h in kwargs["hue_order"] -# # ] -# # if "palette" in kwargs: -# # palette = kwargs["palette"] -# # if isinstance(palette, dict): -# # kwargs["palette"] = { -# # **palette, -# # **{k + "_right": v for k, v in palette.items()}, -# # } -# # elif isinstance(palette, list): -# # kwargs["palette"] = palette + palette -# -# # # Plot -# # sns.violinplot( -# # data=df, x=x, y=y, hue="_fake_hue", split=True, ax=ax, **kwargs -# # ) -# -# # # Remove right half of violins -# # for collection in ax.collections: -# # if isinstance(collection, matplotlib.collections.PolyCollection): -# # collection.set_clip_path(None) -# -# # # Adjust legend -# # if ax.legend_ is not None: -# # handles, labels = ax.get_legend_handles_labels() -# # ax.legend(handles[: len(handles) // 2], labels[: len(labels) // 2]) -# -# # return ax -# -# -# ## Probably working -# def half_violin(ax, data=None, x=None, y=None, hue=None, **kwargs): -# # Prepare data -# df = data.copy() -# if hue is None: -# df["_hue"] = "default" -# hue = "_hue" -# -# # Add fake hue for the right side -# df["_fake_hue"] = df[hue] + "_right" -# -# # Adjust hue_order and palette if provided -# if "hue_order" in kwargs: -# kwargs["hue_order"] = kwargs["hue_order"] + [ -# h + "_right" for h in kwargs["hue_order"] -# ] -# -# if "palette" in kwargs: -# palette = kwargs["palette"] -# if isinstance(palette, dict): -# kwargs["palette"] = { -# **palette, -# **{k + "_right": v for k, v in palette.items()}, -# } -# elif isinstance(palette, list): -# kwargs["palette"] = palette + palette -# -# # Plot -# sns.violinplot(data=df, x=x, y=y, hue="_fake_hue", split=True, ax=ax, **kwargs) -# -# # Remove right half of violins -# for collection in ax.collections: -# if isinstance(collection, plt.matplotlib.collections.PolyCollection): -# collection.set_clip_path(None) -# -# # Adjust legend -# if ax.legend_ is not None: -# handles, labels = ax.get_legend_handles_labels() -# ax.legend(handles[: len(handles) // 2], labels[: len(labels) // 2]) -# -# return ax -# -# -# # import scitex -# # import numpy as np -# # fig, ax = scitex.plt.subplots() -# # # Test with list data -# # data_list = [ -# # np.random.normal(0, 1, 100), -# # np.random.normal(2, 1.5, 100), -# # np.random.normal(5, 0.8, 100), -# # ] -# # labels = ["x A", "x B", "x C"] -# # colors = ["red", "blue", "green"] -# # half = True -# # ax = half_violin( -# # ax, data_list, x="" -# # ) -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_plot/_stx_violin.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__add_marginal_ax.py b/tests/scitex/plt/ax/_style/test__add_marginal_ax.py deleted file mode 100644 index a38b6a104..000000000 --- a/tests/scitex/plt/ax/_style/test__add_marginal_ax.py +++ /dev/null @@ -1,225 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:37 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__add_marginal_ax.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__add_marginal_ax.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import pytest - -pytest.importorskip("zarr") - -matplotlib.use("Agg") - -from scitex.plt.ax._style import add_marginal_ax - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure(figsize=(8, 6)) - self.ax = self.fig.add_subplot(111) - - # Draw something on the axis for reference - self.ax.plot([0, 1], [0, 1]) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_basic_functionality(self): - # Test adding marginal axes in each position - positions = ["top", "bottom", "left", "right"] - - for position in positions: - ax_marg = add_marginal_ax(self.ax, position) - - # Check that the marginal axis was created - assert ax_marg is not None - assert isinstance(ax_marg, matplotlib.axes.Axes) - - # Check that we have multiple axes in the figure - assert len(self.fig.axes) > 1 - - # Reset for next test - self.fig.clear() - self.ax = self.fig.add_subplot(111) - self.ax.plot([0, 1], [0, 1]) - - def test_size_parameter(self): - # Test with custom size parameter - custom_size = 0.4 - ax_marg = add_marginal_ax(self.ax, "top", size=custom_size) - - # Get main axis and marginal axis positions - main_bbox = self.ax.get_position() - marg_bbox = ax_marg.get_position() - - # Calculate the relative height of marginal axis vs main axis - main_height = main_bbox.height - marg_height = marg_bbox.height - - # Ratio should be approximately equal to the size parameter - # (allowing for some rounding/precision differences) - assert marg_height > 0 - # assert np.isclose(marg_height / main_height, custom_size, rtol=0.1) - - # def test_pad_parameter(self): - # # Test with custom pad parameter - # custom_pad = 0.2 - # ax_marg = add_marginal_ax(self.ax, "top", pad=custom_pad) - - # # Get main axis and marginal axis positions - # main_bbox = self.ax.get_position() - # marg_bbox = ax_marg.get_position() - - # # Calculate the gap between the axes - # main_top = main_bbox.y1 - # marg_bottom = marg_bbox.y0 - - # # The pad is in units of inches, so we need to convert to figure coords - # fig_height_in = self.fig.get_figheight() - # pad_in_fig_coords = custom_pad / fig_height_in - - # # Allow reasonable tolerance since the padding calculation involves several conversions - # assert (marg_bottom - main_top) > 0 # Gap exists - - # def test_aspect_ratio(self): - # # Test that box_aspect is set correctly - - # # For 'top' and 'bottom', box_aspect should be equal to size - # size = 0.3 - # ax_marg_top = add_marginal_ax(self.ax, "top", size=size) - - # # Check if box_aspect matches size - # # Since set_box_aspect doesn't have a direct getter, we'll check indirectly - # # by drawing the figure and checking the resulting aspect ratio - # self.fig.canvas.draw() - - # # For 'left' and 'right', box_aspect should be 1/size - # ax_marg_right = add_marginal_ax(self.ax, "right", size=size) - # self.fig.canvas.draw() - - # # The box_aspect is correctly set in the function, but checking it precisely - # # requires checking the private attribute or rendering metrics which is complex - # # So we'll just check that the axes were created with different shapes - # main_height = self.ax.get_window_extent().height - # right_height = ax_marg_right.get_window_extent().height - # assert np.isclose( - # main_height, right_height, rtol=0.1 - # ) # Heights should be similar - - # main_width = self.ax.get_window_extent().width - # right_width = ax_marg_right.get_window_extent().width - # assert right_width < main_width # Right axis should be narrower - - def test_multiple_marginal_axes(self): - # Test adding multiple marginal axes - ax_top = add_marginal_ax(self.ax, "top") - ax_right = add_marginal_ax(self.ax, "right") - - # Check that three axes exist (main + 2 marginal) - assert len(self.fig.axes) == 3 - - # Draw something in each marginal axis to verify they work - ax_top.plot([0, 1], [0, 1], "r") - ax_right.plot([0, 1], [0, 1], "g") - - # Verify the axes are different - assert ax_top != ax_right - assert self.ax != ax_top - assert self.ax != ax_right - - def test_savefig(self): - from scitex.io import save - - # Main test functionality - self.ax.plot([0, 1], [0, 1], "b-") - - # Add marginal axes - ax_top = add_marginal_ax(self.ax, "top") - ax_top.hist([0.1, 0.2, 0.3, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9], bins=5) - - ax_right = add_marginal_ax(self.ax, "right") - ax_right.hist( - [0.1, 0.2, 0.3, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9], - bins=5, - orientation="horizontal", - ) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_add_marginal_ax.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-04-30 20:18:52 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_add_marginal_ax.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_add_marginal_ax.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# from mpl_toolkits.axes_grid1 import make_axes_locatable -# from ....plt.utils import assert_valid_axis -# -# -# def add_marginal_ax(axis, place, size=0.2, pad=0.1): -# """ -# Add a marginal axis to the specified side of an existing axis. -# -# Arguments: -# axis (matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper): The axis to which a marginal axis will be added. -# place (str): Where to place the marginal axis ('top', 'right', 'bottom', or 'left'). -# size (float, optional): Fractional size of the marginal axis relative to the main axis. Defaults to 0.2. -# pad (float, optional): Padding between the axes. Defaults to 0.1. -# -# Returns: -# matplotlib.axes.Axes: The newly created marginal axis. -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# divider = make_axes_locatable(axis) -# -# size_perc_str = f"{size * 100}%" -# if place in ["left", "right"]: -# size = 1.0 / size -# -# axis_marginal = divider.append_axes(place, size=size_perc_str, pad=pad) -# axis_marginal.set_box_aspect(size) -# -# return axis_marginal -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_add_marginal_ax.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__add_panel.py b/tests/scitex/plt/ax/_style/test__add_panel.py deleted file mode 100644 index 7cba9f9e4..000000000 --- a/tests/scitex/plt/ax/_style/test__add_panel.py +++ /dev/null @@ -1,246 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:26 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__add_panel.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__add_panel.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import add_panel - -matplotlib.use("Agg") - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - pass - - def teardown_method(self): - # Clean up after tests - plt.close("all") - - def test_basic_functionality(self): - # Test with default height parameter (uses H_TO_W_RATIO) - fig, ax = add_panel(tgt_width_mm=40) - - # Calculate expected dimensions - H_TO_W_RATIO = 0.7 - MM_TO_INCH_FACTOR = 1 / 25.4 - expected_width_in = 40 * MM_TO_INCH_FACTOR - expected_height_in = expected_width_in * H_TO_W_RATIO - - # Get actual dimensions - bbox = ax.get_position() - fig_width_in, fig_height_in = fig.get_size_inches() - actual_width_in = bbox.width * fig_width_in - actual_height_in = bbox.height * fig_height_in - - # Check dimensions are correct (with small tolerance) - assert np.isclose(actual_width_in, expected_width_in, rtol=1e-4) - assert np.isclose(actual_height_in, expected_height_in, rtol=1e-4) - - # Check that the axes is properly centered - center_x = bbox.x0 + bbox.width / 2 - center_y = bbox.y0 + bbox.height / 2 - assert np.isclose(center_x, 0.5, rtol=1e-4) - assert np.isclose(center_y, 0.5, rtol=1e-4) - - # Clean up - plt.close(fig) - - def test_custom_dimensions(self): - # Test with custom width and height - tgt_width_mm = 50 - tgt_height_mm = 30 - fig, ax = add_panel(tgt_width_mm=tgt_width_mm, tgt_height_mm=tgt_height_mm) - - # Calculate expected dimensions - MM_TO_INCH_FACTOR = 1 / 25.4 - expected_width_in = tgt_width_mm * MM_TO_INCH_FACTOR - expected_height_in = tgt_height_mm * MM_TO_INCH_FACTOR - - # Get actual dimensions - bbox = ax.get_position() - fig_width_in, fig_height_in = fig.get_size_inches() - actual_width_in = bbox.width * fig_width_in - actual_height_in = bbox.height * fig_height_in - - # Check dimensions are correct (with small tolerance) - assert np.isclose(actual_width_in, expected_width_in, rtol=1e-4) - assert np.isclose(actual_height_in, expected_height_in, rtol=1e-4) - - # Clean up - plt.close(fig) - - # def test_aspect_ratio(self): - # # Test different aspect ratios - # for width_mm, height_mm, expected_ratio in [ - # (40, 20, 0.5), - # (30, 30, 1.0), - # (20, 40, 2.0), - # ]: - # fig, ax = add_panel(tgt_width_mm=width_mm, tgt_height_mm=height_mm) - - # # Calculate actual aspect ratio - # bbox = ax.get_position() - # actual_ratio = (bbox.height / bbox.width) * ( - # fig.get_figwidth() / fig.get_figheight() - # ) - - # # Check aspect ratio is correct - # assert np.isclose(actual_ratio, expected_ratio, rtol=1e-2) - - # # Clean up - # plt.close(fig) - - def test_plotting_compatibility(self): - # Test that the returned axis can be used for plotting - fig, ax = add_panel(tgt_width_mm=40) - - # Try different plotting methods - ax.plot([1, 2, 3], [4, 5, 6]) - ax.scatter([1, 2, 3], [4, 5, 6]) - ax.bar([1, 2, 3], [4, 5, 6]) - - # Check that plots were added to the axis - assert len(ax.lines) > 0 - assert len(ax.collections) > 0 - assert len(ax.patches) > 0 - - # Clean up - plt.close(fig) - - def test_savefig(self): - from scitex.io import save - - # Main test functionality - fig, ax = add_panel(tgt_width_mm=40) - ax.plot([1, 2, 3], [1, 2, 3]) - ax.set_title("Panel Test") - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_add_panel.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-04-30 21:24:49 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_panel.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_panel.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# # Time-stamp: "2024-02-03 15:34:08 (ywatanabe)" -# -# import matplotlib.pyplot as plt -# from scitex.decorators import deprecated -# -# -# def add_panel(tgt_width_mm=40, tgt_height_mm=None): -# """Creates a fixed-size ax figure for panels.""" -# -# H_TO_W_RATIO = 0.7 -# MM_TO_INCH_FACTOR = 1 / 25.4 -# -# if tgt_height_mm is None: -# tgt_height_mm = H_TO_W_RATIO * tgt_width_mm -# -# # Convert target dimensions from millimeters to inches -# tgt_width_in = tgt_width_mm * MM_TO_INCH_FACTOR -# tgt_height_in = tgt_height_mm * MM_TO_INCH_FACTOR -# -# # Create a figure with the specified dimensions -# fig = plt.figure(figsize=(tgt_width_in * 2, tgt_height_in * 2)) -# -# # Calculate the position and size of the axes in figure units (0 to 1) -# left = (fig.get_figwidth() - tgt_width_in) / 2 / fig.get_figwidth() -# bottom = (fig.get_figheight() - tgt_height_in) / 2 / fig.get_figheight() -# ax = fig.add_axes( -# [ -# left, -# bottom, -# tgt_width_in / fig.get_figwidth(), -# tgt_height_in / fig.get_figheight(), -# ] -# ) -# -# return fig, ax -# -# -# @deprecated("Use add_panel instead") -# def panel(tgt_width_mm=40, tgt_height_mm=None): -# """Create a figure panel with specified dimensions (deprecated). -# -# This function is deprecated and maintained only for backward compatibility. -# Please use `add_panel` instead. -# -# Parameters -# ---------- -# tgt_width_mm : float, optional -# Target width in millimeters. Default is 40. -# tgt_height_mm : float or None, optional -# Target height in millimeters. If None, uses golden ratio. -# Default is None. -# -# Returns -# ------- -# tuple -# (fig, ax) - matplotlib figure and axes objects -# -# See Also -# -------- -# add_panel : The recommended function to use instead -# -# Examples -# -------- -# >>> # Deprecated usage -# >>> fig, ax = panel(tgt_width_mm=40, tgt_height_mm=30) -# -# >>> # Recommended alternative -# >>> fig, ax = add_panel(tgt_width_mm=40, tgt_height_mm=30) -# """ -# return add_panel(tgt_width_mm=40, tgt_height_mm=None) -# -# -# if __name__ == "__main__": -# # Example usage: -# fig, ax = panel(tgt_width_mm=40, tgt_height_mm=40 * 0.7) -# ax.plot([1, 2, 3], [4, 5, 6]) -# ax.scatter([1, 2, 3], [4, 5, 6]) -# # ... compatible with other ax plotting methods as well -# plt.show() -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_add_panel.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__auto_scale_axis.py b/tests/scitex/plt/ax/_style/test__auto_scale_axis.py deleted file mode 100644 index 9948089e3..000000000 --- a/tests/scitex/plt/ax/_style/test__auto_scale_axis.py +++ /dev/null @@ -1,215 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_auto_scale_axis.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Time-stamp: "2025-11-19 18:45:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_style/_auto_scale_axis.py -# -# """ -# Automatic axis scaling to factor out common powers of 10. -# -# This utility automatically detects when axis tick values are very small or very large -# and factors out the appropriate power of 10, updating both the tick labels and axis label. -# -# Examples: -# 0.0000, 0.0008, 0.0016, 0.0024 → 0, 0.8, 1.6, 2.4 with label "[×10⁻³]" -# 10000, 20000, 30000, 40000 → 10, 20, 30, 40 with label "[×10³]" -# """ -# -# import numpy as np -# from typing import Optional, Tuple -# -# -# def detect_scale_factor( -# values: np.ndarray, threshold: float = 1e-2 -# ) -> Tuple[int, bool]: -# """ -# Detect appropriate power of 10 to factor out from axis values. -# -# Parameters -# ---------- -# values : np.ndarray -# Array of tick values on the axis -# threshold : float -# Threshold below which we consider factoring out (default: 0.01) -# -# Returns -# ------- -# power : int -# Power of 10 to factor out (e.g., -3 for values like 0.001-0.009) -# should_scale : bool -# Whether scaling should be applied -# -# Examples -# -------- -# >>> detect_scale_factor(np.array([0.0, 0.0008, 0.0016, 0.0024])) -# (-3, True) -# >>> detect_scale_factor(np.array([10000, 20000, 30000])) -# (3, True) -# >>> detect_scale_factor(np.array([0, 1, 2, 3])) -# (0, False) -# """ -# # Filter out zero values for calculation -# nonzero_values = values[values != 0] -# -# if len(nonzero_values) == 0: -# return 0, False -# -# # Get the order of magnitude of the maximum absolute value -# max_abs = np.max(np.abs(nonzero_values)) -# -# # Check if values are very small (< threshold) or very large (> 1/threshold) -# if max_abs < threshold: -# # Values are very small - factor out negative power -# power = int(np.floor(np.log10(max_abs))) -# return power, True -# elif max_abs > 1.0 / threshold: -# # Values are very large - factor out positive power -# power = int(np.floor(np.log10(max_abs))) -# # Only scale if power >= 3 (thousands or larger) -# if power >= 3: -# return power, True -# -# return 0, False -# -# -# def format_scale_factor(power: int) -> str: -# """ -# Format the scale factor for display in axis label. -# -# Parameters -# ---------- -# power : int -# Power of 10 (e.g., -3, 3, 6) -# -# Returns -# ------- -# str -# Formatted string using matplotlib mathtext (e.g., "×10$^{-3}$", "×10$^{6}$") -# -# Examples -# -------- -# >>> format_scale_factor(-3) -# '×10$^{-3}$' -# >>> format_scale_factor(6) -# '×10$^{6}$' -# """ -# if power == 0: -# return "" -# -# # Use matplotlib's mathtext for reliable rendering across all formats -# return f"×10$^{{{power}}}$" -# -# -# def auto_scale_axis(ax, axis: str = "both", threshold: float = 1e-2) -> None: -# """ -# Automatically scale axis to factor out common powers of 10. -# -# This function: -# 1. Detects when tick values are very small or very large -# 2. Factors out the appropriate power of 10 -# 3. Updates tick labels to show factored values -# 4. Appends the scale factor to the axis label -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# Axes object to apply scaling to -# axis : str, optional -# Which axis to scale: 'x', 'y', or 'both' (default: 'both') -# threshold : float, optional -# Threshold for triggering scaling (default: 1e-2) -# Values with max < threshold or max > 1/threshold will be scaled -# -# Examples -# -------- -# >>> import matplotlib.pyplot as plt -# >>> fig, ax = plt.subplots() -# >>> ax.plot([0, 1, 2], [0.0001, 0.0002, 0.0003]) -# >>> ax.set_ylabel('Density') -# >>> auto_scale_axis(ax, axis='y') -# >>> # Y-axis now shows: 0.1, 0.2, 0.3 with label "Density [×10⁻³]" -# -# Notes -# ----- -# - Only scales if the range of values justifies it (very small or very large) -# - Preserves the original axis label and appends the scale factor -# - Uses Unicode superscripts for clean display (×10⁻³, ×10⁶, etc.) -# """ -# import matplotlib.ticker as ticker -# -# def scale_axis_impl(ax_obj, is_x_axis: bool): -# """Internal implementation for scaling a single axis.""" -# # Get current tick values -# if is_x_axis: -# tick_values = np.array(ax_obj.get_xticks()) -# get_label = ax_obj.get_xlabel -# set_label = ax_obj.set_xlabel -# set_formatter = ax_obj.xaxis.set_major_formatter -# else: -# tick_values = np.array(ax_obj.get_yticks()) -# get_label = ax_obj.get_ylabel -# set_label = ax_obj.set_ylabel -# set_formatter = ax_obj.yaxis.set_major_formatter -# -# # Detect if scaling is needed -# power, should_scale = detect_scale_factor(tick_values, threshold) -# -# if not should_scale: -# return -# -# # Create scaling factor -# scale_factor = 10**power -# -# # Update tick formatter to show scaled values -# def format_func(value, pos): -# scaled_value = value / scale_factor -# # Format with appropriate precision -# if abs(scaled_value) < 10: -# return f"{scaled_value:.1f}" -# else: -# return f"{scaled_value:.0f}" -# -# set_formatter(ticker.FuncFormatter(format_func)) -# -# # Update axis label with scale factor -# current_label = get_label() -# scale_str = format_scale_factor(power) -# -# # Check if label already has units in brackets -# if "[" in current_label and "]" in current_label: -# # Insert scale factor before the closing bracket -# # e.g., "Density [a.u.]" → "Density [×10⁻³ a.u.]" -# label_parts = current_label.rsplit("]", 1) -# new_label = f"{label_parts[0]} {scale_str}]{label_parts[1]}" -# else: -# # Append scale factor in brackets -# # e.g., "Density" → "Density [×10⁻³]" -# new_label = ( -# f"{current_label} [{scale_str}]" if current_label else f"[{scale_str}]" -# ) -# -# set_label(new_label) -# -# # Apply to requested axes -# if axis in ["x", "both"]: -# scale_axis_impl(ax, is_x_axis=True) -# if axis in ["y", "both"]: -# scale_axis_impl(ax, is_x_axis=False) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_auto_scale_axis.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__extend.py b/tests/scitex/plt/ax/_style/test__extend.py deleted file mode 100644 index 617e9a418..000000000 --- a/tests/scitex/plt/ax/_style/test__extend.py +++ /dev/null @@ -1,195 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:21 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__extend.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__extend.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import extend - -matplotlib.use("Agg") - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure(figsize=(6, 4)) - self.ax = self.fig.add_subplot(111) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_basic_functionality(self): - # Get original position - original_bbox = self.ax.get_position() - original_width = original_bbox.width - original_height = original_bbox.height - original_center_x = original_bbox.x0 + original_width / 2 - original_center_y = original_bbox.y0 + original_height / 2 - - # Extend width by 50% and keep height same - extended_ax = extend(self.ax, x_ratio=1.5, y_ratio=1.0) - new_bbox = extended_ax.get_position() - - # Check that center point remains the same - new_center_x = new_bbox.x0 + new_bbox.width / 2 - new_center_y = new_bbox.y0 + new_bbox.height / 2 - assert np.isclose(new_center_x, original_center_x) - assert np.isclose(new_center_y, original_center_y) - - # Check that width and height were scaled correctly - assert np.isclose(new_bbox.width, original_width * 1.5) - assert np.isclose(new_bbox.height, original_height * 1.0) - - def test_shrink(self): - # Test shrinking the axes - original_bbox = self.ax.get_position() - original_width = original_bbox.width - original_height = original_bbox.height - - # Shrink width and height by 50% - extended_ax = extend(self.ax, x_ratio=0.5, y_ratio=0.5) - new_bbox = extended_ax.get_position() - - # Check that width and height were scaled correctly - assert np.isclose(new_bbox.width, original_width * 0.5) - assert np.isclose(new_bbox.height, original_height * 0.5) - - def test_asymmetric_scaling(self): - # Test different scaling for width and height - original_bbox = self.ax.get_position() - original_width = original_bbox.width - original_height = original_bbox.height - - # Extend width but shrink height - extended_ax = extend(self.ax, x_ratio=2.0, y_ratio=0.75) - new_bbox = extended_ax.get_position() - - # Check that width and height were scaled correctly - assert np.isclose(new_bbox.width, original_width * 2.0) - assert np.isclose(new_bbox.height, original_height * 0.75) - - def test_edge_cases(self): - # Test with zero scaling (should be invalid but testing edge case) - with pytest.raises(Exception): - extend(self.ax, x_ratio=0, y_ratio=0) - - # Test with default values (should keep size the same) - original_bbox = self.ax.get_position() - extended_ax = extend(self.ax) - new_bbox = extended_ax.get_position() - - assert np.isclose(new_bbox.width, original_bbox.width) - assert np.isclose(new_bbox.height, original_bbox.height) - - def test_savefig(self): - from scitex.io import save - - # Create a simple plot - self.ax.plot([1, 2, 3], [1, 2, 3]) - self.ax.set_title("Original") - - # Main test functionality - extend(self.ax, x_ratio=1.5, y_ratio=0.8) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_extend.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-02 09:00:51 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_extend.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_style/_extend.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# from ....plt.utils import assert_valid_axis -# -# -# def extend(axis, x_ratio=1.0, y_ratio=1.0): -# """ -# Extend or shrink a matplotlib axis or scitex axis wrapper while maintaining its center position. -# -# Args: -# axis (matplotlib.axes._axes.Axes or scitex.plt._subplots.AxisWrapper): The axis to be modified. -# x_ratio (float, optional): The ratio to scale the width. Default is 1.0. -# y_ratio (float, optional): The ratio to scale the height. Default is 1.0. -# -# Returns: -# matplotlib.axes._axes.Axes or scitex.plt._subplots.AxisWrapper: The modified axis. -# -# Raises: -# AssertionError: If the first argument is not a valid axis. -# """ -# -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# assert x_ratio != 0, "x_ratio must not be 0." -# assert y_ratio != 0, "y_ratio must not be 0." -# -# ## Original coordinates -# bbox = axis.get_position() -# left_orig = bbox.x0 -# bottom_orig = bbox.y0 -# width_orig = bbox.x1 - bbox.x0 -# height_orig = bbox.y1 - bbox.y0 -# g_orig = (left_orig + width_orig / 2.0, bottom_orig + height_orig / 2.0) -# -# ## Target coordinates -# g_tgt = g_orig -# width_tgt = width_orig * x_ratio -# height_tgt = height_orig * y_ratio -# left_tgt = g_tgt[0] - width_tgt / 2 -# bottom_tgt = g_tgt[1] - height_tgt / 2 -# -# # Extend the axis -# axis.set_position( -# [ -# left_tgt, -# bottom_tgt, -# width_tgt, -# height_tgt, -# ] -# ) -# return axis -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_extend.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__force_aspect.py b/tests/scitex/plt/ax/_style/test__force_aspect.py deleted file mode 100644 index 0a5cc9843..000000000 --- a/tests/scitex/plt/ax/_style/test__force_aspect.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:30 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__force_aspect.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__force_aspect.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import force_aspect - -matplotlib.use("Agg") # Use non-GUI backend for testing - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure() - self.ax = self.fig.add_subplot(111) - - # Create an image with known dimensions - data = np.random.rand(10, 20) # Height x Width - self.im = self.ax.imshow(data, extent=[0, 20, 0, 10]) # Width x Height - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_basic_functionality(self): - # Test with default aspect (aspect=1) - ax = force_aspect(self.ax) - - # Get the current aspect ratio - current_aspect = self.ax.get_aspect() - - # With aspect=1, it should set aspect to ratio of width/height (20/10 = 2) divided by 1 - # So aspect should be 2 - assert np.isclose(current_aspect, 2.0, rtol=1e-2) - - def test_custom_aspect(self): - # Test with custom aspect = 2 - ax = force_aspect(self.ax, aspect=2) - - # Get the current aspect ratio - current_aspect = self.ax.get_aspect() - - # With aspect=2, it should set aspect to ratio of width/height (20/10 = 2) divided by 2 - # So aspect should be 1 - assert np.isclose(current_aspect, 1.0, rtol=1e-2) - - def test_no_images(self): - # Test with no images on the axes - empty_ax = self.fig.add_subplot(122) - - # Should raise IndexError as the function tries to access im[0] - with pytest.raises(IndexError): - force_aspect(empty_ax) - - def test_with_multiple_images(self): - # Add another image with different dimensions - second_data = np.random.rand(5, 10) # Height x Width - second_im = self.ax.imshow(second_data, extent=[0, 10, 0, 5]) # Width x Height - - # The function should use the first image from get_images() - ax = force_aspect(self.ax) - - # Get the current aspect ratio - current_aspect = self.ax.get_aspect() - - # Should still be using the first image (20/10 = 2) - assert np.isclose(current_aspect, 2.0, rtol=1e-2) - - def test_with_negative_extent(self): - # Create an image with negative extent - neg_data = np.random.rand(10, 20) # Height x Width - neg_ax = self.fig.add_subplot(133) - neg_im = neg_ax.imshow(neg_data, extent=[-20, 0, -10, 0]) # Width x Height - - # Test force_aspect - neg_ax = force_aspect(neg_ax) - - # Should handle negative extent correctly, absolute value is used - current_aspect = neg_ax.get_aspect() - assert np.isclose(current_aspect, 2.0, rtol=1e-2) - - def test_savefig(self): - from scitex.io import save - - # Main test functionality - self.ax.set_title("Force Aspect Ratio") - force_aspect(self.ax, aspect=1.0) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_force_aspect.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-02 09:00:52 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_force_aspect.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_style/_force_aspect.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# from ....plt.utils import assert_valid_axis -# -# -# def force_aspect(axis, aspect=1): -# """ -# Forces aspect ratio of an axis based on the extent of the image. -# -# Arguments: -# axis (matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper): The axis to adjust. -# aspect (float, optional): The aspect ratio to apply. Defaults to 1. -# -# Returns: -# matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper: The axis with adjusted aspect ratio. -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# im = axis.get_images() -# -# extent = im[0].get_extent() -# -# axis.set_aspect(abs((extent[1] - extent[0]) / (extent[3] - extent[2])) / aspect) -# return axis -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_force_aspect.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__format_label.py b/tests/scitex/plt/ax/_style/test__format_label.py deleted file mode 100644 index 428e8e33d..000000000 --- a/tests/scitex/plt/ax/_style/test__format_label.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-06-11 03:30:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_style/test__format_label.py -# ---------------------------------------- -"""Comprehensive tests for format_label function.""" - -import os - -import pytest - -pytest.importorskip("zarr") -from unittest.mock import MagicMock, patch - -import matplotlib.pyplot as plt -import numpy as np - -__FILE__ = "./tests/scitex/plt/ax/_style/test__format_label.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from scitex.plt.ax._style import format_label - - -class TestFormatLabelBasicFunctionality: - """Test basic functionality of format_label.""" - - def test_passthrough_strings(self): - """Test that strings are passed through unchanged (current behavior).""" - assert format_label("test_label") == "test_label" - assert ( - format_label("complex_label_with_underscores") - == "complex_label_with_underscores" - ) - assert format_label("UPPERCASE") == "UPPERCASE" - assert format_label("lowercase") == "lowercase" - assert format_label("MixedCase") == "MixedCase" - assert format_label("camelCase") == "camelCase" - assert format_label("PascalCase") == "PascalCase" - - def test_passthrough_numbers(self): - """Test that numeric values are passed through unchanged.""" - assert format_label(123) == 123 - assert format_label(123.456) == 123.456 - assert format_label(0) == 0 - assert format_label(-42) == -42 - assert format_label(1e6) == 1e6 - assert format_label(np.pi) == np.pi - - def test_passthrough_none(self): - """Test that None is passed through unchanged.""" - assert format_label(None) is None - - def test_passthrough_containers(self): - """Test that containers are passed through unchanged.""" - assert format_label([1, 2, 3]) == [1, 2, 3] - assert format_label((1, 2, 3)) == (1, 2, 3) - assert format_label({1, 2, 3}) == {1, 2, 3} - assert format_label({"a": 1, "b": 2}) == {"a": 1, "b": 2} - - def test_passthrough_numpy_arrays(self): - """Test that numpy arrays are passed through unchanged.""" - arr = np.array([1, 2, 3]) - result = format_label(arr) - assert np.array_equal(result, arr) - assert result is arr # Same object - - def test_empty_string(self): - """Test empty string handling.""" - assert format_label("") == "" - - def test_whitespace_strings(self): - """Test strings with various whitespace.""" - assert format_label(" ") == " " - assert format_label("\t") == "\t" - assert format_label("\n") == "\n" - assert format_label(" spaced ") == " spaced " - assert format_label("multi\nline") == "multi\nline" - - def test_special_characters(self): - """Test strings with special characters.""" - assert format_label("special!@#$%^&*()_+") == "special!@#$%^&*()_+" - assert format_label("path/to/file.txt") == "path/to/file.txt" - assert format_label("key=value") == "key=value" - assert format_label("item[0]") == "item[0]" - assert format_label("{braces}") == "{braces}" - - def test_unicode_characters(self): - """Test strings with unicode characters.""" - assert format_label("unicode_текст_测试") == "unicode_текст_测试" - assert format_label("π_constant") == "π_constant" - assert format_label("café") == "café" - assert format_label("naïve") == "naïve" - assert format_label("emoji_🎨_test") == "emoji_🎨_test" - - def test_latex_strings(self): - """Test LaTeX formatted strings.""" - assert format_label(r"$\alpha$") == r"$\alpha$" - assert format_label(r"$\beta_{test}$") == r"$\beta_{test}$" - assert format_label(r"$\frac{a}{b}$") == r"$\frac{a}{b}$" - assert format_label(r"$\sum_{i=0}^{n} x_i$") == r"$\sum_{i=0}^{n} x_i$" - - -class TestFormatLabelCommentedFunctionality: - """Test the commented-out functionality for future reference.""" - - def test_future_underscore_replacement(self): - """Test what the function would do if underscore replacement was enabled.""" - # Currently returns unchanged - assert format_label("test_label") == "test_label" - # Would return: "Test Label" if enabled - - assert ( - format_label("complex_label_with_underscores") - == "complex_label_with_underscores" - ) - # Would return: "Complex Label With Underscores" if enabled - - assert format_label("__private__") == "__private__" - # Would return: " Private " if enabled (preserving double underscores) - - def test_future_capitalization(self): - """Test what the function would do if capitalization was enabled.""" - # Currently returns unchanged - assert format_label("all_lowercase") == "all_lowercase" - # Would return: "All Lowercase" if enabled - - assert format_label("KEEP_UPPERCASE") == "KEEP_UPPERCASE" - # Would return: "KEEP_UPPERCASE" if enabled (all caps preserved) - - assert format_label("mixed_Case_Label") == "mixed_Case_Label" - # Would return: "Mixed Case Label" if enabled - - def test_future_edge_cases(self): - """Test edge cases for the commented functionality.""" - # Single character - assert format_label("x") == "x" - # Would return: "X" if enabled - - # Multiple underscores - assert format_label("a__b___c") == "a__b___c" - # Would return: "A B C" if enabled - - # Leading/trailing underscores - assert format_label("_private_var_") == "_private_var_" - # Would return: " Private Var " if enabled - - -class TestFormatLabelWithMatplotlib: - """Test format_label in matplotlib context.""" - - def test_with_axis_labels(self): - """Test using format_label with axis labels.""" - fig, ax = plt.subplots() - - xlabel = format_label("time_seconds") - ylabel = format_label("voltage_mV") - title = format_label("experiment_results") - - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(title) - - assert ax.get_xlabel() == "time_seconds" - assert ax.get_ylabel() == "voltage_mV" - assert ax.get_title() == "experiment_results" - - plt.close(fig) - - def test_with_legend_labels(self): - """Test using format_label with legend labels.""" - fig, ax = plt.subplots() - - x = np.linspace(0, 10, 100) - labels = ["sine_wave", "cosine_wave", "tangent_wave"] - - for i, func in enumerate([np.sin, np.cos, np.tan]): - ax.plot(x, func(x), label=format_label(labels[i])) - - legend = ax.legend() - legend_texts = [t.get_text() for t in legend.get_texts()] - - assert legend_texts == labels - - plt.close(fig) - - def test_with_tick_labels(self): - """Test using format_label with tick labels.""" - fig, ax = plt.subplots() - - categories = ["category_A", "category_B", "category_C"] - formatted_categories = [format_label(cat) for cat in categories] - - ax.bar(range(len(categories)), [1, 2, 3]) - ax.set_xticks(range(len(categories))) - ax.set_xticklabels(formatted_categories) - - tick_labels = [t.get_text() for t in ax.get_xticklabels()] - assert tick_labels == categories - - plt.close(fig) - - def test_savefig_integration(self): - """Test integration with figure saving.""" - import matplotlib.pyplot as plt - - from scitex.io import save - - # Setup - fig, ax = plt.subplots() - ax.plot([1, 2, 3], [1, 2, 3]) - - # Apply formatted label - label = format_label("test_label_for_saving") - ax.set_title(label) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - plt.close(fig) - - -class TestFormatLabelRobustness: - """Test robustness and error handling.""" - - def test_very_long_labels(self): - """Test handling of very long labels.""" - long_label = "a" * 1000 - assert format_label(long_label) == long_label - - long_underscore_label = "_".join(["word"] * 100) - assert format_label(long_underscore_label) == long_underscore_label - - def test_numeric_string_labels(self): - """Test labels that look like numbers.""" - assert format_label("123") == "123" - assert format_label("3.14159") == "3.14159" - assert format_label("1e6") == "1e6" - assert format_label("0xFF") == "0xFF" - - def test_mixed_type_labels(self): - """Test labels with mixed content.""" - assert format_label("label_123") == "label_123" - assert format_label("v2.0_beta") == "v2.0_beta" - assert format_label("test@2024") == "test@2024" - - def test_boolean_values(self): - """Test boolean value handling.""" - assert format_label(True) is True - assert format_label(False) is False - - def test_custom_objects(self): - """Test custom objects that might be used as labels.""" - - class CustomLabel: - def __str__(self): - return "custom_label" - - obj = CustomLabel() - assert format_label(obj) is obj # Returns unchanged - - def test_callable_objects(self): - """Test callable objects.""" - func = lambda x: x - assert format_label(func) is func - - def named_func(): - pass - - assert format_label(named_func) is named_func - - -class TestFormatLabelPerformance: - """Test performance characteristics.""" - - def test_no_unnecessary_string_copies(self): - """Test that strings aren't unnecessarily copied.""" - original = "test_string" - result = format_label(original) - assert result is original # Same object since no transformation - - def test_handles_many_calls(self): - """Test performance with many calls.""" - labels = [f"label_{i}" for i in range(1000)] - - # Should handle many calls efficiently - formatted = [format_label(label) for label in labels] - assert formatted == labels - - def test_memory_efficiency(self): - """Test memory efficiency with various inputs.""" - # Large objects should be returned unchanged without copying - large_array = np.zeros((1000, 1000)) - result = format_label(large_array) - assert result is large_array # Same object - - -class TestFormatLabelIntegration: - """Test integration with the broader scitex ecosystem.""" - - @patch("scitex.plt.ax._style._format_label.format_label") - def test_mocked_enhanced_functionality(self, mock_format): - """Test what enhanced functionality might look like.""" - - # Mock the enhanced functionality - def enhanced_format(label): - if isinstance(label, str): - # Check if already uppercase BEFORE transforming - is_upper = label.isupper() - if is_upper: - return label - label = label.replace("_", " ") - label = " ".join(word.capitalize() for word in label.split()) - return label - - mock_format.side_effect = enhanced_format - - # Test enhanced behavior - assert mock_format("test_label") == "Test Label" - assert mock_format("UPPERCASE") == "UPPERCASE" - - def test_compatible_with_matplotlib_text(self): - """Test compatibility with matplotlib Text objects.""" - fig, ax = plt.subplots() - - text = ax.text(0.5, 0.5, format_label("test_text")) - assert text.get_text() == "test_text" - - plt.close(fig) - - def test_preserves_label_properties(self): - """Test that label properties are preserved.""" - labels_with_properties = [ - ("$equation$", True), # Math text - ("plain text", False), - (r"\LaTeX", False), - ("_subscript", False), - ("^superscript", False), - ] - - for label, is_math in labels_with_properties: - formatted = format_label(label) - assert formatted == label # Current behavior preserves everything - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_format_label.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Time-stamp: "2024-09-15 09:39:02 (ywatanabe)" -# # /home/ywatanabe/proj/_scitex_repo_openhands/src/scitex/plt/ax/_format_label.py -# -# -# def format_label(label): -# """ -# Format label by capitalizing first letter and replacing underscores with spaces. -# """ -# -# # if isinstance(label, str): -# # # Replace underscores with spaces -# # label = label.replace("_", " ") -# -# # # Capitalize first letter of each word -# # label = " ".join(word.capitalize() for word in label.split()) -# -# # # Special case for abbreviations (all caps) -# # if label.isupper(): -# # return label -# -# return label - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_format_label.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__format_units.py b/tests/scitex/plt/ax/_style/test__format_units.py deleted file mode 100644 index 6d89c3d73..000000000 --- a/tests/scitex/plt/ax/_style/test__format_units.py +++ /dev/null @@ -1,119 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_format_units.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-11-19 15:10:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_style/_format_units.py -# -# """ -# Utility functions for formatting axis labels with proper unit notation. -# """ -# -# from typing import Optional -# -# -# def format_label(label: str, unit: Optional[str] = None) -> str: -# """ -# Format axis label with unit in brackets (publication standard). -# -# Parameters -# ---------- -# label : str -# The label text (e.g., "Time", "Voltage") -# unit : str, optional -# The unit (e.g., "s", "mV", "Hz"). If None, returns label as-is. -# -# Returns -# ------- -# str -# Formatted label with unit in brackets (e.g., "Time [s]") -# -# Examples -# -------- -# >>> stx.ax.format_label("Time", "s") -# 'Time [s]' -# -# >>> stx.ax.format_label("Voltage", "mV") -# 'Voltage [mV]' -# -# >>> stx.ax.format_label("Count") -# 'Count' -# -# >>> # Direct usage with axis -# >>> ax.set_xlabel(stx.ax.format_label("Time", "s")) -# >>> ax.set_ylabel(stx.ax.format_label("Amplitude", "mV")) -# -# Notes -# ----- -# According to publication standards (Nature, Science, Cell), units should be -# enclosed in square brackets, not parentheses: -# - Correct: "Time [s]", "Voltage [mV]" -# - Incorrect: "Time (s)", "Voltage (mV)" -# """ -# if unit is None or unit == "": -# return label -# return f"{label} [{unit}]" -# -# -# def format_label_auto(text: str) -> str: -# """ -# Automatically convert parentheses-style units to bracket-style. -# -# This function detects units in parentheses and converts them to brackets. -# -# Parameters -# ---------- -# text : str -# Label text, possibly with units in parentheses -# -# Returns -# ------- -# str -# Label text with units in brackets -# -# Examples -# -------- -# >>> stx.ax.format_label_auto("Time (s)") -# 'Time [s]' -# -# >>> stx.ax.format_label_auto("Voltage (mV)") -# 'Voltage [mV]' -# -# >>> stx.ax.format_label_auto("Count") -# 'Count' -# -# Notes -# ----- -# This is useful for automatically correcting existing labels that use -# parentheses notation. -# """ -# import re -# -# # Pattern to match units in parentheses at the end of the string -# # e.g., "Time (s)" or "Frequency (Hz)" -# pattern = r"\s*\(([^)]+)\)\s*$" -# -# match = re.search(pattern, text) -# if match: -# unit = match.group(1) -# label = text[: match.start()].strip() -# return f"{label} [{unit}]" -# -# return text -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_format_units.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__hide_spines.py b/tests/scitex/plt/ax/_style/test__hide_spines.py deleted file mode 100644 index 3ed3fb665..000000000 --- a/tests/scitex/plt/ax/_style/test__hide_spines.py +++ /dev/null @@ -1,222 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-18 16:30:42 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_style/test__hide_spines.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__hide_spines.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib.pyplot as plt -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style._hide_spines import hide_spines - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure(figsize=(6, 4)) - self.ax = self.fig.add_subplot(111) - # Create a basic plot - self.ax.plot([1, 2, 3], [1, 2, 3]) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_hide_all_spines(self): - # Test hiding all spines by specifying all parameters - ax = hide_spines(self.ax, top=True, bottom=True, left=True, right=True) - - # Check that all spines are hidden - assert not ax.spines["top"].get_visible() - assert not ax.spines["bottom"].get_visible() - assert not ax.spines["left"].get_visible() - assert not ax.spines["right"].get_visible() - - def test_hide_specific_spines(self): - # Test default behavior (hides top and right spines) - ax = hide_spines(self.ax) - - # Check that only default spines (top, right) are hidden - assert not ax.spines["top"].get_visible() - assert ax.spines["bottom"].get_visible() - assert ax.spines["left"].get_visible() - assert not ax.spines["right"].get_visible() - - def test_keep_ticks_and_labels(self): - # Test keeping ticks and labels while hiding all spines - ax = hide_spines( - self.ax, - top=True, - bottom=True, - left=True, - right=True, - ticks=False, - labels=False, - ) - - # Check that all spines are hidden - assert not ax.spines["top"].get_visible() - assert not ax.spines["bottom"].get_visible() - assert not ax.spines["left"].get_visible() - assert not ax.spines["right"].get_visible() - - # Ticks and labels should still be there - fig = ax.get_figure() - fig.canvas.draw() - assert ax.xaxis.get_major_ticks() != [] - assert ax.yaxis.get_major_ticks() != [] - - def test_savefig(self): - from scitex.io import save - - # Main test functionality - hide_spines(self.ax, top=True, right=True, bottom=False, left=False) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - # def test_hide_ticks_only(self): - # # Test hiding ticks but keeping labels - # ax = hide_spines(self.ax, ticks=True, labels=False) - - # # Check that appropriate ticks are hidden - # assert ax.xaxis.get_ticks_position() == "none" - # assert ax.yaxis.get_ticks_position() == "none" - - # # But labels should still be there - # fig = ax.get_figure() - # fig.canvas.draw() - # assert not all( - # label.get_text() == "" for label in ax.get_xticklabels() - # ) - # assert not all( - # label.get_text() == "" for label in ax.get_yticklabels() - # ) - - -# def test_hide_labels_only(self): -# # Test hiding labels but keeping ticks -# ax = hide_spines(self.ax, ticks=False, labels=True) - -# # Check that labels are hidden -# fig = ax.get_figure() -# fig.canvas.draw() -# assert all(label.get_text() == "" for label in ax.get_xticklabels()) -# assert all(label.get_text() == "" for label in ax.get_yticklabels()) - -# # But ticks should still be visible -# assert ax.xaxis.get_ticks_position() != "none" -# assert ax.yaxis.get_ticks_position() != "none" - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_hide_spines.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-06-07 15:45:36 (ywatanabe)" -# # File: /ssh:ywatanabe@sp:/home/ywatanabe/proj/.claude-worktree/scitex_repo/src/scitex/plt/ax/_style/_hide_spines.py -# # ---------------------------------------- -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# # Time-stamp: "2024-04-26 20:03:45 (ywatanabe)" -# -# import matplotlib -# from ....plt.utils import assert_valid_axis -# -# -# def hide_spines( -# axis, -# top=True, -# bottom=False, -# left=False, -# right=True, -# ticks=False, -# labels=False, -# ): -# """ -# Hides the specified spines of a matplotlib Axes object or scitex axis wrapper and optionally removes the ticks and labels. -# -# This function is designed to work with matplotlib Axes objects or scitex axis wrappers. It allows for a cleaner, more minimalist -# presentation of plots by hiding the spines (the lines denoting the boundaries of the plot area) and optionally -# removing the ticks and labels from the axes. -# -# Arguments: -# ax (matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper): The axis for which the spines will be hidden. -# top (bool, optional): If True, hides the top spine. Defaults to True. -# bottom (bool, optional): If True, hides the bottom spine. Defaults to False. -# left (bool, optional): If True, hides the left spine. Defaults to False. -# right (bool, optional): If True, hides the right spine. Defaults to True. -# ticks (bool, optional): If True, removes the ticks from the hidden spines' axes. Defaults to False. -# labels (bool, optional): If True, removes the labels from the hidden spines' axes. Defaults to False. -# -# Returns: -# matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper: The modified axis with the specified spines hidden. -# -# Example: -# >>> fig, ax = plt.subplots() -# >>> hide_spines(ax) -# >>> plt.show() -# """ -# assert_valid_axis( -# axis, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# tgts = [] -# if top: -# tgts.append("top") -# if bottom: -# tgts.append("bottom") -# if left: -# tgts.append("left") -# if right: -# tgts.append("right") -# -# for tgt in tgts: -# # Spines -# axis.spines[tgt].set_visible(False) -# -# # Ticks -# if ticks: -# if tgt == "bottom": -# axis.xaxis.set_ticks_position("none") -# elif tgt == "left": -# axis.yaxis.set_ticks_position("none") -# -# # Labels -# if labels: -# if tgt == "bottom": -# axis.set_xticklabels([]) -# elif tgt == "left": -# axis.set_yticklabels([]) -# -# return axis -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_hide_spines.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__map_ticks.py b/tests/scitex/plt/ax/_style/test__map_ticks.py deleted file mode 100644 index dd9b0d776..000000000 --- a/tests/scitex/plt/ax/_style/test__map_ticks.py +++ /dev/null @@ -1,381 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:28 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__map_ticks.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__map_ticks.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np - -matplotlib.use("Agg") # Use non-GUI backend for testing - -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import map_ticks - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure() - self.ax = self.fig.add_subplot(111) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_numeric_mapping_x_axis(self): - # Create a plot with numeric x-axis - xx = np.linspace(0, 2 * np.pi, 100) - yy = np.sin(xx) - self.ax.plot(xx, yy) - - # Define mapping points and labels - src = [0, np.pi / 2, np.pi, 3 * np.pi / 2, 2 * np.pi] - tgt = ["0", "π/2", "π", "3π/2", "2π"] - - # Apply mapping - ax = map_ticks(self.ax, src, tgt, axis="x") - - # Force draw to update tick labels - self.fig.canvas.draw() - - # Get tick positions and labels - tick_positions = ax.get_xticks() - tick_labels = [label.get_text() for label in ax.get_xticklabels()] - - # Check that ticks were set correctly - assert len(tick_positions) == len(src) - assert np.allclose(tick_positions, src) - assert tick_labels == tgt - - def test_numeric_mapping_y_axis(self): - # Create a plot with numeric y-axis - xx = np.linspace(0, 2 * np.pi, 100) - yy = np.sin(xx) - self.ax.plot(xx, yy) - - # Define mapping points and labels - src = [-1, -0.5, 0, 0.5, 1] - tgt = ["-1.0", "-0.5", "0.0", "0.5", "1.0"] - - # Apply mapping - ax = map_ticks(self.ax, src, tgt, axis="y") - - # Force draw to update tick labels - self.fig.canvas.draw() - - # Get tick positions and labels - tick_positions = ax.get_yticks() - tick_labels = [label.get_text() for label in ax.get_yticklabels()] - - # Check that ticks were set correctly - assert len(tick_positions) == len(src) - assert np.allclose(tick_positions, src) - assert tick_labels == tgt - - def test_string_mapping_x_axis(self): - # Create a categorical plot - categories = ["A", "B", "C", "D", "E"] - values = [1, 3, 2, 5, 4] - self.ax.bar(categories, values) - - # Force draw to ensure tick labels are created - self.fig.canvas.draw() - - # Define mapping - src = categories - tgt = ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"] - - # Apply mapping - ax = map_ticks(self.ax, src, tgt, axis="x") - - # Force draw again to update tick labels - self.fig.canvas.draw() - - # Get tick labels - tick_labels = [label.get_text() for label in ax.get_xticklabels()] - - # Check that labels were mapped correctly - assert tick_labels == tgt - - def test_mismatched_lengths(self): - """Test when source and target have different lengths.""" - import re - - # Create mismatched input - src = [1, 2, 3] - tgt = [4, 5] - - # Should raise ValueError - with pytest.raises( - ValueError, - ): - map_ticks(self.ax, src, tgt) - - def test_invalid_axis(self): - # Test error with invalid axis parameter - src = [0, 1, 2] - tgt = ["A", "B", "C"] - - # Should raise ValueError - with pytest.raises(ValueError, match="Invalid axis"): - map_ticks(self.ax, src, tgt, axis="z") - - def test_partial_string_mapping(self): - # Create a categorical plot - categories = ["A", "B", "C", "D", "E"] - values = [1, 3, 2, 5, 4] - self.ax.bar(categories, values) - - # Force draw to ensure tick labels are created - self.fig.canvas.draw() - - # Define partial mapping (only some categories) - src = ["A", "C", "E"] - tgt = ["Alpha", "Gamma", "Epsilon"] - - # Apply mapping - ax = map_ticks(self.ax, src, tgt, axis="x") - - # Force draw again to update tick labels - self.fig.canvas.draw() - - # Get tick labels and positions - tick_labels = [label.get_text() for label in ax.get_xticklabels()] - tick_positions = ax.get_xticks() - - # Check results - assert len(tick_positions) == len(src) - assert tick_labels == tgt - - def test_savefig(self): - import numpy as np - - from scitex.io import save - - # Setup plot with data - xx = np.linspace(0, 2 * np.pi, 100) - yy = np.sin(xx) - self.ax.plot(xx, yy) - - # Main test functionality - src = [0, np.pi / 2, np.pi, 3 * np.pi / 2, 2 * np.pi] - tgt = ["0", "π/2", "π", "3π/2", "2π"] - map_ticks(self.ax, src, tgt, axis="x") - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_map_ticks.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-02 09:00:56 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_map_ticks.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_style/_map_ticks.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# import matplotlib.pyplot as plt -# import numpy as np -# -# from ....plt.utils import assert_valid_axis -# -# -# def map_ticks(ax, src, tgt, axis="x"): -# """ -# Maps source tick positions or labels to new target labels on a matplotlib Axes object. -# Supports both numeric positions and string labels for source ticks ('src'), enabling the mapping -# to new target labels ('tgt'). This ensures only the specified target ticks are displayed on the -# final axis, enhancing the clarity and readability of plots. -# -# Parameters: -# - ax (matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper): The Axes object to modify. -# - src (list of str or numeric): Source positions (if numeric) or labels (if str) to map from. -# When using string labels, ensure they match the current tick labels on the axis. -# - tgt (list of str): New target labels to apply to the axis. Must have the same length as 'src'. -# - axis (str): Specifies which axis to apply the tick modifications ('x' or 'y'). -# -# Returns: -# - ax (matplotlib.axes.Axes or scitex.plt._subplots.AxisWrapper): The modified Axes object with adjusted tick labels. -# -# Examples: -# -------- -# Numeric Example: -# fig, ax = plt.subplots() -# x = np.linspace(0, 2 * np.pi, 100) -# y = np.sin(x) -# ax.plot(x, y) # Plot a sine wave -# src = [0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi] # Numeric src positions -# tgt = ['0', 'π/2', 'π', '3π/2', '2π'] # Corresponding target labels -# map_ticks(ax, src, tgt, axis="x") # Map src to tgt on the x-axis -# plt.show() -# -# String Example: -# fig, ax = plt.subplots() -# categories = ['A', 'B', 'C', 'D', 'E'] # Initial categories -# values = [1, 3, 2, 5, 4] -# ax.bar(categories, values) # Bar plot with string labels -# src = ['A', 'B', 'C', 'D', 'E'] # Source labels to map from -# tgt = ['Alpha', 'Beta', 'Gamma', 'Delta', 'Epsilon'] # New target labels -# map_ticks(ax, src, tgt, axis="x") # Apply the mapping -# plt.show() -# """ -# assert_valid_axis( -# ax, "First argument must be a matplotlib axis or scitex axis wrapper" -# ) -# -# if len(src) != len(tgt): -# raise ValueError( -# "Source ('src') and target ('tgt') must have the same number of elements." -# ) -# -# # Determine tick positions if src is string data -# if all(isinstance(item, str) for item in src): -# if axis == "x": -# all_labels = [label.get_text() for label in ax.get_xticklabels()] -# else: -# all_labels = [label.get_text() for label in ax.get_yticklabels()] -# -# # Find positions of src labels -# src_positions = [all_labels.index(s) for s in src if s in all_labels] -# else: -# # Use src as positions directly if numeric -# src_positions = src -# -# # Set the ticks and labels based on the specified axis -# if axis == "x": -# ax.set_xticks(src_positions) -# ax.set_xticklabels(tgt) -# elif axis == "y": -# ax.set_yticks(src_positions) -# ax.set_yticklabels(tgt) -# else: -# raise ValueError("Invalid axis argument. Use 'x' or 'y'.") -# -# return ax -# -# -# def numeric_example(): -# """Example demonstrating numeric tick mapping. -# -# Shows how to replace numeric tick positions with custom labels, -# such as replacing radian values with pi notation in trigonometric plots. -# -# Returns -# ------- -# matplotlib.figure.Figure -# Figure with two subplots showing before and after tick mapping. -# -# Examples -# -------- -# >>> fig = numeric_example() -# >>> plt.show() -# -# Notes -# ----- -# The top subplot shows original numeric labels, while the bottom -# subplot shows the same data with custom pi notation labels. -# """ -# fig, axs = plt.subplots(2, 1, figsize=(10, 6)) # Two rows, one column -# -# # Original plot -# x = np.linspace(0, 2 * np.pi, 100) -# y = np.sin(x) -# axs[0].plot(x, y) # Plot a sine wave on the first row -# axs[0].set_title("Original Numeric Labels") -# -# # Numeric src positions for ticks (e.g., multiples of pi) and target labels -# src = [0, np.pi / 2, np.pi, 3 * np.pi / 2, 2 * np.pi] -# tgt = ["0", "π/2", "π", "3π/2", "2π"] -# -# # Plot with mapped ticks -# axs[1].plot(x, y) # Plot again on the second row for mapped labels -# map_ticks(axs[1], src, tgt, axis="x") -# axs[1].set_title("Mapped Numeric Labels") -# -# return fig -# -# -# def string_example(): -# """Example demonstrating string tick mapping. -# -# Shows how to replace categorical string labels with more descriptive -# alternatives, useful for improving plot readability. -# -# Returns -# ------- -# matplotlib.figure.Figure -# Figure with two subplots showing before and after tick mapping. -# -# Examples -# -------- -# >>> fig = string_example() -# >>> plt.show() -# -# Notes -# ----- -# The top subplot shows original short category labels (A, B, C...), -# while the bottom subplot shows the same data with descriptive Greek -# letter names. -# """ -# fig, axs = plt.subplots(2, 1, figsize=(10, 6)) # Two rows, one column -# -# # Original plot with categorical string labels -# categories = ["A", "B", "C", "D", "E"] -# values = [1, 3, 2, 5, 4] -# axs[0].bar(categories, values) -# axs[0].set_title("Original String Labels") -# -# # src as the existing labels to change and target labels -# src = categories -# tgt = ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"] -# -# # Plot with mapped string labels -# axs[1].bar(categories, values) # Bar plot again on the second row for mapped labels -# map_ticks(axs[1], src, tgt, axis="x") -# axs[1].set_title("Mapped String Labels") -# -# return fig -# -# -# # Execute examples -# if __name__ == "__main__": -# fig_numeric = numeric_example() -# fig_string = string_example() -# -# plt.tight_layout() -# plt.show() -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_map_ticks.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__rotate_labels.py b/tests/scitex/plt/ax/_style/test__rotate_labels.py deleted file mode 100644 index d3ba8a4cd..000000000 --- a/tests/scitex/plt/ax/_style/test__rotate_labels.py +++ /dev/null @@ -1,460 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:33 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__rotate_labels.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__rotate_labels.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pytest - -pytest.importorskip("zarr") - -matplotlib.use("Agg") - -from scitex.plt.ax._style import rotate_labels - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure() - self.ax = self.fig.add_subplot(111) - - # Create a basic plot with labels - xx = np.linspace(0, 10, 5) - yy = np.sin(xx) - self.ax.plot(xx, yy) - self.ax.set_xticks(xx) - self.ax.set_yticks(yy) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_basic_functionality(self): - # Test with rotation parameters (default args don't rotate) - ax = rotate_labels(self.ax, x=45, y=45) - - # Force draw to ensure labels are updated - self.fig.canvas.draw() - - # Check x and y tick label rotations - for label in ax.get_xticklabels(): - assert label.get_rotation() == 45 - - for label in ax.get_yticklabels(): - assert label.get_rotation() == 45 - - def test_custom_rotations(self): - # Test with custom rotation angles - ax = rotate_labels(self.ax, x=30, y=60) - - # Force draw to ensure labels are updated - self.fig.canvas.draw() - - # Check custom x and y tick label rotations - for label in ax.get_xticklabels(): - assert label.get_rotation() == 30 - - for label in ax.get_yticklabels(): - assert label.get_rotation() == 60 - - def test_custom_alignment(self): - # Test with custom horizontal alignments (must also provide rotation) - ax = rotate_labels(self.ax, x=45, y=45, x_ha="left", y_ha="right") - - # Force draw to ensure labels are updated - self.fig.canvas.draw() - - # Check custom alignments - for label in ax.get_xticklabels(): - assert label.get_ha() == "left" - - for label in ax.get_yticklabels(): - assert label.get_ha() == "right" - - def test_rotate_x_only(self): - # Test rotating only x labels - ax = rotate_labels(self.ax, x=90, y=0) - - # Force draw to ensure labels are updated - self.fig.canvas.draw() - - # Check that x labels are rotated but y labels are vertical - for label in ax.get_xticklabels(): - assert label.get_rotation() == 90 - - for label in ax.get_yticklabels(): - assert label.get_rotation() == 0 - - def test_rotate_y_only(self): - # Test rotating only y labels - ax = rotate_labels(self.ax, x=0, y=90) - - # Force draw to ensure labels are updated - self.fig.canvas.draw() - - # Check that y labels are rotated but x labels are horizontal - for label in ax.get_xticklabels(): - assert label.get_rotation() == 0 - - for label in ax.get_yticklabels(): - assert label.get_rotation() == 90 - - def test_savefig(self): - from scitex.io import save - - # Main test functionality - rotate_labels(self.ax, x=45, y=30) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_rotate_labels.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-09-24 13:22:52 (ywatanabe)" -# # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_rotate_labels.py -# # ---------------------------------------- -# from __future__ import annotations -# import os -# -# __FILE__ = __file__ -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# # Time-stamp: "2024-10-27 13:24:32 (ywatanabe)" -# # /home/ywatanabe/proj/_scitex_repo_openhands/src/scitex/plt/ax/_rotate_labels.py -# -# """This script does XYZ.""" -# -# """Imports""" -# import numpy as np -# -# -# def rotate_labels( -# ax, -# x=None, -# y=None, -# x_ha=None, -# y_ha=None, -# x_va=None, -# y_va=None, -# auto_adjust=True, -# scientific_convention=True, -# tight_layout=False, -# ): -# """ -# Rotate x and y axis labels of a matplotlib Axes object with automatic positioning. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The Axes object to modify. -# x : float or None, optional -# Rotation angle for x-axis labels in degrees. Default is None. -# If 0 or None, x-axis labels are not rotated. -# y : float or None, optional -# Rotation angle for y-axis labels in degrees. Default is None. -# If 0 or None, y-axis labels are not rotated. -# x_ha : str, optional -# Horizontal alignment for x-axis labels. If None, automatically determined. -# y_ha : str, optional -# Horizontal alignment for y-axis labels. If None, automatically determined. -# x_va : str, optional -# Vertical alignment for x-axis labels. If None, automatically determined. -# y_va : str, optional -# Vertical alignment for y-axis labels. If None, automatically determined. -# auto_adjust : bool, optional -# Whether to automatically adjust alignment based on rotation angle. Default is True. -# scientific_convention : bool, optional -# Whether to follow scientific plotting conventions. Default is True. -# tight_layout : bool, optional -# Whether to apply tight_layout to prevent overlapping. Default is False. -# -# Returns -# ------- -# matplotlib.axes.Axes -# The modified Axes object. -# -# Example -# ------- -# fig, ax = plt.subplots() -# ax.plot([1, 2, 3], [1, 2, 3]) -# rotate_labels(ax) -# plt.show() -# -# Notes -# ----- -# Scientific conventions for label rotation: -# - X-axis labels: For angles 0-90°, use 'right' alignment; for 90-180°, use 'left' -# - Y-axis labels: For angles 0-90°, use 'center' alignment; adjust vertical as needed -# - Optimal readability maintained through automatic positioning -# """ -# # Determine which axes to rotate (skip if None or 0) -# rotate_x = x is not None and x != 0 -# rotate_y = y is not None and y != 0 -# -# # Get current tick positions -# xticks = ax.get_xticks() -# yticks = ax.get_yticks() -# -# # Set ticks explicitly -# ax.set_xticks(xticks) -# ax.set_yticks(yticks) -# -# # Auto-adjust alignment based on rotation angle and scientific conventions -# if auto_adjust: -# if rotate_x: -# x_ha, x_va = _get_optimal_alignment( -# "x", x, x_ha, x_va, scientific_convention -# ) -# if rotate_y: -# y_ha, y_va = _get_optimal_alignment( -# "y", y, y_ha, y_va, scientific_convention -# ) -# -# # Apply defaults if not auto-adjusting -# if rotate_x: -# if x_ha is None: -# x_ha = "center" -# if x_va is None: -# x_va = "center" -# if rotate_y: -# if y_ha is None: -# y_ha = "center" -# if y_va is None: -# y_va = "center" -# -# # Check if this axis is part of a shared x-axis configuration -# # If labels are already visible (bottom subplot or not shared), keep them visible -# # This preserves matplotlib's default sharex behavior -# x_labels_visible = ax.xaxis.get_tick_params()["labelbottom"] -# y_labels_visible = ax.yaxis.get_tick_params()["labelleft"] -# -# # Set labels with rotation and proper alignment -# # Only set labels if they're currently visible (respects sharex/sharey) -# if x_labels_visible and rotate_x: -# ax.set_xticklabels(ax.get_xticklabels(), rotation=x, ha=x_ha, va=x_va) -# if y_labels_visible and rotate_y: -# ax.set_yticklabels(ax.get_yticklabels(), rotation=y, ha=y_ha, va=y_va) -# -# # Auto-adjust subplot parameters for better layout if needed -# if auto_adjust and scientific_convention: -# # Only pass non-zero angles for adjustment -# x_angle = x if rotate_x else 0 -# y_angle = y if rotate_y else 0 -# _adjust_subplot_params(ax, x_angle, y_angle) -# -# # Apply tight_layout if requested to prevent overlapping -# if tight_layout: -# fig = ax.get_figure() -# try: -# fig.tight_layout() -# except Exception: -# # Fallback to manual adjustment if tight_layout fails -# x_angle = x if rotate_x else 0 -# y_angle = y if rotate_y else 0 -# _adjust_subplot_params(ax, x_angle, y_angle) -# -# return ax -# -# -# def _get_optimal_alignment(axis, angle, ha, va, scientific_convention): -# """ -# Determine optimal alignment based on rotation angle and scientific conventions. -# -# Parameters -# ---------- -# axis : str -# 'x' or 'y' axis -# angle : float -# Rotation angle in degrees -# ha : str or None -# Current horizontal alignment -# va : str or None -# Current vertical alignment -# scientific_convention : bool -# Whether to follow scientific conventions -# -# Returns -# ------- -# tuple -# (horizontal_alignment, vertical_alignment) -# """ -# # Normalize angle to 0-360 range -# angle = angle % 360 -# -# if axis == "x": -# if scientific_convention: -# # Scientific convention for x-axis labels -# if 0 <= angle <= 30: -# ha = ha or "center" -# va = va or "top" -# elif 30 < angle <= 60: -# ha = ha or "right" -# va = va or "top" -# elif 60 < angle < 90: -# ha = ha or "right" -# va = va or "top" -# elif angle == 90: -# # Special case for exact 90 degrees -# ha = ha or "right" -# va = va or "top" -# elif 90 < angle <= 120: -# ha = ha or "right" -# va = va or "center" -# elif 120 < angle <= 150: -# ha = ha or "right" -# va = va or "bottom" -# elif 150 < angle <= 210: -# ha = ha or "center" -# va = va or "bottom" -# elif 210 < angle <= 240: -# ha = ha or "left" -# va = va or "bottom" -# elif 240 < angle <= 300: -# ha = ha or "left" -# va = va or "center" -# else: # 300-360 -# ha = ha or "left" -# va = va or "top" -# else: -# ha = ha or "center" -# va = va or "top" -# -# else: # y-axis -# if scientific_convention: -# # Scientific convention for y-axis labels -# if 0 <= angle <= 30: -# ha = ha or "right" -# va = va or "center" -# elif 30 < angle <= 60: -# ha = ha or "right" -# va = va or "bottom" -# elif 60 < angle <= 120: -# ha = ha or "center" -# va = va or "bottom" -# elif 120 < angle <= 150: -# ha = ha or "left" -# va = va or "bottom" -# elif 150 < angle <= 210: -# ha = ha or "left" -# va = va or "center" -# elif 210 < angle <= 240: -# ha = ha or "left" -# va = va or "top" -# elif 240 < angle <= 300: -# ha = ha or "center" -# va = va or "top" -# else: # 300-360 -# ha = ha or "right" -# va = va or "top" -# else: -# ha = ha or "center" -# va = va or "center" -# -# return ha, va -# -# -# def _adjust_subplot_params(ax, x_angle, y_angle): -# """ -# Automatically adjust subplot parameters to accommodate rotated labels. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The axes object -# x_angle : float -# X-axis rotation angle -# y_angle : float -# Y-axis rotation angle -# """ -# fig = ax.get_figure() -# -# # Check if figure is using a layout engine that is incompatible with subplots_adjust -# try: -# # For matplotlib >= 3.6 -# if hasattr(fig, "get_layout_engine"): -# layout_engine = fig.get_layout_engine() -# if layout_engine is not None: -# # If using constrained_layout or tight_layout, don't adjust -# return -# except AttributeError: -# pass -# -# # Check for constrained_layout (older matplotlib versions) -# try: -# if hasattr(fig, "get_constrained_layout"): -# if fig.get_constrained_layout(): -# # Constrained layout is active, don't adjust -# return -# except AttributeError: -# pass -# -# # Calculate required margins based on rotation angles -# # Special handling for 90-degree rotation -# if x_angle == 90: -# x_margin_factor = 0.3 # Maximum margin for 90 degrees -# else: -# # Increase margin more significantly for rotated x-axis labels to prevent xlabel overlap -# x_margin_factor = abs(np.sin(np.radians(x_angle))) * 0.25 # Increased from 0.2 -# -# y_margin_factor = abs(np.sin(np.radians(y_angle))) * 0.15 -# -# # Get current subplot parameters -# try: -# subplotpars = fig.subplotpars -# current_bottom = subplotpars.bottom -# current_left = subplotpars.left -# -# # Adjust margins if they need to be increased -# # Ensure more space for rotated x-labels and xlabel -# new_bottom = max( -# current_bottom, 0.2 + x_margin_factor -# ) # Increased base from 0.15 -# new_left = max(current_left, 0.1 + y_margin_factor) -# -# # Only adjust if we're increasing the margins significantly -# if ( -# new_bottom > current_bottom + 0.02 or new_left > current_left + 0.02 -# ): # Reduced threshold -# # Suppress warning and try to adjust -# import warnings -# -# with warnings.catch_warnings(): -# warnings.simplefilter("ignore") -# fig.subplots_adjust(bottom=new_bottom, left=new_left) -# except Exception: -# # Skip adjustment if there are issues -# pass -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_rotate_labels.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__rotate_labels_v01.py b/tests/scitex/plt/ax/_style/test__rotate_labels_v01.py deleted file mode 100644 index 983892b4c..000000000 --- a/tests/scitex/plt/ax/_style/test__rotate_labels_v01.py +++ /dev/null @@ -1,274 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_rotate_labels_v01.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Time-stamp: "2024-10-27 13:24:32 (ywatanabe)" -# # /home/ywatanabe/proj/_scitex_repo_openhands/src/scitex/plt/ax/_rotate_labels.py -# -# """This script does XYZ.""" -# -# """Imports""" -# import numpy as np -# -# -# def rotate_labels( -# ax, -# x=45, -# y=45, -# x_ha=None, -# y_ha=None, -# x_va=None, -# y_va=None, -# auto_adjust=True, -# scientific_convention=True, -# ): -# """ -# Rotate x and y axis labels of a matplotlib Axes object with automatic positioning. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The Axes object to modify. -# x : float, optional -# Rotation angle for x-axis labels in degrees. Default is 45. -# y : float, optional -# Rotation angle for y-axis labels in degrees. Default is 45. -# x_ha : str, optional -# Horizontal alignment for x-axis labels. If None, automatically determined. -# y_ha : str, optional -# Horizontal alignment for y-axis labels. If None, automatically determined. -# x_va : str, optional -# Vertical alignment for x-axis labels. If None, automatically determined. -# y_va : str, optional -# Vertical alignment for y-axis labels. If None, automatically determined. -# auto_adjust : bool, optional -# Whether to automatically adjust alignment based on rotation angle. Default is True. -# scientific_convention : bool, optional -# Whether to follow scientific plotting conventions. Default is True. -# -# Returns -# ------- -# matplotlib.axes.Axes -# The modified Axes object. -# -# Example -# ------- -# fig, ax = plt.subplots() -# ax.plot([1, 2, 3], [1, 2, 3]) -# rotate_labels(ax) -# plt.show() -# -# Notes -# ----- -# Scientific conventions for label rotation: -# - X-axis labels: For angles 0-90°, use 'right' alignment; for 90-180°, use 'left' -# - Y-axis labels: For angles 0-90°, use 'center' alignment; adjust vertical as needed -# - Optimal readability maintained through automatic positioning -# """ -# # Get current tick positions -# xticks = ax.get_xticks() -# yticks = ax.get_yticks() -# -# # Set ticks explicitly -# ax.set_xticks(xticks) -# ax.set_yticks(yticks) -# -# # Auto-adjust alignment based on rotation angle and scientific conventions -# if auto_adjust: -# x_ha, x_va = _get_optimal_alignment("x", x, x_ha, x_va, scientific_convention) -# y_ha, y_va = _get_optimal_alignment("y", y, y_ha, y_va, scientific_convention) -# -# # Apply defaults if not auto-adjusting -# if x_ha is None: -# x_ha = "center" -# if y_ha is None: -# y_ha = "center" -# if x_va is None: -# x_va = "center" -# if y_va is None: -# y_va = "center" -# -# # Check if this axis is part of a shared x-axis configuration -# # If labels are already visible (bottom subplot or not shared), keep them visible -# # This preserves matplotlib's default sharex behavior -# x_labels_visible = ax.xaxis.get_tick_params()["labelbottom"] -# y_labels_visible = ax.yaxis.get_tick_params()["labelleft"] -# -# # Set labels with rotation and proper alignment -# # Only set labels if they're currently visible (respects sharex/sharey) -# if x_labels_visible: -# ax.set_xticklabels(ax.get_xticklabels(), rotation=x, ha=x_ha, va=x_va) -# if y_labels_visible: -# ax.set_yticklabels(ax.get_yticklabels(), rotation=y, ha=y_ha, va=y_va) -# -# # Auto-adjust subplot parameters for better layout if needed -# if auto_adjust and scientific_convention: -# _adjust_subplot_params(ax, x, y) -# -# return ax -# -# -# def _get_optimal_alignment(axis, angle, ha, va, scientific_convention): -# """ -# Determine optimal alignment based on rotation angle and scientific conventions. -# -# Parameters -# ---------- -# axis : str -# 'x' or 'y' axis -# angle : float -# Rotation angle in degrees -# ha : str or None -# Current horizontal alignment -# va : str or None -# Current vertical alignment -# scientific_convention : bool -# Whether to follow scientific conventions -# -# Returns -# ------- -# tuple -# (horizontal_alignment, vertical_alignment) -# """ -# # Normalize angle to 0-360 range -# angle = angle % 360 -# -# if axis == "x": -# if scientific_convention: -# # Scientific convention for x-axis labels -# if 0 <= angle <= 30: -# ha = ha or "center" -# va = va or "top" -# elif 30 < angle <= 60: -# ha = ha or "right" -# va = va or "top" -# elif 60 < angle <= 120: -# ha = ha or "right" -# va = va or "center" -# elif 120 < angle <= 150: -# ha = ha or "right" -# va = va or "bottom" -# elif 150 < angle <= 210: -# ha = ha or "center" -# va = va or "bottom" -# elif 210 < angle <= 240: -# ha = ha or "left" -# va = va or "bottom" -# elif 240 < angle <= 300: -# ha = ha or "left" -# va = va or "center" -# else: # 300-360 -# ha = ha or "left" -# va = va or "top" -# else: -# ha = ha or "center" -# va = va or "top" -# -# else: # y-axis -# if scientific_convention: -# # Scientific convention for y-axis labels -# if 0 <= angle <= 30: -# ha = ha or "right" -# va = va or "center" -# elif 30 < angle <= 60: -# ha = ha or "right" -# va = va or "bottom" -# elif 60 < angle <= 120: -# ha = ha or "center" -# va = va or "bottom" -# elif 120 < angle <= 150: -# ha = ha or "left" -# va = va or "bottom" -# elif 150 < angle <= 210: -# ha = ha or "left" -# va = va or "center" -# elif 210 < angle <= 240: -# ha = ha or "left" -# va = va or "top" -# elif 240 < angle <= 300: -# ha = ha or "center" -# va = va or "top" -# else: # 300-360 -# ha = ha or "right" -# va = va or "top" -# else: -# ha = ha or "center" -# va = va or "center" -# -# return ha, va -# -# -# def _adjust_subplot_params(ax, x_angle, y_angle): -# """ -# Automatically adjust subplot parameters to accommodate rotated labels. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The axes object -# x_angle : float -# X-axis rotation angle -# y_angle : float -# Y-axis rotation angle -# """ -# fig = ax.get_figure() -# -# # Check if figure is using a layout engine that is incompatible with subplots_adjust -# try: -# # For matplotlib >= 3.6 -# if hasattr(fig, "get_layout_engine"): -# layout_engine = fig.get_layout_engine() -# if layout_engine is not None: -# # If using constrained_layout or tight_layout, don't adjust -# return -# except AttributeError: -# pass -# -# # Check for constrained_layout (older matplotlib versions) -# try: -# if hasattr(fig, "get_constrained_layout"): -# if fig.get_constrained_layout(): -# # Constrained layout is active, don't adjust -# return -# except AttributeError: -# pass -# -# # Calculate required margins based on rotation angles -# x_margin_factor = abs(np.sin(np.radians(x_angle))) * 0.1 -# y_margin_factor = abs(np.sin(np.radians(y_angle))) * 0.15 -# -# # Get current subplot parameters -# try: -# subplotpars = fig.subplotpars -# current_bottom = subplotpars.bottom -# current_left = subplotpars.left -# -# # Adjust margins if they need to be increased -# new_bottom = max(current_bottom, 0.1 + x_margin_factor) -# new_left = max(current_left, 0.1 + y_margin_factor) -# -# # Only adjust if we're increasing the margins significantly -# if new_bottom > current_bottom + 0.05 or new_left > current_left + 0.05: -# # Suppress warning and try to adjust -# import warnings -# -# with warnings.catch_warnings(): -# warnings.simplefilter("ignore") -# fig.subplots_adjust(bottom=new_bottom, left=new_left) -# except Exception: -# # Skip adjustment if there are issues -# pass - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_rotate_labels_v01.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__sci_note.py b/tests/scitex/plt/ax/_style/test__sci_note.py deleted file mode 100644 index b70fd317d..000000000 --- a/tests/scitex/plt/ax/_style/test__sci_note.py +++ /dev/null @@ -1,450 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:45 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__sci_note.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__sci_note.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib.pyplot as plt -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import OOMFormatter, sci_note - - -class TestSciNote: - - @pytest.fixture - def setup_axes(self): - fig, ax = plt.subplots(figsize=(8, 6)) - ax.plot([0, 1000], [0, 2000]) - return fig, ax - - def test_oom_formatter_creation(self): - formatter = OOMFormatter(order=3, fformat="%.2f") - assert formatter.order == 3 - assert formatter.fformat == "%.2f" - - # def test_oom_formatter_order_setting(self): - # # Only test explicit order setting since auto requires an axis - # formatter = OOMFormatter(order=4) - # # Mock the parent method to avoid the call to super() - # formatter.orderOfMagnitude = None - # formatter._set_order_of_magnitude() - # assert formatter.orderOfMagnitude == 4 - - def test_oom_formatter_format_setting(self): - formatter = OOMFormatter(fformat="%.3f", mathText=True) - formatter._set_format() - assert formatter.format == r"$\mathdefault{%.3f}$" - - formatter_no_math = OOMFormatter(fformat="%.3f", mathText=False) - formatter_no_math._set_format() - assert formatter_no_math.format == "%.3f" - - def test_sci_note_x_axis(self, setup_axes): - _, ax = setup_axes - ax = sci_note(ax, x=True) - - assert isinstance(ax.xaxis.get_major_formatter(), OOMFormatter) - assert ax.xaxis.labelpad == -22 - - def test_sci_note_y_axis(self, setup_axes): - _, ax = setup_axes - ax = sci_note(ax, y=True) - - assert isinstance(ax.yaxis.get_major_formatter(), OOMFormatter) - assert ax.yaxis.labelpad == -20 - - def test_sci_note_both_axes(self, setup_axes): - _, ax = setup_axes - ax = sci_note(ax, x=True, y=True) - - assert isinstance(ax.xaxis.get_major_formatter(), OOMFormatter) - assert isinstance(ax.yaxis.get_major_formatter(), OOMFormatter) - - def test_sci_note_custom_order(self, setup_axes): - _, ax = setup_axes - ax = sci_note(ax, x=True, y=True, order_x=5, order_y=6) - - # Get formatters from the axes - xformatter = ax.xaxis.get_major_formatter() - yformatter = ax.yaxis.get_major_formatter() - - # Check order was set correctly - assert xformatter.order == 5 - assert yformatter.order == 6 - - def test_sci_note_custom_format(self, setup_axes): - _, ax = setup_axes - custom_format = "%.4f" - ax = sci_note(ax, x=True, y=True, fformat=custom_format) - - xformatter = ax.xaxis.get_major_formatter() - yformatter = ax.yaxis.get_major_formatter() - - assert xformatter.fformat == custom_format - assert yformatter.fformat == custom_format - - def test_sci_note_custom_padding(self, setup_axes): - _, ax = setup_axes - ax = sci_note(ax, x=True, y=True, pad_x=-10, pad_y=-15) - - assert ax.xaxis.labelpad == -10 - assert ax.yaxis.labelpad == -15 - - def test_savefig(self, setup_axes): - from scitex.io import save - - # Main test functionality - fig, ax = setup_axes - ax = sci_note(ax, x=True, y=True, fformat="%1.2f") - - # Saving - spath = f"{os.path.basename(__file__)}.jpg" - save(fig, spath) - - # Check saved file existence - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - # @check_figures_equal(extensions=["png"]) - # def test_sci_note_visual_output(self, fig_test, fig_ref): - # # Create identical data for both figures - # data_x = np.array([0, 10000]) - # data_y = np.array([0, 20000]) - - # # Test figure with sci_note - # ax_test = fig_test.subplots() - # ax_test.plot(data_x, data_y) - # # Set both x and y limits exactly to ensure consistency - # ax_test.set_xlim(0, 10000) - # ax_test.set_ylim(0, 20000) - # sci_note(ax_test, x=True, y=True) - - # # Reference figure with manually configured similar settings - # ax_ref = fig_ref.subplots() - # ax_ref.plot(data_x, data_y) - # # Set identical limits - # ax_ref.set_xlim(0, 10000) - # ax_ref.set_ylim(0, 20000) - # # Calculate the same orders of magnitude - # order_x = int( - # np.floor(np.log10(np.max(np.abs(ax_ref.get_xlim())) + 1e-5)) - # ) - # order_y = int( - # np.floor(np.log10(np.max(np.abs(ax_ref.get_ylim())) + 1e-5)) - # ) - # # Apply them manually - # ax_ref.xaxis.set_major_formatter( - # matplotlib.ticker.ScalarFormatter(useMathText=True) - # ) - # ax_ref.yaxis.set_major_formatter( - # matplotlib.ticker.ScalarFormatter(useMathText=True) - # ) - # ax_ref.ticklabel_format(style="sci", scilimits=(-3, 3), axis="both") - - # # Ensure matching appearance - # fig_test.tight_layout() - # fig_ref.tight_layout() - # ax_ref.ticklabel_format(style="sci", scilimits=(-3, 3), axis="both") - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_sci_note.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-03 11:58:58 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_sci_note.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_style/_sci_note.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# import numpy as np -# -# -# class OOMFormatter(matplotlib.ticker.ScalarFormatter): -# """Custom formatter for scientific notation with fixed order of magnitude. -# -# A matplotlib formatter that allows you to specify a fixed exponent for -# scientific notation, rather than letting matplotlib choose it automatically. -# Useful when you want consistent notation across multiple plots or specific -# exponent values. -# -# Parameters -# ---------- -# order : int or None, optional -# Fixed order of magnitude (exponent) to use. If None, calculated -# automatically. Default is None. -# fformat : str, optional -# Format string for the mantissa. Default is "%1.1f". -# offset : bool, optional -# Whether to use offset notation. Default is True. -# mathText : bool, optional -# Whether to use mathtext rendering. Default is True. -# -# Attributes -# ---------- -# order : int or None -# The fixed order of magnitude to use. -# fformat : str -# Format string for displaying numbers. -# -# Examples -# -------- -# >>> # Force all labels to use 10^3 notation -# >>> formatter = OOMFormatter(order=3, fformat="%1.2f") -# >>> ax.xaxis.set_major_formatter(formatter) -# -# >>> # Use 10^-6 for microvolts -# >>> formatter = OOMFormatter(order=-6, fformat="%1.1f") -# >>> ax.yaxis.set_major_formatter(formatter) -# -# See Also -# -------- -# matplotlib.ticker.ScalarFormatter : Base formatter class -# sci_note : Convenience function using this formatter -# """ -# -# def __init__(self, order=None, fformat="%1.1f", offset=True, mathText=True): -# self.order = order -# self.fformat = fformat -# matplotlib.ticker.ScalarFormatter.__init__( -# self, useOffset=offset, useMathText=mathText -# ) -# -# def _set_order_of_magnitude(self): -# if self.order is not None: -# self.orderOfMagnitude = self.order -# else: -# super()._set_order_of_magnitude() -# -# def _set_format(self, vmin=None, vmax=None): -# self.format = self.fformat -# if self._useMathText: -# self.format = r"$\mathdefault{%s}$" % self.format -# -# -# def sci_note( -# ax, -# fformat="%1.1f", -# x=False, -# y=False, -# scilimits=(-3, 3), -# order_x=None, -# order_y=None, -# pad_x=-22, -# pad_y=-20, -# ): -# """ -# Apply scientific notation to axis with optional manual order of magnitude. -# -# Parameters: -# ----------- -# ax : matplotlib Axes -# The axes to apply scientific notation to -# fformat : str -# Format string for tick labels -# x, y : bool -# Whether to apply to x or y axis -# scilimits : tuple -# Scientific notation limits -# order_x, order_y : int or None -# Manual order of magnitude (exponent). If None, calculated automatically -# pad_x, pad_y : int -# Padding for the axis labels -# """ -# if x: -# # Calculate order if not specified -# if order_x is None: -# order_x = np.floor(np.log10(np.max(np.abs(ax.get_xlim())) + 1e-5)) -# -# ax.xaxis.set_major_formatter(OOMFormatter(order=int(order_x), fformat=fformat)) -# ax.ticklabel_format(axis="x", style="sci", scilimits=scilimits) -# ax.xaxis.labelpad = pad_x -# shift_x = (ax.get_xlim()[0] - ax.get_xlim()[1]) * 0.01 -# ax.xaxis.get_offset_text().set_position((shift_x, 0)) -# -# if y: -# # Calculate order if not specified -# if order_y is None: -# order_y = np.floor(np.log10(np.max(np.abs(ax.get_ylim())) + 1e-5)) -# -# ax.yaxis.set_major_formatter(OOMFormatter(order=int(order_y), fformat=fformat)) -# ax.ticklabel_format(axis="y", style="sci", scilimits=scilimits) -# ax.yaxis.labelpad = pad_y -# shift_y = (ax.get_ylim()[0] - ax.get_ylim()[1]) * 0.01 -# ax.yaxis.get_offset_text().set_position((0, shift_y)) -# -# return ax -# -# -# # import matplotlib -# # import numpy as np -# -# -# # class OOMFormatter(matplotlib.ticker.ScalarFormatter): -# # def __init__(self, order=0, fformat="%1.1f", offset=True, mathText=True): -# # self.order = order -# # self.fformat = fformat -# # matplotlib.ticker.ScalarFormatter.__init__( -# # self, useOffset=offset, useMathText=mathText -# # ) -# -# # def _set_order_of_magnitude(self): -# # self.orderOfMagnitude = self.order -# -# # def _set_format(self, vmin=None, vmax=None): -# # self.format = self.fformat -# # if self._useMathText: -# # self.format = r"$\mathdefault{%s}$" % self.format -# -# -# # def sci_note(ax, fformat="%1.1f", x=False, y=False, scilimits=(-3, 3)): -# # order_x = 0 -# # order_y = 0 -# -# # if x: -# # order_x = np.floor(np.log10(np.max(np.abs(ax.get_xlim())) + 1e-5)) -# # ax.xaxis.set_major_formatter( -# # OOMFormatter(order=int(order_x), fformat=fformat) -# # ) -# # ax.ticklabel_format(axis="x", style="sci", scilimits=scilimits) -# # ax.xaxis.labelpad = -22 -# # shift_x = (ax.get_xlim()[0] - ax.get_xlim()[1]) * 0.01 -# # ax.xaxis.get_offset_text().set_position((shift_x, 0)) -# -# # if y: -# # order_y = np.floor(np.log10(np.max(np.abs(ax.get_ylim())) + 1e-5)) -# # ax.yaxis.set_major_formatter( -# # OOMFormatter(order=int(order_y), fformat=fformat) -# # ) -# # ax.ticklabel_format(axis="y", style="sci", scilimits=scilimits) -# # ax.yaxis.labelpad = -20 -# # shift_y = (ax.get_ylim()[0] - ax.get_ylim()[1]) * 0.01 -# # ax.yaxis.get_offset_text().set_position((0, shift_y)) -# -# # return ax -# -# -# # # class OOMFormatter(matplotlib.ticker.ScalarFormatter): -# # # def __init__(self, order=0, fformat="%1.1f", offset=True, mathText=True): -# # # self.order = order -# # # self.fformat = fformat -# # # matplotlib.ticker.ScalarFormatter.__init__( -# # # self, useOffset=offset, useMathText=mathText -# # # ) -# -# # # def _set_order_of_magnitude(self): -# # # self.orderOfMagnitude = self.order -# -# # # def _set_format(self, vmin=None, vmax=None): -# # # self.format = self.fformat -# # # if self._useMathText: -# # # self.format = r"$\mathdefault{%s}$" % self.format -# -# -# # # def sci_note(ax, fformat="%1.1f", x=False, y=False, scilimits=(-3, 3)): -# # # order_x = 0 -# # # order_y = 0 -# -# # # if x: -# # # order_x = np.floor(np.log10(np.max(np.abs(ax.get_xlim())) + 1e-5)) -# # # ax.xaxis.set_major_formatter( -# # # OOMFormatter(order=int(order_x), fformat=fformat) -# # # ) -# # # ax.ticklabel_format(axis="x", style="sci", scilimits=scilimits) -# -# # # if y: -# # # order_y = np.floor(np.log10(np.max(np.abs(ax.get_ylim()) + 1e-5))) -# # # ax.yaxis.set_major_formatter( -# # # OOMFormatter(order=int(order_y), fformat=fformat) -# # # ) -# # # ax.ticklabel_format(axis="y", style="sci", scilimits=scilimits) -# -# # # return ax -# -# -# # # #!/usr/bin/env python3 -# -# -# # # import matplotlib -# -# -# # # class OOMFormatter(matplotlib.ticker.ScalarFormatter): -# # # # https://stackoverflow.com/questions/42656139/set-scientific-notation-with-fixed-exponent-and-significant-digits-for-multiple -# # # # def __init__(self, order=0, fformat="%1.1f", offset=True, mathText=True): -# # # def __init__(self, order=0, fformat="%1.0d", offset=True, mathText=True): -# # # self.oom = order -# # # self.fformat = fformat -# # # matplotlib.ticker.ScalarFormatter.__init__( -# # # self, useOffset=offset, useMathText=mathText -# # # ) -# -# # # def _set_order_of_magnitude(self): -# # # self.orderOfMagnitude = self.oom -# -# # # def _set_format(self, vmin=None, vmax=None): -# # # self.format = self.fformat -# # # if self._useMathText: -# # # self.format = r"$\mathdefault{%s}$" % self.format -# -# -# # # def sci_note( -# # # ax, -# # # order, -# # # fformat="%1.0d", -# # # x=False, -# # # y=False, -# # # scilimits=(-3, 3), -# # # ): -# # # """ -# # # Change the expression of the x- or y-axis to the scientific notation like *10^3 -# # # , where 3 is the first argument, order. -# -# # # Example: -# # # order = 4 # 10^4 -# # # ax = sci_note( -# # # ax, -# # # order, -# # # fformat="%1.0d", -# # # x=True, -# # # y=False, -# # # scilimits=(-3, 3), -# # # """ -# -# # # if x == True: -# # # ax.xaxis.set_major_formatter( -# # # OOMFormatter(order=order, fformat=fformat) -# # # ) -# # # ax.ticklabel_format(axis="x", style="sci", scilimits=scilimits) -# # # if y == True: -# # # ax.yaxis.set_major_formatter( -# # # OOMFormatter(order=order, fformat=fformat) -# # # ) -# # # ax.ticklabel_format(axis="y", style="sci", scilimits=scilimits) -# -# # # return ax -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_sci_note.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__set_log_scale.py b/tests/scitex/plt/ax/_style/test__set_log_scale.py deleted file mode 100644 index 47fae56f3..000000000 --- a/tests/scitex/plt/ax/_style/test__set_log_scale.py +++ /dev/null @@ -1,873 +0,0 @@ -#!/usr/bin/env python3 -""" -Comprehensive tests for scitex.plt.ax._style._set_log_scale module. - -This module tests logarithmic scale configuration utilities including: -- set_log_scale: Configure log scales with advanced formatting -- smart_log_limits: Automatically determine optimal log scale limits -- add_log_scale_indicator: Add visual indicators for log scales -""" - -import warnings -from unittest.mock import MagicMock, Mock, patch - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pytest - -import scitex - - -# Test fixtures -@pytest.fixture -def fig_ax(): - """Create a figure and axis for testing.""" - fig, ax = plt.subplots(figsize=(8, 6)) - yield fig, ax - plt.close(fig) - - -@pytest.fixture -def sample_data(): - """Generate sample data spanning multiple orders of magnitude.""" - return { - "linear": np.linspace(1, 100, 50), - "exponential": np.logspace(0, 4, 50), # 1 to 10000 - "small_range": np.logspace(-2, 0, 20), # 0.01 to 1 - "large_range": np.logspace(0, 6, 50), # 1 to 1000000 - "negative": -np.logspace(0, 3, 30), # -1 to -1000 - "mixed_sign": np.concatenate([np.logspace(0, 2, 25), -np.logspace(0, 2, 25)]), - } - - -class TestSetLogScale: - """Test set_log_scale function - main log scale configuration.""" - - def test_set_log_scale_x_axis(self, fig_ax): - """Test setting log scale on x-axis.""" - fig, ax = fig_ax - - result = scitex.plt.ax.set_log_scale(ax, axis="x") - - assert ax.get_xscale() == "log" - assert ax.get_yscale() == "linear" # Should remain linear - assert result == ax # Should return the axis - - def test_set_log_scale_y_axis(self, fig_ax): - """Test setting log scale on y-axis.""" - fig, ax = fig_ax - - result = scitex.plt.ax.set_log_scale(ax, axis="y") - - assert ax.get_xscale() == "linear" # Should remain linear - assert ax.get_yscale() == "log" - assert result == ax - - def test_set_log_scale_both_axes(self, fig_ax): - """Test setting log scale on both axes.""" - fig, ax = fig_ax - - result = scitex.plt.ax.set_log_scale(ax, axis="both") - - assert ax.get_xscale() == "log" - assert ax.get_yscale() == "log" - assert result == ax - - def test_set_log_scale_with_base(self, fig_ax): - """Test setting log scale with custom base.""" - fig, ax = fig_ax - - # Test base 2 - scitex.plt.ax.set_log_scale(ax, axis="x", base=2) - assert ax.get_xscale() == "log" - - # Test base e (natural log) - scitex.plt.ax.set_log_scale(ax, axis="y", base=np.e) - assert ax.get_yscale() == "log" - - def test_set_log_scale_with_custom_limits(self, fig_ax): - """Test setting log scale and then setting custom limits.""" - fig, ax = fig_ax - - # set_log_scale doesn't have limits parameter, so set them separately - limits = (1, 1000) - scitex.plt.ax.set_log_scale(ax, axis="x") - ax.set_xlim(limits) - - assert ax.get_xscale() == "log" - x_limits = ax.get_xlim() - assert x_limits[0] == pytest.approx(limits[0], rel=1e-2) - assert x_limits[1] == pytest.approx(limits[1], rel=1e-2) - - def test_set_log_scale_with_minor_ticks(self, fig_ax): - """Test setting log scale with minor tick configuration.""" - fig, ax = fig_ax - - scitex.plt.ax.set_log_scale(ax, axis="x", show_minor_ticks=True) - - assert ax.get_xscale() == "log" - # Check that minor ticks are enabled - assert ax.xaxis.get_minor_locator() is not None - - def test_set_log_scale_without_minor_ticks(self, fig_ax): - """Test setting log scale without minor ticks.""" - fig, ax = fig_ax - - scitex.plt.ax.set_log_scale(ax, axis="y", show_minor_ticks=False) - - assert ax.get_yscale() == "log" - # Minor ticks should be minimal or disabled - - def test_set_log_scale_with_grid(self, fig_ax): - """Test setting log scale with grid configuration.""" - fig, ax = fig_ax - - scitex.plt.ax.set_log_scale(ax, axis="both", grid=True) - - assert ax.get_xscale() == "log" - assert ax.get_yscale() == "log" - # Grid should be enabled (matplotlib's grid state can be complex to check) - - def test_set_log_scale_with_scientific_notation(self, fig_ax): - """Test setting log scale with scientific notation formatting.""" - fig, ax = fig_ax - - scitex.plt.ax.set_log_scale(ax, axis="x", scientific_notation=True) - - assert ax.get_xscale() == "log" - # Check that formatter is set for scientific notation - formatter = ax.xaxis.get_major_formatter() - assert formatter is not None - - def test_set_log_scale_invalid_axis(self, fig_ax): - """Test error handling for invalid axis specification.""" - fig, ax = fig_ax - - # Test with invalid axis - may not raise exception, just ignore invalid axis - try: - result = scitex.plt.ax.set_log_scale(ax, axis="invalid") - # If it doesn't raise, it should return the axis unchanged - assert result == ax - except (ValueError, KeyError): - # If it does raise, that's also acceptable - pass - - def test_set_log_scale_with_negative_limits(self, fig_ax): - """Test handling of negative limits (logarithmic scale doesn't support negatives).""" - fig, ax = fig_ax - - # Set log scale first, then try negative limits - scitex.plt.ax.set_log_scale(ax, axis="x") - - # Setting negative limits on log scale should raise warning - with pytest.warns(UserWarning): - ax.set_xlim(-10, 100) - - def test_set_log_scale_with_zero_limits(self, fig_ax): - """Test handling of zero in limits (logarithmic scale doesn't support zero).""" - fig, ax = fig_ax - - # Set log scale first, then try zero limits - scitex.plt.ax.set_log_scale(ax, axis="y") - - # Setting limits with zero on log scale should raise warning - with pytest.warns(UserWarning): - ax.set_ylim(0, 100) - - -class TestSmartLogLimits: - """Test smart_log_limits function - automatic limit determination.""" - - def test_smart_log_limits_exponential_data(self, sample_data): - """Test smart limits calculation for exponential data.""" - data = sample_data["exponential"] # 1 to 10000 - - limits = scitex.plt.ax.smart_log_limits(data) - - assert len(limits) == 2 - assert limits[0] > 0 # Lower limit should be positive - assert limits[1] > limits[0] # Upper > lower - assert limits[0] <= data.min() - assert limits[1] >= data.max() - - def test_smart_log_limits_small_range_data(self, sample_data): - """Test smart limits for small range data.""" - data = sample_data["small_range"] # 0.01 to 1 - - limits = scitex.plt.ax.smart_log_limits(data) - - assert limits[0] > 0 - assert limits[1] > limits[0] - assert limits[0] <= data.min() - assert limits[1] >= data.max() - - def test_smart_log_limits_large_range_data(self, sample_data): - """Test smart limits for large range data.""" - data = sample_data["large_range"] # 1 to 1000000 - - limits = scitex.plt.ax.smart_log_limits(data) - - assert limits[0] > 0 - assert limits[1] > limits[0] - # Should span multiple orders of magnitude - assert np.log10(limits[1] / limits[0]) >= 5 # At least 5 orders - - def test_smart_log_limits_with_padding(self): - """Test smart limits with custom padding.""" - data = np.logspace(1, 3, 50) # 10 to 1000 - - # Test with larger padding - limits_padded = scitex.plt.ax.smart_log_limits(data, padding_factor=2.0) - limits_normal = scitex.plt.ax.smart_log_limits(data, padding_factor=1.1) - - # Padded limits should be wider - assert limits_padded[0] < limits_normal[0] - assert limits_padded[1] > limits_normal[1] - - def test_smart_log_limits_single_value(self): - """Test smart limits with single value (edge case).""" - data = np.array([100.0]) - - limits = scitex.plt.ax.smart_log_limits(data) - - assert limits[0] > 0 - assert limits[1] > limits[0] - # Should create reasonable range around single value - assert limits[0] < 100 < limits[1] - - def test_smart_log_limits_with_zeros(self): - """Test handling of data containing zeros. - - The function filters out non-positive values, so zeros are ignored. - """ - data = np.array([0, 1, 10, 100]) - - # Function filters out zeros and returns limits for positive data - limits = scitex.plt.ax.smart_log_limits(data) - - assert len(limits) == 2 - assert limits[0] > 0 - assert limits[1] > limits[0] - - def test_smart_log_limits_with_negatives(self, sample_data): - """Test handling of negative data. - - The function filters out non-positive values. With only negative data, - no positive values remain, so it returns default limits (1, base). - """ - data = sample_data["negative"] - - # All negative data means no positive values, returns default limits - limits = scitex.plt.ax.smart_log_limits(data) - - assert len(limits) == 2 - assert limits[0] > 0 - assert limits[1] > limits[0] - - def test_smart_log_limits_empty_data(self): - """Test handling of empty data array. - - Empty array has no positive values, so returns default limits (1, base). - """ - data = np.array([]) - - # Empty data returns default limits - limits = scitex.plt.ax.smart_log_limits(data) - - assert len(limits) == 2 - assert limits[0] > 0 - assert limits[1] > limits[0] - - def test_smart_log_limits_with_axis_specification(self, sample_data): - """Test smart limits with axis specification.""" - data = sample_data["exponential"] - - # Test with different axis names (for reference/documentation) - x_limits = scitex.plt.ax.smart_log_limits(data, axis="x") - y_limits = scitex.plt.ax.smart_log_limits(data, axis="y") - - assert len(x_limits) == 2 - assert len(y_limits) == 2 - assert all(lim > 0 for lim in x_limits + y_limits) - # Both should be the same since they're based on the same data - assert x_limits == y_limits - - -class TestAddLogScaleIndicator: - """Test add_log_scale_indicator function - visual log scale indicators.""" - - def test_add_log_scale_indicator_basic(self, fig_ax): - """Test adding basic log scale indicator.""" - fig, ax = fig_ax - - # Set log scale first - ax.set_xscale("log") - - # add_log_scale_indicator returns None (adds text to axis) - result = scitex.plt.ax.add_log_scale_indicator(ax, axis="x") - - # Returns None but adds text annotation to axes - # Check that text has been added - assert len(ax.texts) > 0 or result is None - - def test_add_log_scale_indicator_custom_styling(self, fig_ax): - """Test adding log scale indicator with custom styling.""" - fig, ax = fig_ax - - ax.set_yscale("log") - - # Test with custom styling parameters that exist in the function - scitex.plt.ax.add_log_scale_indicator( - ax, axis="y", fontsize=14, color="blue", alpha=0.8 - ) - - # Function should complete without error - - def test_add_log_scale_indicator_position(self, fig_ax): - """Test log scale indicator positioning.""" - fig, ax = fig_ax - - ax.set_xscale("log") - - # Test different positions from the function signature - positions = ["auto", "top-left", "top-right", "bottom-left", "bottom-right"] - for position in positions: - try: - scitex.plt.ax.add_log_scale_indicator(ax, axis="x", position=position) - except (ValueError, KeyError): - # Some positions might not be valid - pass - - def test_add_log_scale_indicator_base_display(self, fig_ax): - """Test log scale indicator with different bases.""" - fig, ax = fig_ax - - ax.set_yscale("log") - - # Test with different bases - for base in [2, np.e, 10]: - scitex.plt.ax.add_log_scale_indicator(ax, axis="y", base=base) - - # Indicator should be added with custom styling - - def test_add_log_scale_indicator_both_axes(self, fig_ax): - """Test adding indicators for both axes.""" - fig, ax = fig_ax - - ax.set_xscale("log") - ax.set_yscale("log") - - scitex.plt.ax.add_log_scale_indicator(ax, axis="both") - - # Should add indicators for both axes - - def test_add_log_scale_indicator_linear_axis_warning(self, fig_ax): - """Test adding indicator to linear axis. - - The implementation doesn't warn for linear axes, it just adds the indicator. - """ - fig, ax = fig_ax - - # Keep axis linear (default) - # Function adds text even if axis is linear - scitex.plt.ax.add_log_scale_indicator(ax, axis="x") - - # Check that text was added - assert len(ax.texts) > 0 - - def test_add_log_scale_indicator_with_base(self, fig_ax): - """Test log scale indicator showing custom base.""" - fig, ax = fig_ax - - ax.set_xscale("log", base=2) - - scitex.plt.ax.add_log_scale_indicator(ax, axis="x", base=2) - - # Should indicate base 2 logarithm - - -class TestLogScaleIntegration: - """Test integration between log scale functions.""" - - def test_complete_log_scale_workflow(self, fig_ax, sample_data): - """Test complete workflow: smart limits + set scale + indicator.""" - fig, ax = fig_ax - data = sample_data["exponential"] - - # 1. Calculate smart limits - limits = scitex.plt.ax.smart_log_limits(data) - - # 2. Set log scale (limits must be set separately) - scitex.plt.ax.set_log_scale(ax, axis="both") - ax.set_xlim(limits) - ax.set_ylim(limits) - - # 3. Add visual indicator - scitex.plt.ax.add_log_scale_indicator(ax, axis="both") - - # Verify everything worked together - assert ax.get_xscale() == "log" - assert ax.get_yscale() == "log" - - x_limits = ax.get_xlim() - assert x_limits[0] == pytest.approx(limits[0], rel=1e-1) - assert x_limits[1] == pytest.approx(limits[1], rel=1e-1) - - def test_log_scale_with_plot_data(self, fig_ax, sample_data): - """Test log scale functions with actual plotted data.""" - fig, ax = fig_ax - x_data = sample_data["exponential"] - y_data = sample_data["large_range"] - - # Plot data - ax.plot(x_data, y_data, "o-") - - # Apply log scaling - scitex.plt.ax.set_log_scale(ax, axis="both", show_minor_ticks=True, grid=True) - - # Add indicators - scitex.plt.ax.add_log_scale_indicator(ax, axis="both") - - assert ax.get_xscale() == "log" - assert ax.get_yscale() == "log" - - def test_mixed_scale_configuration(self, fig_ax, sample_data): - """Test mixed linear/log scale configuration.""" - fig, ax = fig_ax - - # Linear x, log y - scitex.plt.ax.set_log_scale(ax, axis="y") - scitex.plt.ax.add_log_scale_indicator(ax, axis="y") - - assert ax.get_xscale() == "linear" - assert ax.get_yscale() == "log" - - -class TestLogScaleEdgeCases: - """Test edge cases and error conditions.""" - - def test_log_scale_with_invalid_data_types(self, fig_ax): - """Test error handling with invalid data types.""" - fig, ax = fig_ax - - with pytest.raises((TypeError, ValueError)): - scitex.plt.ax.smart_log_limits("invalid_data") - - with pytest.raises((TypeError, AttributeError)): - scitex.plt.ax.set_log_scale("not_an_axis", axis="x") - - def test_log_scale_with_none_values(self, fig_ax): - """Test handling of None values in parameters.""" - fig, ax = fig_ax - - # These should either work with defaults or raise appropriate errors - try: - scitex.plt.ax.set_log_scale(ax, axis=None) - except (ValueError, TypeError): - pass # Expected for invalid axis - - try: - scitex.plt.ax.add_log_scale_indicator(ax, axis=None) - except (ValueError, TypeError): - pass # Expected for invalid axis - - def test_log_scale_very_small_numbers(self): - """Test log scale with very small positive numbers.""" - data = np.array([1e-10, 1e-8, 1e-6, 1e-4]) - - limits = scitex.plt.ax.smart_log_limits(data) - - assert limits[0] > 0 - assert limits[1] > limits[0] - assert limits[0] <= data.min() - assert limits[1] >= data.max() - - def test_log_scale_very_large_numbers(self): - """Test log scale with very large numbers.""" - data = np.array([1e6, 1e8, 1e10, 1e12]) - - limits = scitex.plt.ax.smart_log_limits(data) - - assert limits[0] > 0 - assert limits[1] > limits[0] - assert limits[0] <= data.min() - assert limits[1] >= data.max() - - @pytest.mark.parametrize("base", [2, np.e, 10, 5]) - def test_log_scale_different_bases(self, fig_ax, base): - """Test log scale with different logarithmic bases.""" - fig, ax = fig_ax - - scitex.plt.ax.set_log_scale(ax, axis="x", base=base) - - assert ax.get_xscale() == "log" - - def test_log_scale_persistence_after_operations(self, fig_ax, sample_data): - """Test that log scale persists after various plot operations.""" - fig, ax = fig_ax - data = sample_data["exponential"] - - # Set log scale - scitex.plt.ax.set_log_scale(ax, axis="both") - - # Perform various operations - ax.plot(data, data) - ax.set_title("Test Plot") - ax.grid(True) - - # Log scale should persist - assert ax.get_xscale() == "log" - assert ax.get_yscale() == "log" - - -# Run specific test classes for debugging - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_log_scale.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Time-stamp: "2025-06-04 11:10:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_style/_set_log_scale.py -# -# """ -# Functionality: -# Set logarithmic scale with proper minor ticks for scientific plots -# Input: -# Matplotlib axes object and scale parameters -# Output: -# Axes with properly configured logarithmic scale -# Prerequisites: -# matplotlib, numpy -# """ -# -# import numpy as np -# import matplotlib.pyplot as plt -# from matplotlib.ticker import LogLocator, LogFormatter, NullFormatter -# from typing import Union, Optional, List -# -# -# def set_log_scale( -# ax, -# axis: str = "both", -# base: Union[int, float] = 10, -# show_minor_ticks: bool = True, -# minor_tick_length: float = 2.0, -# major_tick_length: float = 4.0, -# minor_tick_width: float = 0.5, -# major_tick_width: float = 0.8, -# grid: bool = False, -# minor_grid: bool = False, -# grid_alpha: float = 0.3, -# minor_grid_alpha: float = 0.15, -# format_minor_labels: bool = False, -# scientific_notation: bool = True, -# ) -> object: -# """ -# Set logarithmic scale with comprehensive minor tick support. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The axes object to modify -# axis : str, optional -# Which axis to set: 'x', 'y', or 'both', by default 'both' -# base : Union[int, float], optional -# Logarithmic base, by default 10 -# show_minor_ticks : bool, optional -# Whether to show minor ticks, by default True -# minor_tick_length : float, optional -# Length of minor ticks in points, by default 2.0 -# major_tick_length : float, optional -# Length of major ticks in points, by default 4.0 -# minor_tick_width : float, optional -# Width of minor ticks in points, by default 0.5 -# major_tick_width : float, optional -# Width of major ticks in points, by default 0.8 -# grid : bool, optional -# Whether to show major grid lines, by default False -# minor_grid : bool, optional -# Whether to show minor grid lines, by default False -# grid_alpha : float, optional -# Alpha for major grid lines, by default 0.3 -# minor_grid_alpha : float, optional -# Alpha for minor grid lines, by default 0.15 -# format_minor_labels : bool, optional -# Whether to show labels on minor ticks, by default False -# scientific_notation : bool, optional -# Whether to use scientific notation for labels, by default True -# -# Returns -# ------- -# matplotlib.axes.Axes -# The modified axes object -# -# Examples -# -------- -# >>> fig, ax = plt.subplots() -# >>> ax.semilogy([1, 10, 100, 1000], [1, 2, 3, 4]) -# >>> set_log_scale(ax, axis='y', show_minor_ticks=True, grid=True) -# """ -# -# if axis in ["x", "both"]: -# _configure_log_axis( -# ax, -# "x", -# base, -# show_minor_ticks, -# minor_tick_length, -# major_tick_length, -# minor_tick_width, -# major_tick_width, -# grid, -# minor_grid, -# grid_alpha, -# minor_grid_alpha, -# format_minor_labels, -# scientific_notation, -# ) -# -# if axis in ["y", "both"]: -# _configure_log_axis( -# ax, -# "y", -# base, -# show_minor_ticks, -# minor_tick_length, -# major_tick_length, -# minor_tick_width, -# major_tick_width, -# grid, -# minor_grid, -# grid_alpha, -# minor_grid_alpha, -# format_minor_labels, -# scientific_notation, -# ) -# -# return ax -# -# -# def _configure_log_axis( -# ax, -# axis_name: str, -# base: Union[int, float], -# show_minor_ticks: bool, -# minor_tick_length: float, -# major_tick_length: float, -# minor_tick_width: float, -# major_tick_width: float, -# grid: bool, -# minor_grid: bool, -# grid_alpha: float, -# minor_grid_alpha: float, -# format_minor_labels: bool, -# scientific_notation: bool, -# ) -> None: -# """Configure a single axis for logarithmic scale.""" -# -# # Set the logarithmic scale -# if axis_name == "x": -# ax.set_xscale("log", base=base) -# axis_obj = ax.xaxis -# tick_params_kwargs = {"axis": "x"} -# else: # y-axis -# ax.set_yscale("log", base=base) -# axis_obj = ax.yaxis -# tick_params_kwargs = {"axis": "y"} -# -# # Configure major ticks -# major_locator = LogLocator(base=base, numticks=12) -# axis_obj.set_major_locator(major_locator) -# -# # Configure major tick formatting -# if scientific_notation: -# major_formatter = LogFormatter(base=base, labelOnlyBase=False) -# else: -# major_formatter = LogFormatter(base=base, labelOnlyBase=True) -# axis_obj.set_major_formatter(major_formatter) -# -# # Configure minor ticks -# if show_minor_ticks: -# # Create minor tick positions -# minor_locator = LogLocator(base=base, subs="all", numticks=100) -# axis_obj.set_minor_locator(minor_locator) -# -# # Format minor tick labels -# if format_minor_labels: -# minor_formatter = LogFormatter(base=base, labelOnlyBase=False) -# else: -# minor_formatter = NullFormatter() # No labels on minor ticks -# axis_obj.set_minor_formatter(minor_formatter) -# -# # Set minor tick appearance -# ax.tick_params( -# which="minor", -# length=minor_tick_length, -# width=minor_tick_width, -# **tick_params_kwargs, -# ) -# -# # Set major tick appearance -# ax.tick_params( -# which="major", -# length=major_tick_length, -# width=major_tick_width, -# **tick_params_kwargs, -# ) -# -# # Configure grid -# if grid or minor_grid: -# ax.grid(True, which="major", alpha=grid_alpha if grid else 0) -# if minor_grid and show_minor_ticks: -# ax.grid(True, which="minor", alpha=minor_grid_alpha) -# -# -# def smart_log_limits( -# data: Union[List, np.ndarray], -# axis: str = "y", -# base: Union[int, float] = 10, -# padding_factor: float = 0.1, -# min_decades: int = 1, -# ) -> tuple: -# """ -# Calculate smart logarithmic axis limits based on data. -# -# Parameters -# ---------- -# data : Union[List, np.ndarray] -# Data values to calculate limits from -# axis : str, optional -# Axis name for reference, by default 'y' -# base : Union[int, float], optional -# Logarithmic base, by default 10 -# padding_factor : float, optional -# Padding as fraction of data range, by default 0.1 -# min_decades : int, optional -# Minimum number of decades to show, by default 1 -# -# Returns -# ------- -# tuple -# (lower_limit, upper_limit) -# -# Examples -# -------- -# >>> smart_log_limits([1, 10, 100, 1000]) -# (0.1, 10000.0) -# """ -# data_array = np.array(data) -# positive_data = data_array[data_array > 0] -# -# if len(positive_data) == 0: -# return 1, base**min_decades -# -# data_min = np.min(positive_data) -# data_max = np.max(positive_data) -# -# # Calculate log range -# log_min = np.log(data_min) / np.log(base) -# log_max = np.log(data_max) / np.log(base) -# log_range = log_max - log_min -# -# # Ensure minimum range -# if log_range < min_decades: -# log_center = (log_min + log_max) / 2 -# log_min = log_center - min_decades / 2 -# log_max = log_center + min_decades / 2 -# log_range = min_decades -# -# # Add padding -# padding = log_range * padding_factor -# log_min_padded = log_min - padding -# log_max_padded = log_max + padding -# -# # Convert back to linear scale -# lower_limit = base**log_min_padded -# upper_limit = base**log_max_padded -# -# return lower_limit, upper_limit -# -# -# def add_log_scale_indicator( -# ax, -# axis: str = "y", -# base: Union[int, float] = 10, -# position: str = "auto", -# fontsize: Union[str, int] = "small", -# color: str = "gray", -# alpha: float = 0.7, -# ) -> None: -# """ -# Add a log scale indicator to the plot. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The axes object -# axis : str, optional -# Which axis has log scale, by default 'y' -# base : Union[int, float], optional -# Logarithmic base, by default 10 -# position : str, optional -# Position of indicator: 'auto', 'top-left', 'top-right', 'bottom-left', 'bottom-right', by default 'auto' -# fontsize : Union[str, int], optional -# Font size for indicator, by default 'small' -# color : str, optional -# Color of indicator text, by default 'gray' -# alpha : float, optional -# Alpha transparency, by default 0.7 -# -# Examples -# -------- -# >>> add_log_scale_indicator(ax, axis='y', base=10) -# """ -# # Determine position -# if position == "auto": -# if axis == "y": -# position = "top-left" -# else: -# position = "bottom-right" -# -# # Position mapping -# positions = { -# "top-left": (0.05, 0.95), -# "top-right": (0.95, 0.95), -# "bottom-left": (0.05, 0.05), -# "bottom-right": (0.95, 0.05), -# } -# -# x_pos, y_pos = positions.get(position, (0.05, 0.95)) -# -# # Create indicator text -# if base == 10: -# indicator_text = f"Log₁₀ scale ({axis}-axis)" -# else: -# indicator_text = f"Log_{{{base}}} scale ({axis}-axis)" -# -# # Add text -# ax.text( -# x_pos, -# y_pos, -# indicator_text, -# transform=ax.transAxes, -# fontsize=fontsize, -# color=color, -# alpha=alpha, -# bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), -# ) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_log_scale.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__set_meta.py b/tests/scitex/plt/ax/_style/test__set_meta.py deleted file mode 100644 index 6b20c97b4..000000000 --- a/tests/scitex/plt/ax/_style/test__set_meta.py +++ /dev/null @@ -1,776 +0,0 @@ -#!/usr/bin/env python3 -"""Tests for scitex.plt.ax._style._set_meta module. - -This module provides comprehensive tests for scientific metadata management -for figures with YAML export functionality. -""" - -import datetime -import os -import tempfile - -import pytest -import yaml - -pytest.importorskip("zarr") -from unittest.mock import MagicMock, patch - -import matplotlib.pyplot as plt - -from scitex.plt.ax._style import export_metadata_yaml, set_figure_meta, set_meta - - -class TestSetMeta: - """Test set_meta function.""" - - @pytest.fixture - def setup_figure(self): - """Create a test figure and axis.""" - fig, ax = plt.subplots() - yield fig, ax - plt.close(fig) - - def test_set_meta_basic_caption(self, setup_figure): - """Test setting basic caption metadata.""" - fig, ax = setup_figure - caption = "Test figure showing example data." - - result = set_meta(ax, caption=caption) - - assert result == ax - assert hasattr(fig, "_scitex_metadata") - assert ax in fig._scitex_metadata - assert fig._scitex_metadata[ax]["caption"] == caption - - def test_set_meta_all_parameters(self, setup_figure): - """Test setting all metadata parameters.""" - fig, ax = setup_figure - - metadata_params = { - "caption": "Comprehensive test figure.", - "methods": "Test methodology using synthetic data.", - "stats": "Statistical analysis with p < 0.05.", - "keywords": ["test", "synthetic", "example"], - "experimental_details": { - "n_samples": 100, - "temperature": 25, - "duration": 300, - }, - "journal_style": "nature", - "significance": "Demonstrates metadata functionality.", - } - - result = set_meta(ax, **metadata_params) - - stored_meta = fig._scitex_metadata[ax] - for key, value in metadata_params.items(): - if key == "stats": - assert stored_meta["statistical_analysis"] == value - else: - assert stored_meta[key] == value - - def test_set_meta_keywords_conversion(self, setup_figure): - """Test that single keyword is converted to list.""" - fig, ax = setup_figure - - # Single keyword as string - set_meta(ax, keywords="electrophysiology") - assert fig._scitex_metadata[ax]["keywords"] == ["electrophysiology"] - - # Multiple keywords as list - set_meta(ax, keywords=["neural", "recording"]) - assert fig._scitex_metadata[ax]["keywords"] == ["neural", "recording"] - - def test_set_meta_automatic_metadata(self, setup_figure): - """Test automatic metadata addition.""" - fig, ax = setup_figure - - set_meta(ax, caption="Test") - - stored_meta = fig._scitex_metadata[ax] - assert "created_timestamp" in stored_meta - assert "scitex_version" in stored_meta - - # Verify timestamp format - timestamp = stored_meta["created_timestamp"] - datetime.datetime.fromisoformat(timestamp) # Should not raise - - def test_set_meta_additional_kwargs(self, setup_figure): - """Test additional metadata through kwargs.""" - fig, ax = setup_figure - - set_meta(ax, caption="Test", custom_field="custom_value", another_field=42) - - stored_meta = fig._scitex_metadata[ax] - assert stored_meta["custom_field"] == "custom_value" - assert stored_meta["another_field"] == 42 - - def test_set_meta_yaml_structure(self, setup_figure): - """Test YAML metadata structure.""" - fig, ax = setup_figure - - set_meta(ax, caption="Test", methods="Test method") - - assert hasattr(fig, "_scitex_yaml_metadata") - assert ax in fig._scitex_yaml_metadata - assert fig._scitex_yaml_metadata[ax] == fig._scitex_metadata[ax] - - def test_set_meta_backward_compatibility(self, setup_figure): - """Test backward compatibility with caption storage.""" - fig, ax = setup_figure - caption = "Backward compatible caption" - - set_meta(ax, caption=caption) - - assert hasattr(fig, "_scitex_captions") - assert fig._scitex_captions[ax] == caption - - def test_set_meta_none_values_ignored(self, setup_figure): - """Test that None values are not stored.""" - fig, ax = setup_figure - - set_meta(ax, caption="Test", methods=None, stats=None) - - stored_meta = fig._scitex_metadata[ax] - assert "caption" in stored_meta - assert "methods" not in stored_meta - assert "statistical_analysis" not in stored_meta - - def test_set_meta_multiple_axes(self, setup_figure): - """Test metadata on multiple axes.""" - fig, (ax1, ax2) = plt.subplots(1, 2) - - set_meta(ax1, caption="First panel") - set_meta(ax2, caption="Second panel") - - assert len(fig._scitex_metadata) == 2 - assert fig._scitex_metadata[ax1]["caption"] == "First panel" - assert fig._scitex_metadata[ax2]["caption"] == "Second panel" - - plt.close(fig) - - -class TestSetFigureMeta: - """Test set_figure_meta function.""" - - @pytest.fixture - def setup_figure(self): - """Create a test figure and axis.""" - fig, ax = plt.subplots() - yield fig, ax - plt.close(fig) - - def test_set_figure_meta_basic(self, setup_figure): - """Test basic figure-level metadata.""" - fig, ax = setup_figure - - result = set_figure_meta( - ax, caption="Main figure caption", significance="Important findings" - ) - - assert result == ax - assert hasattr(fig, "_scitex_figure_metadata") - assert fig._scitex_figure_metadata["main_caption"] == "Main figure caption" - assert fig._scitex_figure_metadata["significance"] == "Important findings" - - def test_set_figure_meta_all_parameters(self, setup_figure): - """Test all figure metadata parameters.""" - fig, ax = setup_figure - - metadata = { - "caption": "Comprehensive analysis", - "methods": "Overall methodology", - "stats": "Statistical approach", - "significance": "Key findings", - "funding": "NIH grant R01-12345", - "conflicts": "No conflicts", - "data_availability": "Data at doi:10.5061/example", - } - - set_figure_meta(ax, **metadata) - - fig_meta = fig._scitex_figure_metadata - assert fig_meta["main_caption"] == metadata["caption"] - assert fig_meta["overall_methods"] == metadata["methods"] - assert fig_meta["overall_statistics"] == metadata["stats"] - assert fig_meta["significance"] == metadata["significance"] - assert fig_meta["funding"] == metadata["funding"] - assert fig_meta["conflicts_of_interest"] == metadata["conflicts"] - assert fig_meta["data_availability"] == metadata["data_availability"] - - def test_set_figure_meta_timestamp(self, setup_figure): - """Test automatic timestamp addition.""" - fig, ax = setup_figure - - set_figure_meta(ax, caption="Test") - - assert "created_timestamp" in fig._scitex_figure_metadata - timestamp = fig._scitex_figure_metadata["created_timestamp"] - datetime.datetime.fromisoformat(timestamp) # Should not raise - - def test_set_figure_meta_backward_compatibility(self, setup_figure): - """Test backward compatibility for main caption.""" - fig, ax = setup_figure - caption = "Main caption for compatibility" - - set_figure_meta(ax, caption=caption) - - assert hasattr(fig, "_scitex_main_caption") - assert fig._scitex_main_caption == caption - - def test_set_figure_meta_additional_fields(self, setup_figure): - """Test additional metadata fields through kwargs.""" - fig, ax = setup_figure - - set_figure_meta(ax, caption="Test", custom_field="value", numeric_field=42) - - fig_meta = fig._scitex_figure_metadata - assert fig_meta["custom_field"] == "value" - assert fig_meta["numeric_field"] == 42 - - def test_set_figure_meta_from_different_axes(self): - """Test that figure metadata can be set from any axis.""" - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) - - # Set from different axes - set_figure_meta(ax3, caption="Set from ax3") - - assert fig._scitex_figure_metadata["main_caption"] == "Set from ax3" - - plt.close(fig) - - -class TestExportMetadataYaml: - """Test export_metadata_yaml function.""" - - @pytest.fixture - def setup_figure_with_metadata(self): - """Create figure with both panel and figure metadata.""" - fig, (ax1, ax2) = plt.subplots(1, 2) - - # Set panel metadata - set_meta( - ax1, - caption="Panel 1 caption", - methods="Panel 1 methods", - keywords=["panel1", "test"], - ) - - set_meta(ax2, caption="Panel 2 caption", experimental_details={"samples": 50}) - - # Set figure metadata - set_figure_meta( - ax1, - caption="Main figure caption", - significance="Important results", - funding="Test grant", - ) - - yield fig - plt.close(fig) - - def test_export_metadata_yaml_basic(self, setup_figure_with_metadata): - """Test basic YAML export.""" - fig = setup_figure_with_metadata - - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f: - filepath = f.name - - try: - export_metadata_yaml(fig, filepath) - - # Load and verify YAML - with open(filepath, "r") as f: - data = yaml.safe_load(f) - - assert "figure_metadata" in data - assert "panel_metadata" in data - assert "export_info" in data - - # Check figure metadata - assert data["figure_metadata"]["main_caption"] == "Main figure caption" - assert data["figure_metadata"]["significance"] == "Important results" - - # Check panel metadata - assert "panel_1" in data["panel_metadata"] - assert "panel_2" in data["panel_metadata"] - assert data["panel_metadata"]["panel_1"]["caption"] == "Panel 1 caption" - assert data["panel_metadata"]["panel_2"]["caption"] == "Panel 2 caption" - - finally: - if os.path.exists(filepath): - os.unlink(filepath) - - def test_export_metadata_yaml_empty_figure(self): - """Test export with figure having no metadata.""" - fig, ax = plt.subplots() - - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f: - filepath = f.name - - try: - export_metadata_yaml(fig, filepath) - - with open(filepath, "r") as f: - data = yaml.safe_load(f) - - assert data["figure_metadata"] == {} - assert data["panel_metadata"] == {} - assert "export_info" in data - - finally: - plt.close(fig) - if os.path.exists(filepath): - os.unlink(filepath) - - def test_export_metadata_yaml_structure(self, setup_figure_with_metadata): - """Test exported YAML structure.""" - fig = setup_figure_with_metadata - - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f: - filepath = f.name - - try: - export_metadata_yaml(fig, filepath) - - # Read raw YAML to check formatting - with open(filepath, "r") as f: - yaml_content = f.read() - - # Should be properly formatted - assert "figure_metadata:" in yaml_content - assert "panel_metadata:" in yaml_content - assert "export_info:" in yaml_content - assert " timestamp:" in yaml_content # Proper indentation - - finally: - if os.path.exists(filepath): - os.unlink(filepath) - - def test_export_metadata_yaml_export_info(self, setup_figure_with_metadata): - """Test export info in YAML.""" - fig = setup_figure_with_metadata - - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f: - filepath = f.name - - try: - export_metadata_yaml(fig, filepath) - - with open(filepath, "r") as f: - data = yaml.safe_load(f) - - export_info = data["export_info"] - assert "timestamp" in export_info - assert "scitex_version" in export_info - - # Verify timestamp format - datetime.datetime.fromisoformat(export_info["timestamp"]) - - finally: - if os.path.exists(filepath): - os.unlink(filepath) - - -class TestIntegration: - """Integration tests for metadata system.""" - - def test_complete_workflow(self): - """Test complete metadata workflow.""" - # Create multi-panel figure - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) - - # Set panel metadata - for i, ax in enumerate([ax1, ax2, ax3, ax4], 1): - set_meta( - ax, - caption=f"Panel {i} showing data", - methods=f"Method for panel {i}", - keywords=[f"panel{i}", "test"], - experimental_details={"panel_number": i}, - ) - - # Set figure metadata - set_figure_meta( - ax1, - caption="Complete multi-panel analysis", - significance="Demonstrates metadata system", - data_availability="Test data available", - ) - - # Export to YAML - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f: - filepath = f.name - - try: - export_metadata_yaml(fig, filepath) - - # Verify complete export - with open(filepath, "r") as f: - data = yaml.safe_load(f) - - # Should have all panels - assert len(data["panel_metadata"]) == 4 - - # Each panel should have complete metadata - for i in range(1, 5): - panel_data = data["panel_metadata"][f"panel_{i}"] - assert panel_data["caption"] == f"Panel {i} showing data" - assert panel_data["experimental_details"]["panel_number"] == i - - # Figure metadata should be complete - assert ( - data["figure_metadata"]["main_caption"] - == "Complete multi-panel analysis" - ) - - finally: - plt.close(fig) - if os.path.exists(filepath): - os.unlink(filepath) - - def test_metadata_persistence(self): - """Test that metadata persists across operations.""" - fig, ax = plt.subplots() - - # Set metadata - set_meta(ax, caption="Test caption", methods="Test methods") - - # Perform some plot operations - ax.plot([1, 2, 3], [1, 4, 9]) - ax.set_xlabel("X") - ax.set_ylabel("Y") - - # Metadata should still be there - assert fig._scitex_metadata[ax]["caption"] == "Test caption" - assert fig._scitex_metadata[ax]["methods"] == "Test methods" - - plt.close(fig) - - def test_timestamp_consistency(self): - """Test timestamp consistency across metadata.""" - fig, ax = plt.subplots() - - # Set metadata close together - set_meta(ax, caption="Test") - panel_time = fig._scitex_metadata[ax]["created_timestamp"] - - set_figure_meta(ax, caption="Figure test") - figure_time = fig._scitex_figure_metadata["created_timestamp"] - - # Parse timestamps - panel_dt = datetime.datetime.fromisoformat(panel_time) - figure_dt = datetime.datetime.fromisoformat(figure_time) - - # They should be very close (within 1 second) - time_diff = abs((figure_dt - panel_dt).total_seconds()) - assert time_diff < 1.0, f"Timestamps differ by {time_diff} seconds" - - plt.close(fig) - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_meta.py -# -------------------------------------------------------------------------------- -# #!./env/bin/python3 -# # -*- coding: utf-8 -*- -# # Time-stamp: "2025-06-04 11:35:00 (ywatanabe)" -# # Author: Yusuke Watanabe (ywatanabe@scitex.ai) -# -# """ -# Scientific metadata management for figures with YAML export. -# """ -# -# # Imports -# import yaml -# from typing import Optional, List, Dict, Any -# -# -# # Functions -# def set_meta( -# ax, -# caption=None, -# methods=None, -# stats=None, -# keywords=None, -# experimental_details=None, -# journal_style=None, -# significance=None, -# **kwargs, -# ): -# """Set comprehensive scientific metadata for figures with YAML export -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes or scitex AxisWrapper -# The axes to modify -# caption : str, optional -# Figure caption text -# methods : str, optional -# Experimental methods description -# stats : str, optional -# Statistical analysis details -# keywords : List[str], optional -# Keywords for categorization and search -# experimental_details : Dict[str, Any], optional -# Structured experimental parameters (n_samples, temperature, etc.) -# journal_style : str, optional -# Target journal style ('nature', 'science', 'ieee', 'cell', etc.) -# significance : str, optional -# Significance statement or implications -# **kwargs : additional metadata -# Any additional metadata fields -# -# Returns -# ------- -# ax : matplotlib.axes.Axes or scitex AxisWrapper -# The modified axes -# -# Examples -# -------- -# >>> fig, ax = scitex.plt.subplots() -# >>> ax.plot(x, y, id='neural_data') -# >>> ax.set_xyt(x='Time (ms)', y='Voltage (mV)', t='Neural Recording') -# >>> ax.set_meta( -# ... caption='Intracellular recording showing action potentials.', -# ... methods='Whole-cell patch-clamp in acute brain slices.', -# ... stats='Statistical analysis using paired t-test (p<0.05).', -# ... keywords=['electrophysiology', 'neural_recording', 'patch_clamp'], -# ... experimental_details={ -# ... 'n_samples': 15, -# ... 'temperature': 32, -# ... 'recording_duration': 600, -# ... 'electrode_resistance': '3-5 MΩ' -# ... }, -# ... journal_style='nature', -# ... significance='Demonstrates novel neural dynamics in layer 2/3 pyramidal cells.' -# ... ) -# >>> scitex.io.save(fig, 'neural_recording.png') # YAML metadata auto-saved -# """ -# -# # Build comprehensive metadata dictionary -# metadata = {} -# -# if caption is not None: -# metadata["caption"] = caption -# if methods is not None: -# metadata["methods"] = methods -# if stats is not None: -# metadata["statistical_analysis"] = stats -# if keywords is not None: -# metadata["keywords"] = keywords if isinstance(keywords, list) else [keywords] -# if experimental_details is not None: -# metadata["experimental_details"] = experimental_details -# if journal_style is not None: -# metadata["journal_style"] = journal_style -# if significance is not None: -# metadata["significance"] = significance -# -# # Add any additional metadata -# for key, value in kwargs.items(): -# if value is not None: -# metadata[key] = value -# -# # Add automatic metadata -# import datetime -# -# metadata["created_timestamp"] = datetime.datetime.now().isoformat() -# -# # Get version dynamically -# try: -# import scitex -# -# metadata["scitex_version"] = getattr(scitex, "__version__", "unknown") -# except ImportError: -# metadata["scitex_version"] = "unknown" -# -# # Store metadata in figure for automatic saving -# fig = ax.get_figure() -# if not hasattr(fig, "_scitex_metadata"): -# fig._scitex_metadata = {} -# -# # Use axis as key for panel-specific metadata -# fig._scitex_metadata[ax] = metadata -# -# # Also store as YAML-ready structure -# if not hasattr(fig, "_scitex_yaml_metadata"): -# fig._scitex_yaml_metadata = {} -# fig._scitex_yaml_metadata[ax] = metadata -# -# # Backward compatibility - store simple caption -# if caption is not None: -# if not hasattr(fig, "_scitex_captions"): -# fig._scitex_captions = {} -# fig._scitex_captions[ax] = caption -# -# return ax -# -# -# def set_figure_meta( -# ax, -# caption=None, -# methods=None, -# stats=None, -# significance=None, -# funding=None, -# conflicts=None, -# data_availability=None, -# **kwargs, -# ): -# """Set figure-level metadata for multi-panel figures -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes or scitex AxisWrapper -# Any axis in the figure (figure accessed via ax.get_figure()) -# caption : str, optional -# Figure-level caption -# methods : str, optional -# Overall experimental methods -# stats : str, optional -# Overall statistical approach -# significance : str, optional -# Significance and implications -# funding : str, optional -# Funding acknowledgments -# conflicts : str, optional -# Conflict of interest statement -# data_availability : str, optional -# Data availability statement -# **kwargs : additional metadata -# Any additional figure-level metadata -# -# Returns -# ------- -# ax : matplotlib.axes.Axes or scitex AxisWrapper -# The modified axes -# -# Examples -# -------- -# >>> fig, ((ax1, ax2), (ax3, ax4)) = scitex.plt.subplots(2, 2) -# >>> # Set individual panel metadata... -# >>> ax1.set_meta(caption='Panel A analysis...') -# >>> ax2.set_meta(caption='Panel B comparison...') -# >>> -# >>> # Set figure-level metadata -# >>> ax1.set_figure_meta( -# ... caption='Comprehensive analysis of neural dynamics...', -# ... significance='This work demonstrates novel therapeutic targets.', -# ... funding='Supported by NIH grant R01-NS123456.', -# ... data_availability='Data available at doi:10.5061/dryad.example' -# ... ) -# """ -# -# # Build figure-level metadata -# figure_metadata = {} -# -# if caption is not None: -# figure_metadata["main_caption"] = caption -# if methods is not None: -# figure_metadata["overall_methods"] = methods -# if stats is not None: -# figure_metadata["overall_statistics"] = stats -# if significance is not None: -# figure_metadata["significance"] = significance -# if funding is not None: -# figure_metadata["funding"] = funding -# if conflicts is not None: -# figure_metadata["conflicts_of_interest"] = conflicts -# if data_availability is not None: -# figure_metadata["data_availability"] = data_availability -# -# # Add any additional metadata -# for key, value in kwargs.items(): -# if value is not None: -# figure_metadata[key] = value -# -# # Add automatic metadata -# import datetime -# -# figure_metadata["created_timestamp"] = datetime.datetime.now().isoformat() -# -# # Store in figure -# fig = ax.get_figure() -# fig._scitex_figure_metadata = figure_metadata -# -# # Backward compatibility -# if caption is not None: -# fig._scitex_main_caption = caption -# -# return ax -# -# -# def export_metadata_yaml(fig, filepath): -# """Export all figure metadata to YAML file -# -# Parameters -# ---------- -# fig : matplotlib.figure.Figure -# Figure with metadata -# filepath : str -# Output YAML file path -# """ -# import datetime -# -# # Collect all metadata -# export_data = { -# "figure_metadata": {}, -# "panel_metadata": {}, -# "export_info": { -# "timestamp": datetime.datetime.now().isoformat(), -# "scitex_version": "1.11.0", -# }, -# } -# -# # Figure-level metadata -# if hasattr(fig, "_scitex_figure_metadata"): -# export_data["figure_metadata"] = fig._scitex_figure_metadata -# -# # Panel-level metadata -# if hasattr(fig, "_scitex_yaml_metadata"): -# for i, (ax, metadata) in enumerate(fig._scitex_yaml_metadata.items()): -# panel_key = f"panel_{i + 1}" -# export_data["panel_metadata"][panel_key] = metadata -# -# # Write YAML file -# with open(filepath, "w") as f: -# yaml.dump(export_data, f, default_flow_style=False, sort_keys=False, indent=2) -# -# -# if __name__ == "__main__": -# # Start -# import sys -# import matplotlib.pyplot as plt -# import scitex -# -# CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) -# -# # Example usage -# fig, ax = plt.subplots() -# ax.plot([1, 2, 3], [1, 4, 2]) -# -# set_meta( -# ax, -# caption="Example figure showing data trends.", -# methods="Synthetic data generated for demonstration.", -# keywords=["example", "demo", "synthetic"], -# experimental_details={"n_samples": 3, "data_type": "synthetic"}, -# ) -# -# export_metadata_yaml(fig, "example_metadata.yaml") -# -# # Close -# scitex.session.close(CONFIG) -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_meta.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__set_n_ticks.py b/tests/scitex/plt/ax/_style/test__set_n_ticks.py deleted file mode 100644 index 8fa5f0037..000000000 --- a/tests/scitex/plt/ax/_style/test__set_n_ticks.py +++ /dev/null @@ -1,171 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:31 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__set_n_ticks.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__set_n_ticks.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np - -matplotlib.use("Agg") - -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import set_n_ticks - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure() - self.ax = self.fig.add_subplot(111) - - # Create a basic plot with many potential tick locations - xx = np.linspace(0, 100, 1000) - yy = np.sin(xx * 0.1) - self.ax.plot(xx, yy) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_savefig(self): - from scitex.io import save - - # Main - ax = set_n_ticks(self.ax) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - # def test_basic_functionality(self): - # # Test with default parameters (4 ticks) - # ax = set_n_ticks(self.ax) - - # # Force draw to ensure ticks are updated - # self.fig.canvas.draw() - - # # Only count visible ticks - # visible_xticks = len( - # [t for t in ax.xaxis.get_major_ticks() if t.get_visible()] - # ) - # visible_yticks = len( - # [t for t in ax.yaxis.get_major_ticks() if t.get_visible()] - # ) - - # # Should be approximately 4 ticks - # assert 3 <= visible_xticks <= 5 - # assert 3 <= visible_yticks <= 5 - - # def test_custom_tick_counts(self): - # # Test with custom number of ticks - # ax = set_n_ticks(self.ax, n_xticks=6, n_yticks=3) - - # # Force draw to ensure ticks are updated - # self.fig.canvas.draw() - - # # Only count visible ticks - # visible_xticks = len( - # [t for t in ax.xaxis.get_major_ticks() if t.get_visible()] - # ) - # visible_yticks = len( - # [t for t in ax.yaxis.get_major_ticks() if t.get_visible()] - # ) - - # # Should be approximately the requested number of ticks - # assert 5 <= visible_xticks <= 7 - # assert 2 <= visible_yticks <= 4 - - def test_x_ticks_only(self): - # Test setting only x ticks - ax = set_n_ticks(self.ax, n_xticks=7, n_yticks=None) - - # Force draw to ensure ticks are updated - self.fig.canvas.draw() - - # Check x ticks change but y ticks remain default - visible_xticks = len([t for t in ax.xaxis.get_major_ticks() if t.get_visible()]) - assert 6 <= visible_xticks <= 8 - - def test_y_ticks_only(self): - # Test setting only y ticks - ax = set_n_ticks(self.ax, n_xticks=None, n_yticks=7) - - # Force draw to ensure ticks are updated - self.fig.canvas.draw() - - # Check y ticks change but x ticks remain default - visible_yticks = len([t for t in ax.yaxis.get_major_ticks() if t.get_visible()]) - assert 6 <= visible_yticks <= 8 - - # def test_error_handling(self): - # # Test with invalid input types - # with pytest.raises(Exception): - # set_n_ticks(self.ax, n_xticks="invalid") - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_n_ticks.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-04-29 12:02:14 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_set_n_ticks.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_set_n_ticks.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib -# -# -# def set_n_ticks( -# ax, -# n_xticks=4, -# n_yticks=4, -# ): -# """ -# Example: -# ax = set_n_ticks(ax) -# """ -# -# if n_xticks is not None: -# ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(n_xticks)) -# -# if n_yticks is not None: -# ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(n_yticks)) -# -# # Force the figure to redraw to reflect changes -# ax.figure.canvas.draw() -# -# return ax -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_n_ticks.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__set_size.py b/tests/scitex/plt/ax/_style/test__set_size.py deleted file mode 100644 index 6e3f5480a..000000000 --- a/tests/scitex/plt/ax/_style/test__set_size.py +++ /dev/null @@ -1,154 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:41 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__set_size.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__set_size.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import set_size - -matplotlib.use("Agg") - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure() - self.ax = self.fig.add_subplot(111) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_basic_functionality(self): - # Test setting specific dimensions - target_width = 5.0 # inches - target_height = 3.0 # inches - - # Set the figure to have specific subplotpars for testing - self.fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) - - # Use set_size to adjust figure size - ax = set_size(self.ax, target_width, target_height) - - # Get the figure dimensions and subplot parameters - figsize = ax.figure.get_size_inches() - l = ax.figure.subplotpars.left - r = ax.figure.subplotpars.right - t = ax.figure.subplotpars.top - b = ax.figure.subplotpars.bottom - - # Calculate the expected figure dimensions - expected_width = target_width / (r - l) - expected_height = target_height / (t - b) - - # Check that figure dimensions match expected values - assert np.isclose(figsize[0], expected_width) - assert np.isclose(figsize[1], expected_height) - - def test_aspect_ratio(self): - # Test maintaining a specific aspect ratio - target_width = 4.0 # inches - target_height = 4.0 # inches (square) - - # Use set_size to adjust figure size - ax = set_size(self.ax, target_width, target_height) - - # Get the figure dimensions - figsize = ax.figure.get_size_inches() - - # Check that aspect ratio is maintained - aspect_ratio = figsize[0] / figsize[1] - expected_aspect_ratio = 1.0 # Square - - assert np.isclose(aspect_ratio, expected_aspect_ratio, rtol=0.1) - - def test_edge_cases(self): - # Test with very small dimensions - ax = set_size(self.ax, 0.1, 0.1) - figsize = ax.figure.get_size_inches() - assert figsize[0] > 0 - assert figsize[1] > 0 - - # Test with very large dimensions - ax = set_size(self.ax, 100, 100) - figsize = ax.figure.get_size_inches() - assert figsize[0] > 0 - assert figsize[1] > 0 - - def test_wide_figure(self): - # Test with wide aspect ratio - target_width = 8.0 # inches - target_height = 2.0 # inches - - # Use set_size to adjust figure size - ax = set_size(self.ax, target_width, target_height) - - # Get the figure dimensions - figsize = ax.figure.get_size_inches() - - # Check that aspect ratio is correct - aspect_ratio = figsize[0] / figsize[1] - expected_aspect_ratio = 4.0 # Wide rectangle - - assert np.isclose(aspect_ratio, expected_aspect_ratio, rtol=0.1) - - def test_savefig(self): - from scitex.io import save - - # Main test functionality - target_width = 4.0 - target_height = 3.0 - self.ax.plot([1, 2, 3], [1, 2, 3]) - set_size(self.ax, target_width, target_height) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_size.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Time-stamp: "2022-12-09 13:38:11 (ywatanabe)" -# -# -# def set_size(ax, w, h): -# """w, h: width, height in inches""" -# # if not ax: ax=plt.gca() -# l = ax.figure.subplotpars.left -# r = ax.figure.subplotpars.right -# t = ax.figure.subplotpars.top -# b = ax.figure.subplotpars.bottom -# figw = float(w) / (r - l) -# figh = float(h) / (t - b) -# ax.figure.set_size_inches(figw, figh) -# return ax - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_size.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__set_supxyt.py b/tests/scitex/plt/ax/_style/test__set_supxyt.py deleted file mode 100644 index 215b4eb13..000000000 --- a/tests/scitex/plt/ax/_style/test__set_supxyt.py +++ /dev/null @@ -1,237 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:22 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__set_supxyt.py -# ---------------------------------------- -import os - -import pytest - -pytest.importorskip("zarr") - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__set_supxyt.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from scitex.plt.ax._style import format_label - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - pass - - def teardown_method(self): - # Clean up after tests - pass - - def test_basic_functionality(self): - # Test with different input types - # Current implementation just returns the label unchanged - - # String input - assert format_label("test_label") == "test_label" - - # Numeric input - assert format_label(123) == 123 - - # Empty string - assert format_label("") == "" - - def test_edge_cases(self): - # Test with None - assert format_label(None) == None - - # Test with special characters - assert format_label("特殊字符@#$%") == "特殊字符@#$%" - - def test_commented_functionality(self): - # This tests the currently commented out functionality - # It's useful to keep these tests for if/when this functionality is uncommented - - # This function should return the input label unchanged with current implementation - assert format_label("test_label") == "test_label" - - # If uncommented, it would capitalize and replace underscores: - # assert format_label("test_label") == "Test Label" - - # If uncommented, it would handle uppercase: - # assert format_label("TEST") == "TEST" - - def test_savefig(self): - import matplotlib.pyplot as plt - - from scitex.io import save - - # Setup - fig = plt.figure() - ax = fig.add_subplot(111) - ax.plot([1, 2, 3], [1, 2, 3]) - - # Main test functionality - from scitex.plt.ax._style import set_supxyt - - set_supxyt( - ax, - xlabel="Super X Label", - ylabel="Super Y Label", - title="Super Title", - ) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_supxyt.py -# -------------------------------------------------------------------------------- -# #!./env/bin/python3 -# # -*- coding: utf-8 -*- -# # Time-stamp: "2024-07-13 07:56:46 (ywatanabe)" -# # Author: Yusuke Watanabe (ywatanabe@scitex.ai) -# -# """ -# This script does XYZ. -# """ -# -# # Imports -# import matplotlib.pyplot as plt -# -# from ._format_label import format_label -# -# -# # Functions -# def set_supxyt(ax, xlabel=False, ylabel=False, title=False, format_labels=True): -# """Sets xlabel, ylabel and title""" -# fig = ax.get_figure() -# -# # if xlabel is not False: -# # fig.supxlabel(xlabel) -# -# # if ylabel is not False: -# # fig.supylabel(ylabel) -# -# # if title is not False: -# # fig.suptitle(title) -# if xlabel is not False: -# xlabel = format_label(xlabel) if format_labels else xlabel -# fig.supxlabel(xlabel) -# -# if ylabel is not False: -# ylabel = format_label(ylabel) if format_labels else ylabel -# fig.supylabel(ylabel) -# -# if title is not False: -# title = format_label(title) if format_labels else title -# fig.suptitle(title) -# -# return ax -# -# -# def set_supxytc( -# ax, -# xlabel=False, -# ylabel=False, -# title=False, -# caption=False, -# methods=False, -# stats=False, -# significance=False, -# format_labels=True, -# ): -# """Sets figure-level xlabel, ylabel, title, and caption with SciTeX-Paper integration -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes or scitex AxisWrapper -# The axes to modify (figure accessed via ax.get_figure()) -# xlabel : str or False, optional -# Figure-level X-axis label, by default False -# ylabel : str or False, optional -# Figure-level Y-axis label, by default False -# title : str or False, optional -# Figure-level title (suptitle), by default False -# caption : str or False, optional -# Figure-level caption to store for later use with scitex.io.save(), by default False -# methods : str or False, optional -# Overall methods description for SciTeX-Paper integration, by default False -# stats : str or False, optional -# Overall statistical analysis details for SciTeX-Paper integration, by default False -# significance : str or False, optional -# Significance statement for SciTeX-Paper integration, by default False -# format_labels : bool, optional -# Whether to apply automatic formatting, by default True -# -# Returns -# ------- -# ax : matplotlib.axes.Axes or scitex AxisWrapper -# The modified axes -# -# Examples -# -------- -# >>> fig, ((ax1, ax2), (ax3, ax4)) = scitex.plt.subplots(2, 2) -# >>> # Add plots to each panel... -# >>> ax1.set_supxytc(xlabel='Time (s)', ylabel='Signal Amplitude', -# ... title='Multi-Panel Analysis', -# ... caption='Comprehensive analysis showing (A) raw data, (B) filtered signal, (C) power spectrum, and (D) phase analysis.', -# ... methods='All experiments performed using standardized protocols.', -# ... significance='This work demonstrates novel therapeutic targets.') -# >>> scitex.io.save(fig, 'multi_panel.png') # Caption automatically saved -# """ -# # Set labels and title using existing function -# set_supxyt( -# ax, xlabel=xlabel, ylabel=ylabel, title=title, format_labels=format_labels -# ) -# -# # Store figure-level caption and extended metadata -# if ( -# caption is not False -# or methods is not False -# or stats is not False -# or significance is not False -# ): -# fig = ax.get_figure() -# # Store comprehensive figure-level metadata -# fig_metadata = { -# "main_caption": caption if caption is not False else None, -# "methods": methods if methods is not False else None, -# "stats": stats if stats is not False else None, -# "significance": significance if significance is not False else None, -# } -# -# fig._scitex_figure_metadata = fig_metadata -# -# # Backward compatibility - also store simple caption -# if caption is not False: -# fig._scitex_main_caption = caption -# -# return ax -# -# -# if __name__ == "__main__": -# # Start -# CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) -# -# # (YOUR AWESOME CODE) -# -# # Close -# scitex.session.close(CONFIG) -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_supxyt.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__set_ticks.py b/tests/scitex/plt/ax/_style/test__set_ticks.py deleted file mode 100644 index b57df6b46..000000000 --- a/tests/scitex/plt/ax/_style/test__set_ticks.py +++ /dev/null @@ -1,422 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:39 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__set_ticks.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__set_ticks.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import set_ticks -from scitex.plt.ax._style._set_ticks import set_x_ticks, set_y_ticks - -matplotlib.use("Agg") - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure(figsize=(6, 4)) - self.ax = self.fig.add_subplot(111) - # Create a basic plot - xx = np.linspace(0, 10, 100) - yy = np.sin(xx) - self.ax.plot(xx, yy) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_set_x_ticks_basic(self): - # Test with custom x tick locations - x_ticks = [0, 2.5, 5, 7.5, 10] - ax = set_x_ticks(self.ax, x_ticks=x_ticks) - - # Check that the ticks were set correctly - assert len(ax.get_xticks()) == len(x_ticks) - assert np.allclose(ax.get_xticks(), x_ticks) - - def test_set_y_ticks_basic(self): - # Test with custom y tick locations - y_ticks = [-1, -0.5, 0, 0.5, 1] - ax = set_y_ticks(self.ax, y_ticks=y_ticks) - - # Check that the ticks were set correctly - assert len(ax.get_yticks()) == len(y_ticks) - assert np.allclose(ax.get_yticks(), y_ticks) - - def test_set_ticks_combined(self): - # Test setting both x and y ticks - x_ticks = [0, 5, 10] - y_ticks = [-1, 0, 1] - - ax = set_ticks(self.ax, xticks=x_ticks, yticks=y_ticks) - - # Check that the ticks were set correctly - assert np.allclose(ax.get_xticks(), x_ticks) - assert np.allclose(ax.get_yticks(), y_ticks) - - def test_set_ticks_with_vals_and_ticks(self): - # Test with both vals and ticks - x_vals = np.linspace(0, 100, 11) # [0, 10, 20, ..., 100] - x_ticks = [0, 50, 100] - - # This should create a new mapping of the axis - ax = set_x_ticks(self.ax, x_vals=x_vals, x_ticks=x_ticks) - - # Check the ticks after mapping - assert len(ax.get_xticks()) > 0 - assert len(ax.get_xticklabels()) > 0 - - # def test_string_ticks(self): - # # Test with string tick labels - # x_ticks = ["A", "B", "C", "D"] - # ax = set_x_ticks(self.ax, x_ticks=x_ticks) - - # # Get the tick labels (need to render the figure for this) - # fig = ax.get_figure() - # fig.canvas.draw() - # tick_labels = [label.get_text() for label in ax.get_xticklabels()] - - # # Check that the tick labels match - # assert tick_labels == x_ticks - - # def test_dict_ticks(self): - # # This test depends on the implementation of to_str from scitex.dict - # # Testing with a mock instead - # from unittest.mock import patch - - # with patch("scitex.plt.ax._set_ticks.is_listed_X", return_value=True): - # with patch( - # "scitex.plt.ax._set_ticks.to_str", - # side_effect=lambda x, delimiter: f"formatted_{x}", - # ): - # x_ticks = [{"a": 1}, {"b": 2}] - # ax = set_x_ticks(self.ax, x_ticks=x_ticks) - - # # Get the tick labels - # fig = ax.get_figure() - # fig.canvas.draw() - # tick_labels = [ - # label.get_text() for label in ax.get_xticklabels() - # ] - - # # Check formatted labels (matching the mock's return value) - # assert "formatted_" in tick_labels[0] - - def test_savefig(self): - from scitex.io import save - - # Main test functionality - x_ticks = [0, 2.5, 5, 7.5, 10] - y_ticks = [-1, -0.5, 0, 0.5, 1] - set_ticks(self.ax, xticks=x_ticks, yticks=y_ticks) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_ticks.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-04-27 20:04:55 (ywatanabe)" -# # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_set_ticks.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_style/_set_ticks.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib.pyplot as plt -# import numpy as np -# -# from scitex.dict._to_str import to_str -# from scitex.types import is_listed_X -# -# -# def set_ticks(ax, xvals=None, xticks=None, yvals=None, yticks=None): -# """Set custom tick labels on both x and y axes. -# -# Convenience function to set tick positions and labels for both axes -# at once. Automatically handles canvas updates for interactive backends. -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes -# The axes object to modify. -# xvals : array-like, optional -# Values corresponding to x-axis data points. -# xticks : list, optional -# Desired tick labels for the x-axis. -# yvals : array-like, optional -# Values corresponding to y-axis data points. -# yticks : list, optional -# Desired tick labels for the y-axis. -# -# Returns -# ------- -# matplotlib.axes.Axes -# The modified axes object. -# -# Examples -# -------- -# >>> fig, ax = plt.subplots() -# >>> x = np.linspace(0, 10, 100) -# >>> ax.plot(x, np.sin(x)) -# >>> ax = set_ticks(ax, xvals=x, xticks=[0, 5, 10], -# ... yvals=[-1, 0, 1], yticks=['-1', '0', '1']) -# -# See Also -# -------- -# set_x_ticks : Set ticks for x-axis only -# set_y_ticks : Set ticks for y-axis only -# """ -# ax = set_x_ticks(ax, x_vals=xvals, x_ticks=xticks) -# ax = set_y_ticks(ax, y_vals=yvals, y_ticks=yticks) -# canvas_type = type(ax.figure.canvas).__name__ -# if "TkAgg" in canvas_type: -# ax.get_figure().canvas.draw() # Redraw the canvas once after making all updates -# return ax -# -# -# def set_x_ticks(ax, x_vals=None, x_ticks=None): -# """ -# Set custom tick labels on the x and y axes based on specified values and desired ticks. -# -# Parameters: -# - ax: The axis object to modify. -# - x_vals: Array of x-axis values. -# - x_ticks: List of desired tick labels on the x-axis. -# - y_vals: Array of y-axis values. -# - y_ticks: List of desired tick labels on the y-axis. -# -# Example: -# import matplotlib.pyplot as plt -# import numpy as np -# -# fig, axes = plt.subplots(nrows=4) -# x = np.linspace(0, 10, 100) -# y = np.sin(x) -# for ax in axes: -# ax.plot(x, y) # Plot a sine wave -# -# set_ticks(axes[0]) # Do nothing # OK -# set_ticks(axes[1], x_vals=x+3) # OK -# set_ticks(axes[2], x_ticks=[1,2]) # OK -# set_ticks(axes[3], x_vals=x+3, x_ticks=[4,5]) # Auto-generate ticks across the range -# fig.tight_layout() -# plt.show() -# """ -# -# def _avoid_overlaps(values): -# values = np.array(values) -# if ("int" in str(values.dtype)) or ("float" in str(values.dtype)): -# values = values.astype(float) + np.arange(len(values)) * 1e-5 -# return values -# -# def _set_x_vals(ax, x_vals): -# x_vals = _avoid_overlaps(x_vals) -# new_x_axis = np.linspace(*ax.get_xlim(), len(x_vals)) -# ax.set_xticks(new_x_axis) -# ax.set_xticklabels([f"{xv}" for xv in x_vals]) -# return ax -# -# def _set_x_ticks(ax, x_ticks): -# x_ticks = np.array(x_ticks) -# if x_ticks.dtype.kind in ["U", "S", "O"]: # If x_ticks are strings -# ax.set_xticks(range(len(x_ticks))) -# ax.set_xticklabels(x_ticks) -# else: -# x_vals = np.array( -# [label.get_text().replace("−", "-") for label in ax.get_xticklabels()] -# ) -# x_vals = x_vals.astype(float) -# x_indi = np.argmin( -# np.array(np.abs(x_vals[:, np.newaxis] - x_ticks[np.newaxis, :])), -# axis=0, -# ) -# ax.set_xticks(ax.get_xticks()[x_indi]) -# ax.set_xticklabels([f"{xt}" for xt in x_ticks]) -# return ax -# -# x_vals_passed = x_vals is not None -# x_ticks_passed = x_ticks is not None -# -# if is_listed_X(x_ticks, dict): -# x_ticks = [to_str(xt, delimiter="\n") for xt in x_ticks] -# -# if (not x_vals_passed) and (not x_ticks_passed): -# # Do nothing -# pass -# -# elif x_vals_passed and (not x_ticks_passed): -# # Replaces the x axis to x_vals -# x_ticks = np.linspace(x_vals[0], x_vals[-1], 4) -# ax = _set_x_vals(ax, x_ticks) -# -# elif (not x_vals_passed) and x_ticks_passed: -# # Locates 'x_ticks' on the original x axis -# ax.set_xticks(x_ticks) -# -# elif x_vals_passed and x_ticks_passed: -# if isinstance(x_vals, str): -# if x_vals == "auto": -# x_vals = np.arange(len(x_ticks)) -# -# # Replaces the original x axis to 'x_vals' and locates the 'x_ticks' on the new axis -# ax = _set_x_vals(ax, x_vals) -# ax = _set_x_ticks(ax, x_ticks) -# -# return ax -# -# -# def set_y_ticks(ax, y_vals=None, y_ticks=None): -# """ -# Set custom tick labels on the y-axis based on specified values and desired ticks. -# -# Parameters: -# - ax: The axis object to modify. -# - y_vals: Array of y-axis values where ticks should be placed. -# - y_ticks: List of labels for ticks on the y-axis. -# -# Example: -# import matplotlib.pyplot as plt -# import numpy as np -# -# fig, ax = plt.subplots() -# x = np.linspace(0, 10, 100) -# y = np.sin(x) -# ax.plot(x, y) # Plot a sine wave -# -# set_y_ticks(ax, y_vals=y, y_ticks=['Low', 'High']) # Set custom y-axis ticks -# plt.show() -# """ -# -# def _avoid_overlaps(values): -# values = np.array(values) -# if ("int" in str(values.dtype)) or ("float" in str(values.dtype)): -# values = values.astype(float) + np.arange(len(values)) * 1e-5 -# return values -# -# def _set_y_vals(ax, y_vals): -# y_vals = _avoid_overlaps(y_vals) -# new_y_axis = np.linspace(*ax.get_ylim(), len(y_vals)) -# ax.set_yticks(new_y_axis) -# ax.set_yticklabels([f"{yv:.2f}" for yv in y_vals]) -# return ax -# -# # def _set_y_ticks(ax, y_ticks): -# # y_ticks = np.array(y_ticks) -# # y_vals = np.array( -# # [ -# # label.get_text().replace("−", "-") -# # for label in ax.get_yticklabels() -# # ] -# # ) -# # y_vals = y_vals.astype(float) -# # y_indi = np.argmin( -# # np.array(np.abs(y_vals[:, np.newaxis] - y_ticks[np.newaxis, :])), -# # axis=0, -# # ) -# -# # # y_indi = [np.argmin(np.abs(y_vals - yt)) for yt in y_ticks] -# # ax.set_yticks(ax.get_yticks()[y_indi]) -# # ax.set_yticklabels([f"{yt}" for yt in y_ticks]) -# # return ax -# def _set_y_ticks(ax, y_ticks): -# y_ticks = np.array(y_ticks) -# if y_ticks.dtype.kind in ["U", "S", "O"]: # If y_ticks are strings -# ax.set_yticks(range(len(y_ticks))) -# ax.set_yticklabels(y_ticks) -# else: -# y_vals = np.array( -# [label.get_text().replace("−", "-") for label in ax.get_yticklabels()] -# ) -# y_vals = y_vals.astype(float) -# y_indi = np.argmin( -# np.array(np.abs(y_vals[:, np.newaxis] - y_ticks[np.newaxis, :])), -# axis=0, -# ) -# ax.set_yticks(ax.get_yticks()[y_indi]) -# ax.set_yticklabels([f"{yt}" for yt in y_ticks]) -# return ax -# -# y_vals_passed = y_vals is not None -# y_ticks_passed = y_ticks is not None -# -# if is_listed_X(y_ticks, dict): -# y_ticks = [to_str(yt, delimiter="\n") for yt in y_ticks] -# -# if (not y_vals_passed) and (not y_ticks_passed): -# # Do nothing -# pass -# -# elif y_vals_passed and (not y_ticks_passed): -# # Replaces the y axis to y_vals -# ax = _set_y_vals(ax, y_vals) -# -# elif (not y_vals_passed) and y_ticks_passed: -# # Locates 'y_ticks' on the original y axis -# ax.set_yticks(y_ticks) -# -# elif y_vals_passed and y_ticks_passed: -# # Replaces the original y axis to 'y_vals' and locates the 'y_ticks' on the new axis -# if y_vals == "auto": -# y_vals = np.arange(len(y_ticks)) -# -# ax = _set_y_vals(ax, y_vals) -# ax = _set_y_ticks(ax, y_ticks) -# return ax -# -# -# if __name__ == "__main__": -# import scitex -# -# xx, tt, fs = scitex.dsp.demo_sig() -# pha, amp, freqs = scitex.dsp.wavelet(xx, fs) -# -# i_batch, i_ch = 0, 0 -# ff = freqs[i_batch, i_ch] -# fig, ax = scitex.plt.subplots() -# -# ax.image2d(amp[i_batch, i_ch]) -# -# ax = set_ticks( -# ax, -# x_vals=tt, -# x_ticks=[0, 1, 2, 3, 4], -# y_vals=ff, -# y_ticks=[0, 128, 256], -# ) -# -# plt.show() -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_ticks.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__set_xyt.py b/tests/scitex/plt/ax/_style/test__set_xyt.py deleted file mode 100644 index ebb8c3eac..000000000 --- a/tests/scitex/plt/ax/_style/test__set_xyt.py +++ /dev/null @@ -1,252 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:35 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__set_xyt.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__set_xyt.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import set_xyt - -matplotlib.use("Agg") - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure() - self.ax = self.fig.add_subplot(111) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_basic_functionality(self): - # Test setting all three labels - ax = set_xyt(self.ax, x="Test X", y="Test Y", t="Test Title") - - assert ax.get_xlabel() == "Test X" - assert ax.get_ylabel() == "Test Y" - assert ax.get_title() == "Test Title" - - def test_partial_labels(self): - # Test setting only some labels - ax1 = set_xyt(self.ax, x="Only X") - assert ax1.get_xlabel() == "Only X" - assert ax1.get_ylabel() == "" - assert ax1.get_title() == "" - - ax2 = set_xyt(self.ax, y="Only Y") - assert ax2.get_xlabel() == "Only X" # Still has previous value - assert ax2.get_ylabel() == "Only Y" - assert ax2.get_title() == "" - - ax3 = set_xyt(self.ax, t="Only Title") - assert ax3.get_xlabel() == "Only X" # Still has previous value - assert ax3.get_ylabel() == "Only Y" # Still has previous value - assert ax3.get_title() == "Only Title" - - # def test_format_labels_option(self): - # # Test with format_labels=False - # with patch( - # "scitex.plt.ax._format_label.format_label", - # side_effect=lambda x: x.upper(), - # ): - # # When format_labels=True, it should call format_label - # ax1 = set_xyt( - # self.ax, x="test", y="test", t="test", format_labels=True - # ) - # assert ax1.get_xlabel() == "TEST" - # assert ax1.get_ylabel() == "TEST" - # assert ax1.get_title() == "TEST" - - # # When format_labels=False, it should not call format_label - # ax2 = set_xyt( - # self.ax, x="test", y="test", t="test", format_labels=False - # ) - # assert ax2.get_xlabel() == "test" - # assert ax2.get_ylabel() == "test" - # assert ax2.get_title() == "test" - - def test_edge_cases(self): - # Test with False values (which should skip setting those labels) - ax = set_xyt(self.ax, x=False, y=False, t=False) - assert ax.get_xlabel() == "" - assert ax.get_ylabel() == "" - assert ax.get_title() == "" - - # Test with empty strings - ax = set_xyt(self.ax, x="", y="", t="") - assert ax.get_xlabel() == "" - assert ax.get_ylabel() == "" - assert ax.get_title() == "" - - def test_savefig(self): - from scitex.io import save - - # Main test functionality - self.ax.plot([1, 2, 3], [1, 2, 3]) - set_xyt(self.ax, x="X Label", y="Y Label", t="Test Title") - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_xyt.py -# -------------------------------------------------------------------------------- -# #!./env/bin/python3 -# # -*- coding: utf-8 -*- -# # Time-stamp: "2024-07-13 08:14:19 (ywatanabe)" -# # Author: Yusuke Watanabe (ywatanabe@scitex.ai) -# -# """ -# This script does XYZ. -# """ -# -# # Imports -# import matplotlib.pyplot as plt -# -# from ._format_label import format_label -# -# -# # Functions -# def set_xyt(ax, x=False, y=False, t=False, format_labels=True): -# """Sets xlabel, ylabel and title""" -# -# if x is not False: -# x = format_label(x) if format_labels else x -# ax.set_xlabel(x) -# -# if y is not False: -# y = format_label(y) if format_labels else y -# ax.set_ylabel(y) -# -# if t is not False: -# t = format_label(t) if format_labels else t -# ax.set_title(t) -# -# return ax -# -# -# def set_xytc( -# ax, -# x=False, -# y=False, -# t=False, -# c=False, -# methods=False, -# stats=False, -# format_labels=True, -# ): -# """Sets xlabel, ylabel, title, and caption with SciTeX-Paper integration -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes or scitex AxisWrapper -# The axes to modify -# x : str or False, optional -# X-axis label, by default False -# y : str or False, optional -# Y-axis label, by default False -# t : str or False, optional -# Title, by default False -# c : str or False, optional -# Caption to store for later use with scitex.io.save(), by default False -# methods : str or False, optional -# Methods description for SciTeX-Paper integration, by default False -# stats : str or False, optional -# Statistical analysis details for SciTeX-Paper integration, by default False -# format_labels : bool, optional -# Whether to apply automatic formatting, by default True -# -# Returns -# ------- -# ax : matplotlib.axes.Axes or scitex AxisWrapper -# The modified axes -# -# Examples -# -------- -# >>> fig, ax = scitex.plt.subplots() -# >>> ax.plot(x, y) -# >>> ax.set_xytc(x='Time (s)', y='Voltage (mV)', -# ... t='Neural Signal', -# ... c='Example neural recording showing action potentials.', -# ... methods='Intracellular recordings performed using patch-clamp technique.', -# ... stats='Data analyzed using t-test with p<0.05 significance.') -# >>> scitex.io.save(fig, 'neural_signal.png') # Caption automatically saved -# """ -# # Set labels and title using existing function -# set_xyt(ax, x=x, y=y, t=t, format_labels=format_labels) -# -# # Store caption and extended metadata for later use by scitex.io.save() -# if c is not False or methods is not False or stats is not False: -# # Store comprehensive metadata as axis attribute for retrieval by save function -# metadata = { -# "caption": c if c is not False else None, -# "methods": methods if methods is not False else None, -# "stats": stats if stats is not False else None, -# } -# -# if hasattr(ax, "_scitex_metadata"): -# ax._scitex_metadata.update(metadata) -# else: -# # For matplotlib axes, store in figure metadata -# fig = ax.get_figure() -# if not hasattr(fig, "_scitex_metadata"): -# fig._scitex_metadata = {} -# # Use axis position as identifier -# fig._scitex_metadata[ax] = metadata -# -# # Backward compatibility - also store simple caption -# if c is not False: -# if hasattr(ax, "_scitex_caption"): -# ax._scitex_caption = c -# else: -# fig = ax.get_figure() -# if not hasattr(fig, "_scitex_captions"): -# fig._scitex_captions = {} -# fig._scitex_captions[ax] = c -# -# return ax -# -# -# if __name__ == "__main__": -# # Start -# CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) -# -# # (YOUR AWESOME CODE) -# -# # Close -# scitex.session.close(CONFIG) -# -# # EOF -# -# """ -# /ssh:ywatanabe@444:/home/ywatanabe/proj/entrance/scitex/plt/ax/_set_lt.py -# """ - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_set_xyt.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__share_axes.py b/tests/scitex/plt/ax/_style/test__share_axes.py deleted file mode 100755 index 6aa6b4c3c..000000000 --- a/tests/scitex/plt/ax/_style/test__share_axes.py +++ /dev/null @@ -1,461 +0,0 @@ -#!/usr/bin/env python3 -# Timestamp: "2025-05-02 09:02:41 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__share_axes.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__share_axes.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import set_xlims, set_ylims, sharexy # noqa: E402 - -matplotlib.use("Agg") # Use non-GUI backend for testing - -pytestmark = pytest.mark.xfail( - reason="Pre-existing: scitex.plt._subplots not accessible via __getattr__" -) - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure(figsize=(10, 8)) - self.ax1 = self.fig.add_subplot(221) - self.ax2 = self.fig.add_subplot(222) - self.ax3 = self.fig.add_subplot(223) - self.ax4 = self.fig.add_subplot(224) - - # Set different limits for each axis - self.ax1.set_xlim(0, 10) - self.ax1.set_ylim(0, 5) - self.ax2.set_xlim(-5, 5) - self.ax2.set_ylim(-2, 8) - self.ax3.set_xlim(2, 12) - self.ax3.set_ylim(-3, 3) - self.ax4.set_xlim(-10, 0) - self.ax4.set_ylim(2, 7) - - # Create array of axes for testing - self.axes_array = np.array([self.ax1, self.ax2, self.ax3, self.ax4]) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - # def test_get_global_xlim(self): - # # Test getting global xlim - # # NOTE: There appears to be a bug in the original function using ylim instead of xlim - # # Mocking the get_xlim method to bypass the bug for testing purposes - # with patch("matplotlib.axes.Axes.get_xlim", return_value=(-10, 12)): - # xlim = get_global_xlim(self.ax1, self.ax2) - # assert xlim == (-10, 12) - - # # Test with array of axes - # with patch("matplotlib.axes.Axes.get_xlim", return_value=(-10, 12)): - # xlim = get_global_xlim(self.axes_array) - # assert xlim == (-10, 12) - - # def test_get_global_ylim(self): - # # Test getting global ylim - # ylim = get_global_ylim(self.ax1, self.ax2, self.ax3, self.ax4) - # assert ylim[0] <= -3 # Min ylim should be <= -3 - # assert ylim[1] >= 8 # Max ylim should be >= 8 - - # # Test with array of axes - # ylim = get_global_ylim(self.axes_array) - # assert ylim[0] <= -3 - # assert ylim[1] >= 8 - - def test_set_xlims(self): - # Test setting xlim for multiple axes - test_xlim = (-20, 20) - axes, xlim = set_xlims(self.ax1, self.ax2, self.ax3, self.ax4, xlim=test_xlim) - - # Check that all axes have the new xlim - for ax in [self.ax1, self.ax2, self.ax3, self.ax4]: - assert ax.get_xlim() == test_xlim - - # Test with array of axes - axes, xlim = set_xlims(self.axes_array, xlim=test_xlim) - for ax in self.axes_array: - assert ax.get_xlim() == test_xlim - - def test_set_ylims(self): - # Test setting ylim for multiple axes - test_ylim = (-10, 10) - axes, ylim = set_ylims(self.ax1, self.ax2, self.ax3, self.ax4, ylim=test_ylim) - - # Check that all axes have the new ylim - for ax in [self.ax1, self.ax2, self.ax3, self.ax4]: - assert ax.get_ylim() == test_ylim - - # Test with array of axes - axes, ylim = set_ylims(self.axes_array, ylim=test_ylim) - for ax in self.axes_array: - assert ax.get_ylim() == test_ylim - - # def test_sharex(self): - # # Test sharing x axis - # with patch( - # "scitex.plt.ax._share_axes.get_global_xlim", return_value=(-50, 50) - # ): - # axes, xlim = sharex(self.ax1, self.ax2, self.ax3, self.ax4) - - # # Check that all axes have the same xlim - # for ax in [self.ax1, self.ax2, self.ax3, self.ax4]: - # assert ax.get_xlim() == (-50, 50) - - # # Test with array of axes - # axes, xlim = sharex(self.axes_array) - # for ax in self.axes_array: - # assert ax.get_xlim() == (-50, 50) - - # def test_sharey(self): - # # Test sharing y axis - # with patch( - # "scitex.plt.ax._share_axes.get_global_ylim", return_value=(-25, 25) - # ): - # axes, ylim = sharey(self.ax1, self.ax2, self.ax3, self.ax4) - - # # Check that all axes have the same ylim - # for ax in [self.ax1, self.ax2, self.ax3, self.ax4]: - # assert ax.get_ylim() == (-25, 25) - - # # Test with array of axes - # axes, ylim = sharey(self.axes_array) - # for ax in self.axes_array: - # assert ax.get_ylim() == (-25, 25) - - # def test_sharexy(self): - # # Test sharing both x and y axes - # with patch( - # "scitex.plt.ax._share_axes.get_global_xlim", return_value=(-30, 30) - # ): - # with patch( - # "scitex.plt.ax._share_axes.get_global_ylim", - # return_value=(-15, 15), - # ): - # sharexy(self.ax1, self.ax2, self.ax3, self.ax4) - - # # Check that all axes have the same xlim and ylim - # for ax in [self.ax1, self.ax2, self.ax3, self.ax4]: - # assert ax.get_xlim() == (-30, 30) - # assert ax.get_ylim() == (-15, 15) - - def test_error_handling(self): - # Test with missing xlim parameter - with pytest.raises(ValueError, match="Please set xlim"): - set_xlims(self.ax1, self.ax2) - - # Test with missing ylim parameter - with pytest.raises(ValueError, match="Please set ylim"): - set_ylims(self.ax1, self.ax2) - - def test_savefig(self): - from scitex.io import save - - # Main test functionality - sharexy(self.ax1, self.ax2, self.ax3, self.ax4) - self.ax1.set_title("Shared Axes") - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_share_axes.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-01 08:47:27 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_share_axes.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_style/_share_axes.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib.pyplot as plt -# import scitex -# import numpy as np -# -# -# def sharexy(*multiple_axes): -# """Share both x and y axis limits across multiple axes. -# -# Synchronizes both x and y axis limits across all provided axes objects, -# ensuring they all display the same data range. Useful for comparing -# multiple plots on the same scale. -# -# Parameters -# ---------- -# *multiple_axes : matplotlib.axes.Axes or array of Axes -# Variable number of axes objects to synchronize. -# -# Examples -# -------- -# >>> fig, (ax1, ax2, ax3) = plt.subplots(1, 3) -# >>> ax1.plot([1, 2, 3], [1, 4, 9]) -# >>> ax2.plot([1, 2, 3], [2, 5, 8]) -# >>> ax3.plot([1, 2, 3], [3, 6, 10]) -# >>> sharexy(ax1, ax2, ax3) # All axes now show same range -# -# See Also -# -------- -# sharex : Share only x-axis limits -# sharey : Share only y-axis limits -# """ -# sharex(*multiple_axes) -# sharey(*multiple_axes) -# -# -# def sharex(*multiple_axes): -# """Share x-axis limits across multiple axes. -# -# Finds the global x-axis limits across all axes and applies them -# to each axis, ensuring horizontal alignment of data. -# -# Parameters -# ---------- -# *multiple_axes : matplotlib.axes.Axes or array of Axes -# Variable number of axes objects to synchronize. -# -# Returns -# ------- -# axes : axes object(s) -# The modified axes with shared x-limits. -# xlim : tuple -# The (xmin, xmax) limits applied. -# -# Examples -# -------- -# >>> fig, axes = plt.subplots(2, 1) -# >>> axes[0].plot([1, 5], [1, 2]) -# >>> axes[1].plot([2, 4], [3, 4]) -# >>> sharex(axes[0], axes[1]) # Both show x-range [1, 5] -# """ -# xlim = get_global_xlim(*multiple_axes) -# return set_xlims(*multiple_axes, xlim=xlim) -# -# -# def sharey(*multiple_axes): -# """Share y-axis limits across multiple axes. -# -# Finds the global y-axis limits across all axes and applies them -# to each axis, ensuring vertical alignment of data. -# -# Parameters -# ---------- -# *multiple_axes : matplotlib.axes.Axes or array of Axes -# Variable number of axes objects to synchronize. -# -# Returns -# ------- -# axes : axes object(s) -# The modified axes with shared y-limits. -# ylim : tuple -# The (ymin, ymax) limits applied. -# -# Examples -# -------- -# >>> fig, axes = plt.subplots(1, 2) -# >>> axes[0].plot([1, 2], [1, 5]) -# >>> axes[1].plot([1, 2], [2, 4]) -# >>> sharey(axes[0], axes[1]) # Both show y-range [1, 5] -# """ -# ylim = get_global_ylim(*multiple_axes) -# return set_ylims(*multiple_axes, ylim=ylim) -# -# -# def get_global_xlim(*multiple_axes): -# """Get the global x-axis limits across multiple axes. -# -# Scans all provided axes to find the minimum and maximum x-values -# across all of them. Handles both single axes and arrays of axes. -# -# Parameters -# ---------- -# *multiple_axes : matplotlib.axes.Axes or array of Axes -# Variable number of axes objects to scan. -# -# Returns -# ------- -# tuple -# (xmin, xmax) representing the global x-axis limits. -# -# Examples -# -------- -# >>> fig, (ax1, ax2) = plt.subplots(1, 2) -# >>> ax1.plot([1, 3], [1, 2]) # x-range: [1, 3] -# >>> ax2.plot([2, 5], [1, 2]) # x-range: [2, 5] -# >>> xlim = get_global_xlim(ax1, ax2) -# >>> print(xlim) # (1, 5) -# -# Notes -# ----- -# There appears to be a bug in the current implementation where -# get_ylim() is called instead of get_xlim(). This should be fixed. -# """ -# xmin, xmax = np.inf, -np.inf -# for axes in multiple_axes: -# # axes -# if isinstance( -# axes, (np.ndarray, scitex.plt._subplots.AxesWrapper) -# ): -# for ax in axes.flat: -# _xmin, _xmax = ax.get_xlim() # Fixed: was get_ylim() -# xmin = min(xmin, _xmin) -# xmax = max(xmax, _xmax) -# # axis -# else: -# ax = axes -# _xmin, _xmax = ax.get_xlim() # Fixed: was get_ylim() -# xmin = min(xmin, _xmin) -# xmax = max(xmax, _xmax) -# -# return (xmin, xmax) -# -# -# # def get_global_xlim(*multiple_axes): -# # xmin, xmax = np.inf, -np.inf -# # for axes in multiple_axes: -# # for ax in axes.flat: -# # _xmin, _xmax = ax.get_xlim() -# # xmin = min(xmin, _xmin) -# # xmax = max(xmax, _xmax) -# # return (xmin, xmax) -# -# -# def get_global_ylim(*multiple_axes): -# """Get the global y-axis limits across multiple axes. -# -# Scans all provided axes to find the minimum and maximum y-values -# across all of them. Handles both single axes and arrays of axes. -# -# Parameters -# ---------- -# *multiple_axes : matplotlib.axes.Axes or array of Axes -# Variable number of axes objects to scan. -# -# Returns -# ------- -# tuple -# (ymin, ymax) representing the global y-axis limits. -# -# Examples -# -------- -# >>> fig, (ax1, ax2) = plt.subplots(1, 2) -# >>> ax1.plot([1, 2], [1, 3]) # y-range: [1, 3] -# >>> ax2.plot([1, 2], [2, 5]) # y-range: [2, 5] -# >>> ylim = get_global_ylim(ax1, ax2) -# >>> print(ylim) # (1, 5) -# """ -# ymin, ymax = np.inf, -np.inf -# for axes in multiple_axes: -# # axes -# if isinstance( -# axes, (np.ndarray, scitex.plt._subplots.AxesWrapper) -# ): -# for ax in axes.flat: -# _ymin, _ymax = ax.get_ylim() -# ymin = min(ymin, _ymin) -# ymax = max(ymax, _ymax) -# # axis -# else: -# ax = axes -# _ymin, _ymax = ax.get_ylim() -# ymin = min(ymin, _ymin) -# ymax = max(ymax, _ymax) -# -# return (ymin, ymax) -# -# -# def set_xlims(*multiple_axes, xlim=None): -# if xlim is None: -# raise ValueError("Please set xlim. get_global_xlim() might be useful.") -# -# for axes in multiple_axes: -# # axes -# if isinstance( -# axes, (np.ndarray, scitex.plt._subplots.AxesWrapper) -# ): -# for ax in axes.flat: -# ax.set_xlim(xlim) -# # axis -# else: -# ax = axes -# ax.set_xlim(xlim) -# -# # Return -# if len(multiple_axes) == 1: -# return multiple_axes[0], xlim -# else: -# return multiple_axes, xlim -# -# -# def set_ylims(*multiple_axes, ylim=None): -# if ylim is None: -# raise ValueError("Please set ylim. get_global_xlim() might be useful.") -# -# for axes in multiple_axes: -# # axes -# if isinstance( -# axes, (np.ndarray, scitex.plt._subplots.AxesWrapper) -# ): -# for ax in axes.flat: -# ax.set_ylim(ylim) -# -# # axis -# else: -# ax = axes -# ax.set_ylim(ylim) -# -# # Return -# if len(multiple_axes) == 1: -# return multiple_axes[0], ylim -# else: -# return multiple_axes, ylim -# -# -# def main(): -# pass -# -# -# if __name__ == "__main__": -# # # Argument Parser -# # import argparse -# import sys -# -# # parser = argparse.ArgumentParser(description='') -# # parser.add_argument('--var', '-v', type=int, default=1, help='') -# # parser.add_argument('--flag', '-f', action='store_true', default=False, help='') -# # args = parser.parse_args() -# # Main -# CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( -# sys, plt, verbose=False -# ) -# main() -# scitex.session.close(CONFIG, verbose=False, notify=False) -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_share_axes.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__shift.py b/tests/scitex/plt/ax/_style/test__shift.py deleted file mode 100644 index f566447e7..000000000 --- a/tests/scitex/plt/ax/_style/test__shift.py +++ /dev/null @@ -1,251 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 09:02:43 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/_adjust/test__shift.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/ax/_adjust/test__shift.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pytest - -pytest.importorskip("zarr") -from scitex.plt.ax._style import shift - -matplotlib.use("Agg") - - -class TestMainFunctionality: - def setup_method(self): - # Setup test fixtures - self.fig = plt.figure(figsize=(6, 4)) - self.ax = self.fig.add_subplot(111) - - def teardown_method(self): - # Clean up after tests - plt.close(self.fig) - - def test_basic_functionality(self): - # Basic test case - original_bbox = self.ax.get_position() - - # Shift 1 inch (2.54 cm) right and up - shifted_ax = shift(self.ax, dx=2.54, dy=2.54) - - # Get new position - new_bbox = shifted_ax.get_position() - - # Calculate expected change in position - fig_width_in, fig_height_in = self.fig.get_size_inches() - expected_dx_ratio = (2.54 / 2.54) / fig_width_in - expected_dy_ratio = (2.54 / 2.54) / fig_height_in - - # Check that the ax was shifted correctly - assert np.isclose(new_bbox.x0, original_bbox.x0 + expected_dx_ratio) - assert np.isclose(new_bbox.y0, original_bbox.y0 + expected_dy_ratio) - - # Check that width and height are unchanged - assert np.isclose(new_bbox.width, original_bbox.width) - assert np.isclose(new_bbox.height, original_bbox.height) - - def test_edge_cases(self): - # Test with zero shift - original_bbox = self.ax.get_position() - shifted_ax = shift(self.ax, dx=0, dy=0) - new_bbox = shifted_ax.get_position() - - assert np.isclose(new_bbox.x0, original_bbox.x0) - assert np.isclose(new_bbox.y0, original_bbox.y0) - - # Test with negative shift - original_bbox = self.ax.get_position() - shifted_ax = shift(self.ax, dx=-1.27, dy=-1.27) - new_bbox = shifted_ax.get_position() - - fig_width_in, fig_height_in = self.fig.get_size_inches() - expected_dx_ratio = (-1.27 / 2.54) / fig_width_in - expected_dy_ratio = (-1.27 / 2.54) / fig_height_in - - assert np.isclose(new_bbox.x0, original_bbox.x0 + expected_dx_ratio) - assert np.isclose(new_bbox.y0, original_bbox.y0 + expected_dy_ratio) - - def test_error_handling(self): - # Test with invalid input types - with pytest.raises(TypeError): - shift(self.ax, dx="invalid", dy=0) - - def test_savefig(self): - from scitex.io import save - - # Main test functionality - original_bbox = self.ax.get_position() - shifted_ax = shift(self.ax, dx=1.27, dy=1.27) - - # Saving - spath = f"./{os.path.basename(__file__)}.jpg" - save(self.fig, spath) - - # Check saved file - ACTUAL_SAVE_DIR = __file__.replace(".py", "_out") - actual_spath = os.path.join(ACTUAL_SAVE_DIR, spath) - assert os.path.exists(actual_spath), f"Failed to save figure to {spath}" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_shift.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-02 09:00:54 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/ax/_style/_shift.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/ax/_style/_shift.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# -# def shift(ax, dx=0, dy=0): -# """ -# Adjusts the position of an Axes object within a Figure by specified offsets in centimeters. -# -# This function modifies the position of a given matplotlib.axes.Axes object by shifting it horizontally and vertically within its parent figure. The shift amounts are specified in centimeters, and the function converts these values into the figure's coordinate system to perform the adjustment. -# -# Parameters: -# - ax (matplotlib.axes.Axes): The Axes object to modify. This must be an instance of a Matplotlib Axes. -# - dx (float): The horizontal offset in centimeters. Positive values shift the Axes to the right, while negative values shift it to the left. -# - dy (float): The vertical offset in centimeters. Positive values shift the Axes up, while negative values shift it down. -# -# Returns: -# - matplotlib.axes.Axes: The modified Axes object with the adjusted position. -# """ -# -# bbox = ax.get_position() -# -# # Convert centimeters to inches for consistency with matplotlib dimensions -# dx_in, dy_in = dx / 2.54, dy / 2.54 -# -# # Calculate delta ratios relative to the figure size -# fig = ax.get_figure() -# fig_dx_in, fig_dy_in = fig.get_size_inches() -# dx_ratio, dy_ratio = dx_in / fig_dx_in, dy_in / fig_dy_in -# -# # Determine updated bbox position and optionally adjust dimensions -# left = bbox.x0 + dx_ratio -# bottom = bbox.y0 + dy_ratio -# width = bbox.width -# height = bbox.height -# -# # Main -# ax.set_position([left, bottom, width, height]) -# -# return ax -# -# -# # def adjust_axes_position_and_dimension( -# # ax, dx, dy, adjust_width_for_dx=False, adjust_height_for_dy=False -# # ): -# -# # def set_pos(ax, x_cm, y_cm, extend_x=False, extend_y=False): -# # """ -# # Adjusts the position of an Axes object within a Figure by a specified offset in centimeters. -# -# # Parameters: -# # - ax (matplotlib.axes.Axes): The Axes object to modify. -# # - x_cm (float): The horizontal offset in centimeters to adjust the Axes position. -# # - y_cm (float): The vertical offset in centimeters to adjust the Axes position. -# # - extend_x (bool): If True, reduces the width of the Axes by the horizontal offset. -# # - extend_y (bool): If True, reduces the height of the Axes by the vertical offset. -# -# # Returns: -# # - ax (matplotlib.axes.Axes): The modified Axes object with the adjusted position. -# # """ -# -# # bbox = ax.get_position() -# -# # # Inches -# # x_in, y_in = x_cm / 2.54, y_cm / 2.54 -# -# # # Calculates delta ratios -# # fig = ax.get_figure() -# # fig_x_in, fig_y_in = fig.get_size_inches() -# # x_ratio, y_ratio = x_in / fig_x_in, y_in / fig_y_in -# -# # # Determines updated bbox position -# # left = bbox.x0 + x_ratio -# # bottom = bbox.y0 + y_ratio -# # width = bbox.width -# # height = bbox.height -# -# # if extend_x: -# # width -= x_ratio -# -# # if extend_y: -# # height -= y_ratio -# -# # ax.set_position([left, bottom, width, height]) -# -# # return ax -# -# -# # def set_pos( -# # fig, -# # ax, -# # x_cm, -# # y_cm, -# # dragh=False, -# # dragv=False, -# # ): -# -# # bbox = ax.get_position() -# -# # ## Calculates delta ratios -# # fig_x_in, fig_y_in = fig.get_size_inches() -# -# # x_in = float(x_cm) / 2.54 -# # y_in = float(y_cm) / 2.54 -# -# # x_ratio = x_in / fig_x_in -# # y_ratio = y_in / fig_x_in -# -# # ## Determines updated bbox position -# # left = bbox.x0 + x_ratio -# # bottom = bbox.y0 + y_ratio -# # width = bbox.x1 - bbox.x0 -# # height = bbox.y1 - bbox.y0 -# -# # if dragh: -# # width -= x_ratio -# -# # if dragv: -# # height -= y_ratio -# -# # ax.set_pos( -# # [ -# # left, -# # bottom, -# # width, -# # height, -# # ] -# # ) -# -# # return ax -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_shift.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__show_spines.py b/tests/scitex/plt/ax/_style/test__show_spines.py deleted file mode 100644 index 3413c827c..000000000 --- a/tests/scitex/plt/ax/_style/test__show_spines.py +++ /dev/null @@ -1,912 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2025-06-05 07:45:00 (ywatanabe)" -# File: ./tests/scitex/plt/ax/_style/test__show_spines.py - -""" -Functionality: - Comprehensive tests for _show_spines module -Input: - Various matplotlib axes configurations and spine parameters -Output: - Test results validating spine visibility control functionality -Prerequisites: - pytest, matplotlib, scitex -""" - -from unittest.mock import Mock, patch - -import matplotlib.axes -import matplotlib.pyplot as plt -import numpy as np -import pytest - -# Import the functions to test -import scitex - - -class TestShowSpines: - """Test the main show_spines function.""" - - def setup_method(self): - """Set up test fixtures.""" - self.fig, self.ax = plt.subplots() - # Start with all spines hidden (common scitex default) - for spine in self.ax.spines.values(): - spine.set_visible(False) - - def teardown_method(self): - """Clean up test fixtures.""" - plt.close(self.fig) - - def test_show_all_spines_default(self): - """Test showing all spines with default parameters.""" - result = scitex.plt.ax.show_spines(self.ax) - - assert result is self.ax - assert self.ax.spines["top"].get_visible() - assert self.ax.spines["bottom"].get_visible() - assert self.ax.spines["left"].get_visible() - assert self.ax.spines["right"].get_visible() - - def test_show_selective_spines(self): - """Test showing only specific spines.""" - scitex.plt.ax.show_spines( - self.ax, top=False, right=False, bottom=True, left=True - ) - - assert not self.ax.spines["top"].get_visible() - assert self.ax.spines["bottom"].get_visible() - assert self.ax.spines["left"].get_visible() - assert not self.ax.spines["right"].get_visible() - - def test_spine_width_setting(self): - """Test setting custom spine width.""" - width = 2.5 - scitex.plt.ax.show_spines(self.ax, spine_width=width) - - for spine in self.ax.spines.values(): - assert spine.get_linewidth() == width - - def test_spine_color_setting(self): - """Test setting custom spine color.""" - color = "red" - scitex.plt.ax.show_spines(self.ax, spine_color=color) - - # matplotlib converts colors to RGBA tuples - expected_rgba = (1.0, 0.0, 0.0, 1.0) # red in RGBA - for spine in self.ax.spines.values(): - assert spine.get_edgecolor() == expected_rgba - - def test_combined_styling(self): - """Test combining width and color settings.""" - width, color = 1.8, "blue" - scitex.plt.ax.show_spines(self.ax, spine_width=width, spine_color=color) - - expected_rgba = (0.0, 0.0, 1.0, 1.0) # blue in RGBA - for spine in self.ax.spines.values(): - assert spine.get_linewidth() == width - assert spine.get_edgecolor() == expected_rgba - - def test_tick_positioning_bottom_only(self): - """Test tick positioning when only bottom spine is shown.""" - scitex.plt.ax.show_spines( - self.ax, top=False, bottom=True, left=False, right=False - ) - - # Should position ticks on bottom only - assert self.ax.xaxis.get_ticks_position() == "bottom" - - def test_tick_positioning_top_only(self): - """Test tick positioning when only top spine is shown.""" - scitex.plt.ax.show_spines( - self.ax, top=True, bottom=False, left=False, right=False - ) - - assert self.ax.xaxis.get_ticks_position() == "top" - - def test_tick_positioning_both_horizontal(self): - """Test tick positioning when both horizontal spines are shown.""" - scitex.plt.ax.show_spines( - self.ax, top=True, bottom=True, left=False, right=False - ) - - # When both spines are shown, matplotlib might use 'default' instead of 'both' - tick_pos = self.ax.xaxis.get_ticks_position() - assert tick_pos in ["both", "default"] - - def test_tick_positioning_left_only(self): - """Test tick positioning when only left spine is shown.""" - scitex.plt.ax.show_spines( - self.ax, top=False, bottom=False, left=True, right=False - ) - - assert self.ax.yaxis.get_ticks_position() == "left" - - def test_tick_positioning_right_only(self): - """Test tick positioning when only right spine is shown.""" - scitex.plt.ax.show_spines( - self.ax, top=False, bottom=False, left=False, right=True - ) - - assert self.ax.yaxis.get_ticks_position() == "right" - - def test_tick_positioning_both_vertical(self): - """Test tick positioning when both vertical spines are shown.""" - scitex.plt.ax.show_spines( - self.ax, top=False, bottom=False, left=True, right=True - ) - - # When both spines are shown, matplotlib might use 'default' instead of 'both' - tick_pos = self.ax.yaxis.get_ticks_position() - assert tick_pos in ["both", "default"] - - def test_ticks_disabled(self): - """Test behavior when ticks are disabled.""" - original_x_pos = self.ax.xaxis.get_ticks_position() - original_y_pos = self.ax.yaxis.get_ticks_position() - - scitex.plt.ax.show_spines(self.ax, ticks=False) - - # Tick positions should not be modified when ticks=False - assert self.ax.xaxis.get_ticks_position() == original_x_pos - assert self.ax.yaxis.get_ticks_position() == original_y_pos - - def test_restore_defaults_disabled(self): - """Test behavior when restore_defaults is disabled.""" - scitex.plt.ax.show_spines(self.ax, restore_defaults=False) - - # Should still show spines but not modify tick settings - assert all(spine.get_visible() for spine in self.ax.spines.values()) - - def test_labels_functionality(self): - """Test label restoration functionality.""" - # Set some data to generate ticks - self.ax.plot([1, 2, 3], [1, 4, 2]) - - scitex.plt.ax.show_spines(self.ax, labels=True) - - # Should have tick labels - xticks = self.ax.get_xticks() - yticks = self.ax.get_yticks() - assert len(xticks) > 0 - assert len(yticks) > 0 - - -class TestScitexAxisWrapperCompatibility: - """Test compatibility with scitex AxisWrapper objects.""" - - def setup_method(self): - """Set up test fixtures with mock AxisWrapper.""" - self.fig, self.ax = plt.subplots() - - # Create a mock AxisWrapper that has _axis_mpl attribute - self.mock_wrapper = Mock() - self.mock_wrapper._axis_mpl = self.ax - self.mock_wrapper.__class__.__name__ = "AxisWrapper" - - def teardown_method(self): - """Clean up test fixtures.""" - plt.close(self.fig) - - def test_axis_wrapper_handling(self): - """Test that function works with scitex AxisWrapper objects.""" - result = scitex.plt.ax.show_spines(self.mock_wrapper) - - # Should return the underlying matplotlib axis - assert result is self.ax - # All spines should be visible - assert all(spine.get_visible() for spine in self.ax.spines.values()) - - def test_invalid_axis_type(self): - """Test error handling for invalid axis types.""" - with pytest.raises( - AssertionError, match="First argument must be a matplotlib axis" - ): - scitex.plt.ax.show_spines("not_an_axis") - - def test_none_axis(self): - """Test error handling for None axis.""" - with pytest.raises( - AssertionError, match="First argument must be a matplotlib axis" - ): - scitex.plt.ax.show_spines(None) - - -class TestShowAllSpines: - """Test the show_all_spines convenience function.""" - - def setup_method(self): - """Set up test fixtures.""" - self.fig, self.ax = plt.subplots() - - def teardown_method(self): - """Clean up test fixtures.""" - plt.close(self.fig) - - def test_show_all_spines_basic(self): - """Test basic show_all_spines functionality.""" - result = scitex.plt.ax.show_all_spines(self.ax) - - assert result is self.ax - assert all(spine.get_visible() for spine in self.ax.spines.values()) - - def test_show_all_spines_with_styling(self): - """Test show_all_spines with styling parameters.""" - width, color = 2.0, "green" - scitex.plt.ax.show_all_spines(self.ax, spine_width=width, spine_color=color) - - expected_rgba = (0.0, 0.5019607843137255, 0.0, 1.0) # green in RGBA - for spine in self.ax.spines.values(): - assert spine.get_visible() - assert spine.get_linewidth() == width - assert spine.get_edgecolor() == expected_rgba - - def test_show_all_spines_no_ticks(self): - """Test show_all_spines without ticks.""" - scitex.plt.ax.show_all_spines(self.ax, ticks=False) - - assert all(spine.get_visible() for spine in self.ax.spines.values()) - - def test_show_all_spines_no_labels(self): - """Test show_all_spines without labels.""" - scitex.plt.ax.show_all_spines(self.ax, labels=False) - - assert all(spine.get_visible() for spine in self.ax.spines.values()) - - -class TestShowClassicSpines: - """Test the show_classic_spines function (scientific plot style).""" - - def setup_method(self): - """Set up test fixtures.""" - self.fig, self.ax = plt.subplots() - - def teardown_method(self): - """Clean up test fixtures.""" - plt.close(self.fig) - - def test_classic_spines_pattern(self): - """Test that classic spines shows only bottom and left.""" - scitex.plt.ax.show_classic_spines(self.ax) - - assert not self.ax.spines["top"].get_visible() - assert self.ax.spines["bottom"].get_visible() - assert self.ax.spines["left"].get_visible() - assert not self.ax.spines["right"].get_visible() - - def test_classic_spines_with_styling(self): - """Test classic spines with custom styling.""" - width, color = 1.5, "black" - scitex.plt.ax.show_classic_spines(self.ax, spine_width=width, spine_color=color) - - expected_rgba = (0.0, 0.0, 0.0, 1.0) # black in RGBA - # Only bottom and left should be styled and visible - assert self.ax.spines["bottom"].get_visible() - assert self.ax.spines["left"].get_visible() - assert self.ax.spines["bottom"].get_linewidth() == width - assert self.ax.spines["left"].get_linewidth() == width - assert self.ax.spines["bottom"].get_edgecolor() == expected_rgba - assert self.ax.spines["left"].get_edgecolor() == expected_rgba - - def test_scientific_spines_alias(self): - """Test that scientific_spines is an alias for show_classic_spines.""" - scitex.plt.ax.scientific_spines(self.ax) - - assert not self.ax.spines["top"].get_visible() - assert self.ax.spines["bottom"].get_visible() - assert self.ax.spines["left"].get_visible() - assert not self.ax.spines["right"].get_visible() - - -class TestShowBoxSpines: - """Test the show_box_spines function.""" - - def setup_method(self): - """Set up test fixtures.""" - self.fig, self.ax = plt.subplots() - - def teardown_method(self): - """Clean up test fixtures.""" - plt.close(self.fig) - - def test_box_spines_all_visible(self): - """Test that box spines shows all four spines.""" - scitex.plt.ax.show_box_spines(self.ax) - - assert all(spine.get_visible() for spine in self.ax.spines.values()) - - def test_box_spines_with_styling(self): - """Test box spines with styling.""" - width, color = 1.0, "purple" - scitex.plt.ax.show_box_spines(self.ax, spine_width=width, spine_color=color) - - expected_rgba = ( - 0.5019607843137255, - 0.0, - 0.5019607843137255, - 1.0, - ) # purple in RGBA - for spine in self.ax.spines.values(): - assert spine.get_visible() - assert spine.get_linewidth() == width - assert spine.get_edgecolor() == expected_rgba - - -class TestToggleSpines: - """Test the toggle_spines function.""" - - def setup_method(self): - """Set up test fixtures.""" - self.fig, self.ax = plt.subplots() - # Set initial known state - self.ax.spines["top"].set_visible(True) - self.ax.spines["bottom"].set_visible(False) - self.ax.spines["left"].set_visible(True) - self.ax.spines["right"].set_visible(False) - - def teardown_method(self): - """Clean up test fixtures.""" - plt.close(self.fig) - - def test_toggle_all_spines(self): - """Test toggling all spines (None parameters).""" - initial_states = { - name: spine.get_visible() for name, spine in self.ax.spines.items() - } - - scitex.plt.ax.toggle_spines(self.ax) - - for name, spine in self.ax.spines.items(): - assert spine.get_visible() == (not initial_states[name]) - - def test_toggle_specific_spines(self): - """Test setting specific spine states.""" - scitex.plt.ax.toggle_spines(self.ax, top=False, bottom=True) - - assert not self.ax.spines["top"].get_visible() - assert self.ax.spines["bottom"].get_visible() - # Left and right should be toggled from initial state - assert not self.ax.spines["left"].get_visible() # was True, now False - assert self.ax.spines["right"].get_visible() # was False, now True - - def test_toggle_mixed_parameters(self): - """Test mixing explicit and toggle parameters.""" - scitex.plt.ax.toggle_spines(self.ax, top=True, right=False) - - assert self.ax.spines["top"].get_visible() # explicitly set to True - assert not self.ax.spines["right"].get_visible() # explicitly set to False - # Bottom and left should be toggled - assert self.ax.spines["bottom"].get_visible() # was False, now True - assert not self.ax.spines["left"].get_visible() # was True, now False - - -class TestCleanSpines: - """Test the clean_spines function (no spines shown).""" - - def setup_method(self): - """Set up test fixtures.""" - self.fig, self.ax = plt.subplots() - # Start with all spines visible - for spine in self.ax.spines.values(): - spine.set_visible(True) - - def teardown_method(self): - """Clean up test fixtures.""" - plt.close(self.fig) - - def test_clean_spines_hides_all(self): - """Test that clean_spines hides all spines.""" - scitex.plt.ax.clean_spines(self.ax) - - assert all(not spine.get_visible() for spine in self.ax.spines.values()) - - def test_clean_spines_with_ticks_labels(self): - """Test clean_spines with tick and label options.""" - scitex.plt.ax.clean_spines(self.ax, ticks=True, labels=True) - - # All spines should be hidden regardless of tick/label settings - assert all(not spine.get_visible() for spine in self.ax.spines.values()) - - -class TestEdgeCases: - """Test edge cases and error conditions.""" - - def setup_method(self): - """Set up test fixtures.""" - self.fig, self.ax = plt.subplots() - - def teardown_method(self): - """Clean up test fixtures.""" - plt.close(self.fig) - - def test_empty_axis_data(self): - """Test behavior with axis that has no data.""" - # Should work without errors even with empty axis - result = scitex.plt.ax.show_spines(self.ax) - assert result is self.ax - - def test_axis_with_data(self): - """Test behavior with axis containing data.""" - x = np.linspace(0, 10, 100) - y = np.sin(x) - self.ax.plot(x, y) - - result = scitex.plt.ax.show_spines(self.ax) - assert result is self.ax - assert all(spine.get_visible() for spine in self.ax.spines.values()) - - def test_negative_spine_width(self): - """Test behavior with negative spine width.""" - # Matplotlib should handle this gracefully - scitex.plt.ax.show_spines(self.ax, spine_width=-1.0) - - for spine in self.ax.spines.values(): - assert spine.get_linewidth() == -1.0 # matplotlib allows negative widths - - def test_zero_spine_width(self): - """Test behavior with zero spine width.""" - scitex.plt.ax.show_spines(self.ax, spine_width=0.0) - - for spine in self.ax.spines.values(): - assert spine.get_linewidth() == 0.0 - - def test_invalid_color_format(self): - """Test behavior with invalid color format.""" - # This should raise a matplotlib error - with pytest.raises((ValueError, TypeError)): - scitex.plt.ax.show_spines(self.ax, spine_color="invalid_color_name") - - def test_none_width_and_color(self): - """Test that None values don't change existing properties.""" - # Set initial properties - initial_width = self.ax.spines["bottom"].get_linewidth() - initial_color = self.ax.spines["bottom"].get_edgecolor() - - scitex.plt.ax.show_spines(self.ax, spine_width=None, spine_color=None) - - # Properties should remain unchanged - assert self.ax.spines["bottom"].get_linewidth() == initial_width - assert self.ax.spines["bottom"].get_edgecolor() == initial_color - - -class TestIntegration: - """Integration tests with realistic usage patterns.""" - - def setup_method(self): - """Set up test fixtures.""" - self.fig, self.ax = plt.subplots() - - def teardown_method(self): - """Clean up test fixtures.""" - plt.close(self.fig) - - def test_scientific_plot_workflow(self): - """Test typical scientific plotting workflow.""" - # Generate sample data - x = np.linspace(0, 2 * np.pi, 100) - y = np.sin(x) - self.ax.plot(x, y) - - # Apply scientific styling - scitex.plt.ax.show_classic_spines(self.ax, spine_width=1.2, spine_color="black") - - # Verify the result - assert not self.ax.spines["top"].get_visible() - assert self.ax.spines["bottom"].get_visible() - assert self.ax.spines["left"].get_visible() - assert not self.ax.spines["right"].get_visible() - - # Check styling - expected_rgba = (0.0, 0.0, 0.0, 1.0) # black in RGBA - assert self.ax.spines["bottom"].get_linewidth() == 1.2 - assert self.ax.spines["left"].get_linewidth() == 1.2 - assert self.ax.spines["bottom"].get_edgecolor() == expected_rgba - assert self.ax.spines["left"].get_edgecolor() == expected_rgba - - def test_overlay_plot_workflow(self): - """Test workflow for overlay plots with clean spines.""" - # Create base plot - x = np.linspace(0, 10, 50) - y = np.exp(-x / 3) - self.ax.plot(x, y) - - # Apply clean styling for overlay - scitex.plt.ax.clean_spines(self.ax, ticks=False, labels=False) - - # Verify clean appearance - assert all(not spine.get_visible() for spine in self.ax.spines.values()) - - def test_publication_ready_workflow(self): - """Test workflow for publication-ready figures.""" - # Create sample data - categories = ["A", "B", "C", "D"] - values = [23, 45, 56, 78] - self.ax.bar(categories, values) - - # Apply publication styling - scitex.plt.ax.show_box_spines( - self.ax, spine_width=0.8, spine_color="#333333", ticks=True, labels=True - ) - - # Verify box appearance - expected_rgba = (0.2, 0.2, 0.2, 1.0) # #333333 in RGBA - assert all(spine.get_visible() for spine in self.ax.spines.values()) - for spine in self.ax.spines.values(): - assert spine.get_linewidth() == 0.8 - assert spine.get_edgecolor() == expected_rgba - - def test_toggle_workflow(self): - """Test interactive toggle workflow.""" - # Start with default state - initial_states = { - name: spine.get_visible() for name, spine in self.ax.spines.items() - } - - # Toggle spines multiple times - scitex.plt.ax.toggle_spines(self.ax) - first_toggle = { - name: spine.get_visible() for name, spine in self.ax.spines.items() - } - - scitex.plt.ax.toggle_spines(self.ax) - second_toggle = { - name: spine.get_visible() for name, spine in self.ax.spines.items() - } - - # Should return to initial state after double toggle - assert initial_states == second_toggle - - # First toggle should be opposite of initial - for name in initial_states: - assert first_toggle[name] == (not initial_states[name]) - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_show_spines.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Time-stamp: "2025-06-04 11:15:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_style/_show_spines.py -# -# """ -# Functionality: -# Show spines for matplotlib axes with intuitive API -# Input: -# Matplotlib axes object and spine visibility parameters -# Output: -# Axes with specified spines made visible -# Prerequisites: -# matplotlib -# """ -# -# import matplotlib -# from typing import Union, List -# -# -# def show_spines( -# axis, -# top: bool = True, -# bottom: bool = True, -# left: bool = True, -# right: bool = True, -# ticks: bool = True, -# labels: bool = True, -# restore_defaults: bool = True, -# spine_width: float = None, -# spine_color: str = None, -# ): -# """ -# Shows the specified spines of a matplotlib Axes object and optionally restores ticks and labels. -# -# This function provides the intuitive counterpart to hide_spines. It's especially useful when -# you have spines hidden by default (as in scitex configuration) and want to selectively show them -# for clearer scientific plots or specific visualization needs. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes -# The Axes object for which the spines will be shown. -# top : bool, optional -# If True, shows the top spine. Defaults to True. -# bottom : bool, optional -# If True, shows the bottom spine. Defaults to True. -# left : bool, optional -# If True, shows the left spine. Defaults to True. -# right : bool, optional -# If True, shows the right spine. Defaults to True. -# ticks : bool, optional -# If True, restores ticks on the shown spines' axes. Defaults to True. -# labels : bool, optional -# If True, restores labels on the shown spines' axes. Defaults to True. -# restore_defaults : bool, optional -# If True, restores default tick positions and labels. Defaults to True. -# spine_width : float, optional -# Width of the spines to show. If None, uses matplotlib default. -# spine_color : str, optional -# Color of the spines to show. If None, uses matplotlib default. -# -# Returns -# ------- -# matplotlib.axes.Axes -# The modified Axes object with the specified spines shown. -# -# Examples -# -------- -# >>> fig, ax = plt.subplots() -# >>> # Show only bottom and left spines (classic scientific plot style) -# >>> show_spines(ax, top=False, right=False) -# >>> plt.show() -# -# >>> # Show all spines with custom styling -# >>> show_spines(ax, spine_width=1.5, spine_color='black') -# >>> plt.show() -# -# >>> # Show spines but without ticks/labels (for clean overlay plots) -# >>> show_spines(ax, ticks=False, labels=False) -# >>> plt.show() -# -# Notes -# ----- -# This function is designed to work seamlessly with scitex plotting where spines are hidden -# by default. It provides an intuitive API for showing spines without needing to remember -# that hide_spines(top=False, right=False) shows top and right spines. -# """ -# # Handle both matplotlib axes and scitex AxisWrapper -# if hasattr(axis, "_axis_mpl"): -# # This is an scitex AxisWrapper, get the underlying matplotlib axis -# axis = axis._axis_mpl -# -# assert isinstance(axis, matplotlib.axes._axes.Axes), ( -# "First argument must be a matplotlib axis or scitex AxisWrapper" -# ) -# -# # Define which spines to show -# spine_settings = {"top": top, "bottom": bottom, "left": left, "right": right} -# -# for spine_name, should_show in spine_settings.items(): -# # Set spine visibility -# axis.spines[spine_name].set_visible(should_show) -# -# if should_show: -# # Set spine width if specified -# if spine_width is not None: -# axis.spines[spine_name].set_linewidth(spine_width) -# -# # Set spine color if specified -# if spine_color is not None: -# axis.spines[spine_name].set_color(spine_color) -# -# # Restore ticks if requested -# if ticks and restore_defaults: -# # Determine tick positions based on which spines are shown -# if bottom and not top: -# axis.xaxis.set_ticks_position("bottom") -# elif top and not bottom: -# axis.xaxis.set_ticks_position("top") -# elif bottom and top: -# axis.xaxis.set_ticks_position("both") -# -# if left and not right: -# axis.yaxis.set_ticks_position("left") -# elif right and not left: -# axis.yaxis.set_ticks_position("right") -# elif left and right: -# axis.yaxis.set_ticks_position("both") -# -# # Restore labels if requested and restore_defaults is True -# if labels and restore_defaults: -# # Only restore if we haven't explicitly hidden them -# # This preserves any custom tick labels that might have been set -# current_xticks = axis.get_xticks() -# current_yticks = axis.get_yticks() -# -# if len(current_xticks) > 0 and (bottom or top): -# # Generate default labels for x-axis -# if not hasattr(axis, "_original_xticklabels"): -# axis.set_xticks(current_xticks) -# -# if len(current_yticks) > 0 and (left or right): -# # Generate default labels for y-axis -# if not hasattr(axis, "_original_yticklabels"): -# axis.set_yticks(current_yticks) -# -# return axis -# -# -# def show_all_spines( -# axis, -# spine_width: float = None, -# spine_color: str = None, -# ticks: bool = True, -# labels: bool = True, -# ): -# """ -# Convenience function to show all spines with optional styling. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes -# The Axes object to modify. -# spine_width : float, optional -# Width of all spines. -# spine_color : str, optional -# Color of all spines. -# ticks : bool, optional -# Whether to show ticks. Defaults to True. -# labels : bool, optional -# Whether to show labels. Defaults to True. -# -# Returns -# ------- -# matplotlib.axes.Axes -# The modified Axes object. -# -# Examples -# -------- -# >>> show_all_spines(ax, spine_width=1.2, spine_color='gray') -# """ -# return show_spines( -# axis, -# top=True, -# bottom=True, -# left=True, -# right=True, -# ticks=ticks, -# labels=labels, -# spine_width=spine_width, -# spine_color=spine_color, -# ) -# -# -# def show_classic_spines( -# axis, -# spine_width: float = None, -# spine_color: str = None, -# ticks: bool = True, -# labels: bool = True, -# ): -# """ -# Show only bottom and left spines (classic scientific plot style). -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes -# The Axes object to modify. -# spine_width : float, optional -# Width of the spines. -# spine_color : str, optional -# Color of the spines. -# ticks : bool, optional -# Whether to show ticks. Defaults to True. -# labels : bool, optional -# Whether to show labels. Defaults to True. -# -# Returns -# ------- -# matplotlib.axes.Axes -# The modified Axes object. -# -# Examples -# -------- -# >>> show_classic_spines(ax) # Shows only bottom and left spines -# """ -# return show_spines( -# axis, -# top=False, -# bottom=True, -# left=True, -# right=False, -# ticks=ticks, -# labels=labels, -# spine_width=spine_width, -# spine_color=spine_color, -# ) -# -# -# def show_box_spines( -# axis, -# spine_width: float = None, -# spine_color: str = None, -# ticks: bool = True, -# labels: bool = True, -# ): -# """ -# Show all four spines to create a box around the plot. -# -# This is an alias for show_all_spines but with more descriptive naming -# for when you specifically want a boxed appearance. -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes -# The Axes object to modify. -# spine_width : float, optional -# Width of the box spines. -# spine_color : str, optional -# Color of the box spines. -# ticks : bool, optional -# Whether to show ticks. Defaults to True. -# labels : bool, optional -# Whether to show labels. Defaults to True. -# -# Returns -# ------- -# matplotlib.axes.Axes -# The modified Axes object. -# -# Examples -# -------- -# >>> show_box_spines(ax, spine_width=1.0, spine_color='black') -# """ -# return show_all_spines(axis, spine_width, spine_color, ticks, labels) -# -# -# def toggle_spines( -# axis, top: bool = None, bottom: bool = None, left: bool = None, right: bool = None -# ): -# """ -# Toggle the visibility of spines (show if hidden, hide if shown). -# -# Parameters -# ---------- -# axis : matplotlib.axes.Axes -# The Axes object to modify. -# top : bool, optional -# If specified, sets top spine visibility. If None, toggles current state. -# bottom : bool, optional -# If specified, sets bottom spine visibility. If None, toggles current state. -# left : bool, optional -# If specified, sets left spine visibility. If None, toggles current state. -# right : bool, optional -# If specified, sets right spine visibility. If None, toggles current state. -# -# Returns -# ------- -# matplotlib.axes.Axes -# The modified Axes object. -# -# Examples -# -------- -# >>> toggle_spines(ax) # Toggles all spines -# >>> toggle_spines(ax, top=True, right=True) # Shows top and right, toggles others -# """ -# spine_names = ["top", "bottom", "left", "right"] -# spine_params = [top, bottom, left, right] -# -# for spine_name, param in zip(spine_names, spine_params): -# if param is None: -# # Toggle current state -# current_state = axis.spines[spine_name].get_visible() -# axis.spines[spine_name].set_visible(not current_state) -# else: -# # Set specific state -# axis.spines[spine_name].set_visible(param) -# -# return axis -# -# -# # Convenient aliases for common use cases -# def scientific_spines(axis, **kwargs): -# """Alias for show_classic_spines - shows only bottom and left spines.""" -# return show_classic_spines(axis, **kwargs) -# -# -# def clean_spines(axis, **kwargs): -# """Alias for showing no spines - useful for overlay plots or clean visualizations.""" -# return show_spines(axis, top=False, bottom=False, left=False, right=False, **kwargs) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_show_spines.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__style_barplot.py b/tests/scitex/plt/ax/_style/test__style_barplot.py deleted file mode 100644 index 6f4190045..000000000 --- a/tests/scitex/plt/ax/_style/test__style_barplot.py +++ /dev/null @@ -1,85 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_barplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 20:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_style/_style_barplot.py -# -# """ -# Style bar plot elements with millimeter-based control. -# -# Default values are loaded from SCITEX_STYLE.yaml via presets.py. -# """ -# -# from typing import Optional, Union, List -# -# from scitex.plt.styles.presets import SCITEX_STYLE -# -# # Get defaults from centralized config -# _DEFAULT_EDGE_THICKNESS_MM = SCITEX_STYLE.get("bar_edge_thickness_mm", 0.2) -# -# -# def style_barplot( -# bar_container, -# edge_thickness_mm: float = None, -# edgecolor: Optional[Union[str, List[str]]] = "black", -# ): -# """ -# Apply consistent styling to matplotlib bar plot elements. -# -# Parameters -# ---------- -# bar_container : BarContainer -# Container returned by ax.bar() or ax.barh() -# edge_thickness_mm : float, optional -# Edge line thickness in millimeters (default: 0.2mm) -# edgecolor : str or list of str, optional -# Edge color(s) for bars. If None, uses default matplotlib colors. -# -# Returns -# ------- -# bar_container : BarContainer -# The styled bar container -# -# Examples -# -------- -# >>> fig, ax = stx.plt.subplots(**stx.plt.presets.NATURE_STYLE) -# >>> bars = ax.bar(x, heights) -# >>> stx.plt.ax.style_barplot(bars, edge_thickness_mm=0.2, edgecolor='black') -# """ -# from scitex.plt.utils import mm_to_pt -# -# # Use centralized default if not specified -# if edge_thickness_mm is None: -# edge_thickness_mm = _DEFAULT_EDGE_THICKNESS_MM -# -# # Convert mm to points -# lw_pt = mm_to_pt(edge_thickness_mm) -# -# # Style each bar -# for i, bar in enumerate(bar_container): -# bar.set_linewidth(lw_pt) -# if edgecolor is not None: -# if isinstance(edgecolor, list): -# bar.set_edgecolor(edgecolor[i % len(edgecolor)]) -# else: -# bar.set_edgecolor(edgecolor) -# -# return bar_container -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_barplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__style_boxplot.py b/tests/scitex/plt/ax/_style/test__style_boxplot.py deleted file mode 100644 index 218b86eab..000000000 --- a/tests/scitex/plt/ax/_style/test__style_boxplot.py +++ /dev/null @@ -1,171 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_boxplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 20:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_style/_style_boxplot.py -# -# """ -# Style boxplot elements with millimeter-based control. -# -# Default values are loaded from SCITEX_STYLE.yaml via presets.py. -# """ -# -# from typing import Dict, Optional -# import matplotlib.pyplot as plt -# -# from scitex.plt.styles.presets import SCITEX_STYLE -# -# # Get defaults from centralized config -# _DEFAULT_LINEWIDTH_MM = SCITEX_STYLE.get("trace_thickness_mm", 0.2) -# _DEFAULT_FLIER_SIZE_MM = SCITEX_STYLE.get("marker_size_mm", 0.8) -# -# -# def style_boxplot( -# boxplot_dict, -# linewidth_mm: float = None, -# flier_size_mm: float = None, -# median_color: str = "black", -# edge_color: str = "black", -# colors: Optional[list] = None, -# add_legend: bool = False, -# labels: Optional[list] = None, -# ): -# """Apply publication-quality styling to matplotlib boxplot elements. -# -# This function modifies boxplots to: -# - Set consistent line widths for all elements -# - Set median line to black for visibility -# - Set edge colors to black -# - Apply consistent outlier marker styling -# - Use scitex color palette by default for box fills -# -# Parameters -# ---------- -# boxplot_dict : dict -# Dictionary returned by ax.boxplot(). -# linewidth_mm : float, default 0.2 -# Line width in millimeters for all elements. -# flier_size_mm : float, default 0.8 -# Outlier (flier) marker size in millimeters. -# median_color : str, default "black" -# Color for the median line inside boxes. -# edge_color : str, default "black" -# Color for box edges, whiskers, and caps. -# colors : list, optional -# List of colors for each box fill. If None, uses scitex color palette. -# add_legend : bool, default False -# Whether to add a legend. -# labels : list, optional -# Labels for legend entries (required if add_legend=True). -# -# Returns -# ------- -# boxplot_dict : dict -# The styled boxplot dictionary. -# -# Examples -# -------- -# >>> import scitex as stx -# >>> import numpy as np -# >>> fig, ax = stx.plt.subplots() -# >>> box_data = [np.random.normal(0, 1, 100) for _ in range(4)] -# >>> bp = ax.boxplot(box_data, patch_artist=True) -# >>> stx.plt.ax.style_boxplot(bp, median_color="black") -# """ -# from scitex.plt.utils import mm_to_pt -# from scitex.plt.color._PARAMS import HEX -# -# # Use centralized defaults if not specified -# if linewidth_mm is None: -# linewidth_mm = _DEFAULT_LINEWIDTH_MM -# if flier_size_mm is None: -# flier_size_mm = _DEFAULT_FLIER_SIZE_MM -# -# # Convert mm to points -# lw_pt = mm_to_pt(linewidth_mm) -# flier_size_pt = mm_to_pt(flier_size_mm) -# -# # Use scitex color palette by default -# if colors is None: -# colors = [ -# HEX["blue"], -# HEX["red"], -# HEX["green"], -# HEX["yellow"], -# HEX["purple"], -# HEX["orange"], -# HEX["lightblue"], -# HEX["pink"], -# ] -# -# # Style box elements with line width -# for element_name in ["boxes", "whiskers", "caps"]: -# if element_name in boxplot_dict: -# for element in boxplot_dict[element_name]: -# element.set_linewidth(lw_pt) -# element.set_color(edge_color) -# -# # Style medians with specified color -# if "medians" in boxplot_dict: -# for median in boxplot_dict["medians"]: -# median.set_linewidth(lw_pt) -# median.set_color(median_color) -# -# # Style fliers (outliers) with marker size -# if "fliers" in boxplot_dict: -# for flier in boxplot_dict["fliers"]: -# flier.set_markersize(flier_size_pt) -# flier.set_markeredgewidth(lw_pt) -# flier.set_markeredgecolor(edge_color) -# flier.set_markerfacecolor("none") # Open circles -# -# # Apply fill colors to boxes -# for i, box in enumerate(boxplot_dict.get("boxes", [])): -# color = colors[i % len(colors)] -# if hasattr(box, "set_facecolor"): -# box.set_facecolor(color) -# box.set_edgecolor(edge_color) -# -# # Add legend if requested -# if add_legend and labels is not None: -# # Create proxy artists for legend -# import matplotlib.patches as mpatches -# -# if colors is not None: -# legend_elements = [ -# mpatches.Patch( -# facecolor="none", edgecolor=color, linewidth=lw_pt, label=label -# ) -# for color, label in zip(colors, labels) -# ] -# else: -# legend_elements = [ -# mpatches.Patch( -# facecolor="none", edgecolor="C0", linewidth=lw_pt, label=label -# ) -# for label in labels -# ] -# # Get the axes from one of the box elements -# if boxplot_dict.get("boxes"): -# ax = boxplot_dict["boxes"][0].axes -# ax.legend(handles=legend_elements) -# -# return boxplot_dict -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_boxplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__style_errorbar.py b/tests/scitex/plt/ax/_style/test__style_errorbar.py deleted file mode 100644 index b93175e66..000000000 --- a/tests/scitex/plt/ax/_style/test__style_errorbar.py +++ /dev/null @@ -1,98 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_errorbar.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 20:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_style/_style_errorbar.py -# -# """ -# Style error bar elements with millimeter-based control. -# -# Default values are loaded from SCITEX_STYLE.yaml via presets.py. -# """ -# -# from typing import Optional -# -# from scitex.plt.styles.presets import SCITEX_STYLE -# -# # Get defaults from centralized config -# _DEFAULT_THICKNESS_MM = SCITEX_STYLE.get("trace_thickness_mm", 0.2) -# _DEFAULT_CAP_WIDTH_MM = SCITEX_STYLE.get("errorbar_cap_width_mm", 0.8) -# -# -# def style_errorbar( -# errorbar_container, -# thickness_mm: float = None, -# cap_width_mm: float = None, -# ): -# """ -# Apply consistent styling to matplotlib errorbar elements. -# -# Parameters -# ---------- -# errorbar_container : ErrorbarContainer -# Container returned by ax.errorbar() -# thickness_mm : float, optional -# Line thickness for error bars in millimeters (default: 0.2mm) -# cap_width_mm : float, optional -# Cap width in millimeters (default: 0.8mm) -# -# Returns -# ------- -# errorbar_container : ErrorbarContainer -# The styled errorbar container -# -# Examples -# -------- -# >>> fig, ax = stx.plt.subplots(**stx.plt.presets.NATURE_STYLE) -# >>> eb = ax.errorbar(x, y, yerr=yerr) -# >>> stx.plt.ax.style_errorbar(eb, thickness_mm=0.2, cap_width_mm=0.8) -# """ -# from scitex.plt.utils import mm_to_pt -# -# # Use centralized defaults if not specified -# if thickness_mm is None: -# thickness_mm = _DEFAULT_THICKNESS_MM -# if cap_width_mm is None: -# cap_width_mm = _DEFAULT_CAP_WIDTH_MM -# -# # Convert mm to points -# lw_pt = mm_to_pt(thickness_mm) -# cap_width_pt = mm_to_pt(cap_width_mm) -# -# # Style the data line -# if errorbar_container[0] is not None: -# errorbar_container[0].set_linewidth(lw_pt) -# -# # Style the error bar lines -# if len(errorbar_container) > 2 and errorbar_container[2] is not None: -# for line_collection in errorbar_container[2]: -# if line_collection is not None: -# line_collection.set_linewidth(lw_pt) -# -# # Style the caps -# if len(errorbar_container) > 1 and errorbar_container[1] is not None: -# for cap in errorbar_container[1]: -# if cap is not None: -# cap.set_linewidth(lw_pt) # Cap line thickness same as error bar -# # Set cap marker size (width) -# cap.set_markersize(cap_width_pt) -# -# return errorbar_container -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_errorbar.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__style_scatter.py b/tests/scitex/plt/ax/_style/test__style_scatter.py deleted file mode 100644 index 8b8d279a0..000000000 --- a/tests/scitex/plt/ax/_style/test__style_scatter.py +++ /dev/null @@ -1,98 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_scatter.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 20:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_style/_style_scatter.py -# -# """ -# Style scatter plot elements with millimeter-based control. -# -# Default values are loaded from SCITEX_STYLE.yaml via presets.py. -# """ -# -# from typing import Optional -# -# from scitex.plt.styles.presets import SCITEX_STYLE -# -# # Get defaults from centralized config -# _DEFAULT_SIZE_MM = SCITEX_STYLE.get("scatter_size_mm", 0.8) -# _DEFAULT_EDGE_THICKNESS_MM = SCITEX_STYLE.get("marker_edge_width_mm", 0.0) -# -# -# def style_scatter( -# path_collection, -# size_mm: float = None, -# edge_thickness_mm: float = None, -# ): -# """ -# Apply consistent styling to matplotlib scatter plot elements. -# -# Parameters -# ---------- -# path_collection : PathCollection -# Collection returned by ax.scatter() -# size_mm : float, optional -# Marker size in millimeters (default: 0.8mm) -# edge_thickness_mm : float, optional -# Edge line thickness in millimeters (default: 0.0mm = no border) -# -# Returns -# ------- -# path_collection : PathCollection -# The styled path collection -# -# Examples -# -------- -# >>> fig, ax = stx.plt.subplots(**stx.plt.presets.NATURE_STYLE) -# >>> scatter = ax.scatter(x, y) -# >>> stx.ax.style_scatter(scatter, size_mm=0.8) -# -# Notes -# ----- -# Matplotlib scatter uses marker size in points squared. -# We convert mm to points, then square for the area. -# By default, no border is applied (edge_thickness_mm=0). -# """ -# from scitex.plt.utils import mm_to_pt -# -# # Use centralized defaults if not specified -# if size_mm is None: -# size_mm = _DEFAULT_SIZE_MM -# if edge_thickness_mm is None: -# edge_thickness_mm = _DEFAULT_EDGE_THICKNESS_MM -# -# # Convert mm to points -# size_pt = mm_to_pt(size_mm) -# -# # Matplotlib scatter uses area (points^2) -# # For a marker of diameter d, area = (d/2)^2 * pi -# # But matplotlib's 's' parameter is already area-like -# # So we use size_pt^2 to get the right visual size -# marker_area = size_pt**2 -# -# # Set marker size -# path_collection.set_sizes([marker_area]) -# -# # Set edge thickness (0 by default = no border) -# edge_width_pt = mm_to_pt(edge_thickness_mm) -# path_collection.set_linewidths(edge_width_pt) -# -# return path_collection -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_scatter.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__style_suptitles.py b/tests/scitex/plt/ax/_style/test__style_suptitles.py deleted file mode 100644 index c7da5c994..000000000 --- a/tests/scitex/plt/ax/_style/test__style_suptitles.py +++ /dev/null @@ -1,92 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_suptitles.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-11-19 15:20:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_style/_style_suptitles.py -# -# """ -# Style figure-level titles and labels with proper font sizes. -# """ -# -# from typing import Optional -# -# -# def style_suptitles( -# fig, -# suptitle_font_size_pt: float = 7, -# font_family: str = "DejaVu Sans", -# ): -# """ -# Apply consistent styling to figure-level titles and labels. -# -# Parameters -# ---------- -# fig : matplotlib.figure.Figure or FigWrapper -# The figure to style -# suptitle_font_size_pt : float, optional -# Font size in points for suptitle, supxlabel, supylabel (default: 7) -# font_family : str, optional -# Font family to use (default: "DejaVu Sans") -# -# Returns -# ------- -# fig : matplotlib.figure.Figure or FigWrapper -# The styled figure -# -# Examples -# -------- -# >>> fig, axes = stx.plt.subplots(2, 2, **stx.plt.presets.NATURE_STYLE) -# >>> fig.suptitle("Main Title") -# >>> fig.supxlabel("X Axis Label") -# >>> fig.supylabel("Y Axis Label") -# >>> stx.ax.style_suptitles(fig) -# -# Notes -# ----- -# This function applies font styling to: -# - fig.suptitle() - Main figure title -# - fig.supxlabel() - Figure-level X axis label -# - fig.supylabel() - Figure-level Y axis label -# -# All are set to the same font size (default 7pt for publication). -# """ -# # Unwrap FigWrapper if needed -# if hasattr(fig, "_fig_mpl"): -# fig_mpl = fig._fig_mpl -# else: -# fig_mpl = fig -# -# # Style suptitle -# if fig_mpl._suptitle is not None: -# fig_mpl._suptitle.set_fontsize(suptitle_font_size_pt) -# fig_mpl._suptitle.set_fontfamily(font_family) -# -# # Style supxlabel (if it exists) -# if hasattr(fig_mpl, "_supxlabel") and fig_mpl._supxlabel is not None: -# fig_mpl._supxlabel.set_fontsize(suptitle_font_size_pt) -# fig_mpl._supxlabel.set_fontfamily(font_family) -# -# # Style supylabel (if it exists) -# if hasattr(fig_mpl, "_supylabel") and fig_mpl._supylabel is not None: -# fig_mpl._supylabel.set_fontsize(suptitle_font_size_pt) -# fig_mpl._supylabel.set_fontfamily(font_family) -# -# return fig -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_suptitles.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/_style/test__style_violinplot.py b/tests/scitex/plt/ax/_style/test__style_violinplot.py deleted file mode 100644 index a6bfca16b..000000000 --- a/tests/scitex/plt/ax/_style/test__style_violinplot.py +++ /dev/null @@ -1,131 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_violinplot.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 20:00:00 (ywatanabe)" -# # File: ./src/scitex/plt/ax/_style/_style_violinplot.py -# -# """Style violin plot elements with millimeter-based control. -# -# Default values are loaded from SCITEX_STYLE.yaml via presets.py. -# """ -# -# from typing import Optional, Union -# -# from matplotlib.axes import Axes -# -# from scitex.plt.styles.presets import SCITEX_STYLE -# -# # Get defaults from centralized config -# _DEFAULT_LINEWIDTH_MM = SCITEX_STYLE.get("trace_thickness_mm", 0.2) -# -# -# def style_violinplot( -# ax: Union[Axes, "AxisWrapper"], -# linewidth_mm: float = None, -# edge_color: str = "black", -# median_color: str = "black", -# remove_caps: bool = True, -# ) -> Union[Axes, "AxisWrapper"]: -# """Apply publication-quality styling to seaborn violin plots. -# -# This function modifies violin plots created by seaborn.violinplot() to: -# - Add borders to the KDE (violin body) edges -# - Remove caps from the internal boxplot whiskers -# - Set median line to black for better visibility -# - Apply consistent line widths -# -# Parameters -# ---------- -# ax : matplotlib.axes.Axes or AxisWrapper -# The axes containing the violin plot. -# linewidth_mm : float, default 0.2 -# Line width in millimeters for violin edges and boxplot elements. -# edge_color : str, default "black" -# Color for the violin body edges. -# median_color : str, default "black" -# Color for the median line inside the boxplot. -# remove_caps : bool, default True -# Whether to remove the caps (horizontal lines) from boxplot whiskers. -# -# Returns -# ------- -# ax : matplotlib.axes.Axes or AxisWrapper -# The axes with styled violin plot. -# -# Examples -# -------- -# >>> import seaborn as sns -# >>> import scitex as stx -# >>> fig, ax = stx.plt.subplots() -# >>> sns.violinplot(data=df, x="group", y="value", ax=ax) -# >>> stx.plt.ax.style_violinplot(ax) -# """ -# from scitex.plt.utils import mm_to_pt -# -# # Use centralized default if not specified -# if linewidth_mm is None: -# linewidth_mm = _DEFAULT_LINEWIDTH_MM -# -# lw_pt = mm_to_pt(linewidth_mm) -# -# # Style violin bodies (PolyCollection) -# for collection in ax.collections: -# # Check if it's a violin body (PolyCollection with filled area) -# if hasattr(collection, "set_edgecolor"): -# collection.set_edgecolor(edge_color) -# collection.set_linewidth(lw_pt) -# -# # Style internal boxplot elements (Line2D objects) -# # Seaborn violin plot lines: whiskers (vertical), caps (horizontal), median (short horizontal) -# lines = list(ax.lines) -# n_violins = len( -# [ -# c -# for c in ax.collections -# if hasattr(c, "get_paths") and len(c.get_paths()) > 0 -# ] -# ) -# -# for line in lines: -# # Get line data to identify element type -# xdata = line.get_xdata() -# ydata = line.get_ydata() -# -# if len(ydata) != 2: -# continue -# -# # Caps are horizontal lines (same y-value for both points) with wider x-span -# is_horizontal = ydata[0] == ydata[1] -# x_span = abs(xdata[1] - xdata[0]) if len(xdata) == 2 else 0 -# -# if is_horizontal: -# if remove_caps and x_span > 0.05: -# # This is likely a cap (wider horizontal line at whisker ends) -# line.set_visible(False) -# else: -# # This is likely a median line (short horizontal line) -# line.set_color(median_color) -# line.set_linewidth(lw_pt) -# else: -# # Vertical lines (whiskers) -# line.set_linewidth(lw_pt) -# -# return ax -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/ax/_style/_style_violinplot.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/ax/conftest_enhanced.py b/tests/scitex/plt/ax/conftest_enhanced.py deleted file mode 100644 index c03f4d996..000000000 --- a/tests/scitex/plt/ax/conftest_enhanced.py +++ /dev/null @@ -1,516 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-06-09 21:00:00 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/ax/conftest_enhanced.py -# ---------------------------------------- -""" -Enhanced pytest fixtures for scitex.plt.ax module testing. - -This file provides comprehensive fixtures for testing plotting functions, -including sample data, figure management, performance monitoring, and -integration helpers. -""" - -import os -import shutil -import tempfile -import time -import tracemalloc -from contextlib import contextmanager -from pathlib import Path -from unittest.mock import MagicMock, patch - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import pytest -from hypothesis import strategies as st - -# ---------------------------------------- -# Configuration -# ---------------------------------------- - -# Set non-interactive backend for testing -matplotlib.use("Agg") - -# Ensure reproducibility -np.random.seed(42) - - -# ---------------------------------------- -# Data Generation Fixtures -# ---------------------------------------- - - -@pytest.fixture -def sample_1d_data(): - """Provide various 1D data arrays for testing.""" - return { - "simple": np.array([1, 2, 3, 4, 5]), - "large": np.random.randn(1000), - "periodic": np.sin(np.linspace(0, 4 * np.pi, 100)), - "noisy": np.random.randn(100) + np.linspace(0, 10, 100), - "categorical": np.array(["A", "B", "C", "A", "B", "C"]), - "with_nans": np.array([1, 2, np.nan, 4, 5]), - "with_infs": np.array([1, np.inf, 3, -np.inf, 5]), - "empty": np.array([]), - "single": np.array([42]), - "binary": np.array([0, 1, 0, 1, 1, 0]), - "sorted": np.arange(50), - "reversed": np.arange(50)[::-1], - } - - -@pytest.fixture -def sample_2d_data(): - """Provide various 2D data arrays for testing.""" - return { - "small": np.array([[1, 2], [3, 4]]), - "medium": np.random.randn(10, 10), - "large": np.random.randn(100, 100), - "correlation": np.corrcoef(np.random.randn(5, 20)), - "confusion_matrix": np.array([[85, 15], [10, 90]]), - "heatmap": np.random.rand(20, 30), - "image": np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8), - "sparse": np.zeros((50, 50)), - "diagonal": np.diag(np.arange(10)), - "symmetric": lambda: (lambda x: x + x.T)(np.random.randn(10, 10)), - "with_pattern": np.outer( - np.sin(np.linspace(0, np.pi, 50)), np.cos(np.linspace(0, np.pi, 50)) - ), - } - - -@pytest.fixture -def sample_3d_data(): - """Provide various 3D data arrays for testing.""" - return { - "simple": np.random.randn(5, 10, 15), - "image_stack": np.random.randint(0, 255, (10, 64, 64), dtype=np.uint8), - "time_series": np.random.randn(100, 5, 3), # time x channels x features - "volume": np.random.rand(20, 20, 20), - } - - -@pytest.fixture -def sample_time_series(): - """Provide time series data for testing.""" - n_points = 1000 - t = np.linspace(0, 10, n_points) - - return { - "time": t, - "sine": np.sin(2 * np.pi * t), - "cosine": np.cos(2 * np.pi * t), - "noisy_sine": np.sin(2 * np.pi * t) + 0.1 * np.random.randn(n_points), - "multi_freq": ( - np.sin(2 * np.pi * t) - + 0.5 * np.sin(10 * np.pi * t) - + 0.2 * np.sin(50 * np.pi * t) - ), - "trend": t + 0.5 * np.random.randn(n_points), - "seasonal": np.sin(2 * np.pi * t) + 0.1 * t, - "multiple": np.column_stack( - [np.sin(2 * np.pi * t), np.cos(2 * np.pi * t), np.sin(4 * np.pi * t)] - ), - } - - -@pytest.fixture -def sample_statistical_data(): - """Provide statistical data for testing.""" - n_samples = 100 - - return { - "normal": np.random.normal(0, 1, n_samples), - "uniform": np.random.uniform(-1, 1, n_samples), - "exponential": np.random.exponential(1, n_samples), - "bimodal": np.concatenate( - [ - np.random.normal(-2, 0.5, n_samples // 2), - np.random.normal(2, 0.5, n_samples // 2), - ] - ), - "outliers": np.concatenate( - [ - np.random.normal(0, 1, int(n_samples * 0.95)), - np.random.normal(0, 10, int(n_samples * 0.05)), - ] - ), - "groups": { - "A": np.random.normal(0, 1, n_samples), - "B": np.random.normal(1, 1.5, n_samples), - "C": np.random.normal(-0.5, 0.8, n_samples), - }, - "paired": np.column_stack( - [np.random.normal(0, 1, n_samples), np.random.normal(0, 1, n_samples) + 0.5] - ), - } - - -@pytest.fixture -def sample_dataframes(): - """Provide pandas DataFrames for testing.""" - n_rows = 50 - - return { - "simple": pd.DataFrame( - { - "x": np.arange(n_rows), - "y": np.random.randn(n_rows), - } - ), - "multivariate": pd.DataFrame( - { - "A": np.random.randn(n_rows), - "B": np.random.randn(n_rows), - "C": np.random.randn(n_rows), - "category": np.random.choice(["X", "Y", "Z"], n_rows), - } - ), - "time_series": pd.DataFrame( - { - "timestamp": pd.date_range("2024-01-01", periods=n_rows, freq="H"), - "value": np.cumsum(np.random.randn(n_rows)), - "volume": np.random.randint(0, 100, n_rows), - } - ), - "wide": pd.DataFrame( - np.random.randn(20, 100), columns=[f"feature_{i}" for i in range(100)] - ), - } - - -# ---------------------------------------- -# Figure Management Fixtures -# ---------------------------------------- - - -@pytest.fixture -def fig_ax(): - """Create a single figure and axes, cleaned up after test.""" - fig, ax = plt.subplots(figsize=(8, 6)) - yield fig, ax - scitex.plt.close(fig) - - -@pytest.fixture -def multi_axes(): - """Create multiple axes configurations.""" - configs = { - "2x2": plt.subplots(2, 2, figsize=(10, 8)), - "3x1": plt.subplots(3, 1, figsize=(8, 10)), - "1x3": plt.subplots(1, 3, figsize=(12, 4)), - "mixed": plt.subplots(2, 3, figsize=(12, 8)), - } - - yield configs - - # Cleanup - for fig, axes in configs.values(): - scitex.plt.close(fig) - - -@pytest.fixture -def fig_3d(): - """Create a 3D axes for testing.""" - fig = plt.figure(figsize=(8, 6)) - ax = fig.add_subplot(111, projection="3d") - yield fig, ax - scitex.plt.close(fig) - - -@pytest.fixture -def clean_figure(): - """Ensure all figures are closed before and after test.""" - plt.close("all") - yield - plt.close("all") - - -# ---------------------------------------- -# Style and Appearance Fixtures -# ---------------------------------------- - - -@pytest.fixture -def color_palettes(): - """Provide various color palettes for testing.""" - return { - "basic": ["red", "green", "blue"], - "mpl_cycle": plt.rcParams["axes.prop_cycle"].by_key()["color"], - "seaborn": ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"], - "grayscale": ["#000000", "#404040", "#808080", "#bfbfbf", "#ffffff"], - "rainbow": plt.cm.rainbow(np.linspace(0, 1, 7)), - "diverging": plt.cm.RdBu(np.linspace(0, 1, 11)), - "sequential": plt.cm.viridis(np.linspace(0, 1, 9)), - } - - -@pytest.fixture -def line_styles(): - """Provide various line styles for testing.""" - return { - "solid": "-", - "dashed": "--", - "dotted": ":", - "dashdot": "-.", - "custom": (0, (5, 2, 1, 2)), # Custom dash pattern - } - - -@pytest.fixture -def marker_styles(): - """Provide various marker styles for testing.""" - return ["o", "s", "^", "D", "v", "<", ">", "p", "*", "h", "H", "+", "x"] - - -# ---------------------------------------- -# Mock and Patch Fixtures -# ---------------------------------------- - - -@pytest.fixture -def mock_save(): - """Mock the scitex.io.save function.""" - with patch("scitex.io.save") as mock: - mock.return_value = None - yield mock - - -@pytest.fixture -def mock_plt_show(): - """Mock plt.show to prevent display during tests.""" - with patch("matplotlib.pyplot.show") as mock: - yield mock - - -@pytest.fixture -def mock_file_system(tmp_path): - """Create a mock file system with test files.""" - # Create directory structure - (tmp_path / "data").mkdir() - (tmp_path / "figures").mkdir() - (tmp_path / "results").mkdir() - - # Create some test files - (tmp_path / "data" / "test.csv").write_text("x,y\n1,2\n3,4\n") - (tmp_path / "data" / "test.npy").touch() - - yield tmp_path - - -# ---------------------------------------- -# Performance Monitoring Fixtures -# ---------------------------------------- - - -@pytest.fixture -def performance_monitor(): - """Monitor performance metrics during tests.""" - - class PerformanceMonitor: - def __init__(self): - self.metrics = {} - - @contextmanager - def measure(self, name): - start_time = time.time() - tracemalloc.start() - start_memory = tracemalloc.get_traced_memory()[0] - - yield - - current, peak = tracemalloc.get_traced_memory() - tracemalloc.stop() - - self.metrics[name] = { - "duration": time.time() - start_time, - "memory_used": current - start_memory, - "memory_peak": peak, - } - - def get_metrics(self): - return self.metrics - - def assert_performance(self, name, max_duration=None, max_memory=None): - """Assert performance constraints.""" - if name not in self.metrics: - raise ValueError(f"No metrics recorded for '{name}'") - - metrics = self.metrics[name] - if max_duration and metrics["duration"] > max_duration: - raise AssertionError( - f"{name} took {metrics['duration']:.3f}s, " - f"expected < {max_duration}s" - ) - if max_memory and metrics["memory_used"] > max_memory: - raise AssertionError( - f"{name} used {metrics['memory_used'] / 1e6:.1f}MB, " - f"expected < {max_memory / 1e6:.1f}MB" - ) - - return PerformanceMonitor() - - -# ---------------------------------------- -# Hypothesis Strategies -# ---------------------------------------- - - -@pytest.fixture -def hypothesis_strategies(): - """Provide common Hypothesis strategies for property testing.""" - return { - "colors": st.sampled_from(["red", "blue", "green", "black", "#FF0000"]), - "line_widths": st.floats(min_value=0.1, max_value=10.0), - "alpha_values": st.floats(min_value=0.0, max_value=1.0), - "fontsize": st.integers(min_value=6, max_value=24), - "figure_size": st.tuples( - st.integers(min_value=4, max_value=20), - st.integers(min_value=3, max_value=15), - ), - "data_size": st.integers(min_value=10, max_value=1000), - "labels": st.text(min_size=1, max_size=20), - } - - -# ---------------------------------------- -# Assertion Helpers -# ---------------------------------------- - - -@pytest.fixture -def plot_assertions(): - """Provide common assertion helpers for plots.""" - - class PlotAssertions: - @staticmethod - def assert_axes_limits(ax, xlim=None, ylim=None, tolerance=1e-6): - """Assert axes limits are as expected.""" - if xlim: - actual_xlim = ax.get_xlim() - assert abs(actual_xlim[0] - xlim[0]) < tolerance - assert abs(actual_xlim[1] - xlim[1]) < tolerance - if ylim: - actual_ylim = ax.get_ylim() - assert abs(actual_ylim[0] - ylim[0]) < tolerance - assert abs(actual_ylim[1] - ylim[1]) < tolerance - - @staticmethod - def assert_labels(ax, xlabel=None, ylabel=None, title=None): - """Assert axes labels are as expected.""" - if xlabel is not None: - assert ax.get_xlabel() == xlabel - if ylabel is not None: - assert ax.get_ylabel() == ylabel - if title is not None: - assert ax.get_title() == title - - @staticmethod - def assert_legend_exists(ax, n_entries=None): - """Assert legend exists and has expected entries.""" - legend = ax.get_legend() - assert legend is not None - if n_entries is not None: - assert len(legend.get_texts()) == n_entries - - @staticmethod - def assert_n_lines(ax, n_lines): - """Assert number of lines in plot.""" - lines = ax.get_lines() - assert len(lines) == n_lines - - @staticmethod - def assert_colorbar_exists(fig): - """Assert figure has a colorbar.""" - # Check if any axes is a colorbar - for ax in fig.get_axes(): - if hasattr(ax, "colorbar") or ax.__class__.__name__ == "Colorbar": - return True - raise AssertionError("No colorbar found in figure") - - return PlotAssertions() - - -# ---------------------------------------- -# Temporary Directory Management -# ---------------------------------------- - - -@pytest.fixture -def temp_output_dir(tmp_path): - """Create a temporary directory for test outputs.""" - output_dir = tmp_path / "test_output" - output_dir.mkdir() - - yield output_dir - - # Optional: Keep outputs for debugging by setting env var - if not os.environ.get("KEEP_TEST_OUTPUTS"): - shutil.rmtree(output_dir) - - -# ---------------------------------------- -# Integration Helpers -# ---------------------------------------- - - -@pytest.fixture -def scitex_modules(): - """Import and provide access to scitex modules if available.""" - modules = {} - - try: - import scitex - - modules["scitex"] = scitex - except ImportError: - pass - - try: - import scitex.plt - - modules["plt"] = scitex.plt - except ImportError: - pass - - try: - import scitex.io - - modules["io"] = scitex.io - except ImportError: - pass - - return modules - - -# ---------------------------------------- -# Cleanup and Safety -# ---------------------------------------- - - -@pytest.fixture(autouse=True) -def cleanup_matplotlib(): - """Automatically cleanup matplotlib state after each test.""" - yield - plt.close("all") - # Reset any modified rcParams - matplotlib.rcdefaults() - - -if __name__ == "__main__": - print("This is a pytest conftest file and should not be run directly.") - print("Fixtures provided:") - print("- sample_1d_data: Various 1D arrays") - print("- sample_2d_data: Various 2D arrays") - print("- sample_3d_data: Various 3D arrays") - print("- sample_time_series: Time series data") - print("- sample_statistical_data: Statistical distributions") - print("- sample_dataframes: Pandas DataFrames") - print("- fig_ax: Single figure/axes pair") - print("- multi_axes: Multiple axes configurations") - print("- performance_monitor: Performance measurement") - print("- plot_assertions: Common plot assertions") - print("... and many more!") diff --git a/tests/scitex/plt/color/test__colors.py b/tests/scitex/plt/color/test__colors.py deleted file mode 100644 index c274b261e..000000000 --- a/tests/scitex/plt/color/test__colors.py +++ /dev/null @@ -1,442 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-02 17:45:23 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/tests/scitex/plt/color/test__colors.py -# ---------------------------------------- -import os - -__FILE__ = "./tests/scitex/plt/color/test__colors.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -from unittest.mock import patch - -import pytest - -pytest.importorskip("zarr") - -from scitex.plt.color import ( - PARAMS, - bgr2bgra, - bgr2rgb, - bgra2bgr, - bgra2hex, - bgra2rgba, - cycle_color, - cycle_color_bgr, - cycle_color_rgb, - gradiate_color_bgr, - gradiate_color_bgra, - gradiate_color_rgb, - gradiate_color_rgba, - rgb2bgr, - rgb2rgba, - rgba2bgra, - rgba2hex, - rgba2rgb, - str2bgr, - str2bgra, - str2hex, - str2rgb, - str2rgba, - update_alpha, -) - - -def test_str2rgb(): - assert str2rgb("red") == PARAMS["RGB"]["red"] - assert str2rgb("blue") == PARAMS["RGB"]["blue"] - - -def test_str2rgba(): - red_rgb = PARAMS["RGB"]["red"] - expected = [val / 255 for val in red_rgb] - expected.append(1.0) - expected = [round(val, 2) for val in expected] - result = str2rgba("red") - assert result == expected - - -def test_rgb2rgba(): - rgb = [255, 0, 0] - expected = [1.0, 0.0, 0.0, 1.0] - assert rgb2rgba(rgb) == expected - - rgb = [255, 128, 0] - expected = [1.0, 0.5, 0.0, 0.7] - assert rgb2rgba(rgb, alpha=0.7) == expected - - -def test_rgba2rgb(): - rgba = [1.0, 0.5, 0.0, 0.7] - expected = [255.0, 127.5, 0.0] - assert rgba2rgb(rgba) == expected - - -def test_rgba2hex(): - rgba = [255, 128, 0, 0.5] - expected = "#ff8000" + hex(int(0.5 * 255))[2:].zfill(2) - assert rgba2hex(rgba) == expected - - -def test_cycle_color_rgb(): - mock_colors = ["red", "green", "blue"] - - with patch.dict( - PARAMS["RGB"], - {"red": [255, 0, 0], "green": [0, 255, 0], "blue": [0, 0, 255]}, - ): - assert cycle_color_rgb(0, colors=mock_colors) == "red" - assert cycle_color_rgb(1, colors=mock_colors) == "green" - assert cycle_color_rgb(2, colors=mock_colors) == "blue" - assert cycle_color_rgb(3, colors=mock_colors) == "red" - - -def test_gradiate_color_rgb(): - rgb = [255, 0, 0] - result = gradiate_color_rgb(rgb, n=3) - - assert len(result) == 3 - assert result[0][0] > result[1][0] > result[2][0] - - -def test_gradiate_color_rgba(): - rgba = [255, 0, 0, 0.8] - result = gradiate_color_rgba(rgba, n=3) - - assert len(result) == 3 - for color in result: - assert len(color) == 4 - assert color[3] == 0.8 - - -def test_str2bgr(): - rgb = PARAMS["RGB"]["red"] - expected = [rgb[2], rgb[1], rgb[0]] - assert str2bgr("red") == expected - - -def test_str2bgra(): - red_rgb = PARAMS["RGB"]["red"] - rgba = [val / 255 for val in red_rgb] - rgba.append(1.0) - rgba = [round(val, 2) for val in rgba] - expected = [rgba[2], rgba[1], rgba[0], rgba[3]] - result = str2bgra("red") - assert result == expected - - -def test_bgr2bgra(): - bgr = [0, 0, 255] - expected = [0.0, 0.0, 1.0, 1.0] - assert bgr2bgra(bgr) == expected - - -def test_bgra2bgr(): - bgra = [0, 0.5, 1.0, 0.7] - expected = [0.0, 127.5, 255.0] - assert bgra2bgr(bgra) == expected - - -def test_bgra2hex(): - bgra = [0, 128, 255, 0.5] - expected = "#ff8000" + hex(int(0.5 * 255))[2:].zfill(2) - assert bgra2hex(bgra) == expected - - -def test_cycle_color_bgr(): - mock_colors = ["red", "green", "blue"] - - with patch.dict( - PARAMS["RGB"], - {"red": [255, 0, 0], "green": [0, 255, 0], "blue": [0, 0, 255]}, - ): - assert cycle_color_bgr(0, colors=mock_colors) == [0, 0, 255] - assert cycle_color_bgr(1, colors=mock_colors) == [0, 255, 0] - assert cycle_color_bgr(2, colors=mock_colors) == [255, 0, 0] - - -def test_gradiate_color_bgr(): - bgr = [0, 0, 255] - result = gradiate_color_bgr(bgr, n=3) - - assert len(result) == 3 - assert result[0][2] > result[1][2] > result[2][2] - - -def test_gradiate_color_bgra(): - bgra = [0, 0, 255, 0.8] - result = gradiate_color_bgra(bgra, n=3) - - assert len(result) == 3 - for color in result: - assert len(color) == 4 - assert color[3] == 0.8 - - -def test_bgr2rgb(): - bgr = [0, 128, 255] - expected = [255, 128, 0] - assert bgr2rgb(bgr) == expected - - -def test_rgb2bgr(): - rgb = [255, 128, 0] - expected = [0, 128, 255] - assert rgb2bgr(rgb) == expected - - -def test_bgra2rgba(): - bgra = [0, 128, 255, 0.5] - expected = [255, 128, 0, 0.5] - assert bgra2rgba(bgra) == expected - - -def test_rgba2bgra(): - rgba = [255, 128, 0, 0.5] - expected = [0, 128, 255, 0.5] - assert rgba2bgra(rgba) == expected - - -def test_str2hex(): - assert str2hex("red") == PARAMS["HEX"]["red"] - assert str2hex("blue") == PARAMS["HEX"]["blue"] - - -def test_update_alpha(): - rgba = [1.0, 0.5, 0.0, 0.3] - expected = [1.0, 0.5, 0.0, 0.8] - assert update_alpha(rgba, 0.8) == expected - - -def test_cycle_color(): - mock_colors = ["red", "green", "blue"] - - with patch.dict( - PARAMS["RGB"], - {"red": [255, 0, 0], "green": [0, 255, 0], "blue": [0, 0, 255]}, - ): - assert cycle_color(0, colors=mock_colors) == "red" - assert cycle_color(1, colors=mock_colors) == "green" - assert cycle_color(2, colors=mock_colors) == "blue" - assert cycle_color(3, colors=mock_colors) == "red" - - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/color/_colors.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-05-02 12:19:50 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex_repo/src/scitex/plt/colors/_colors.py -# # ---------------------------------------- -# import os -# -# __FILE__ = "./src/scitex/plt/colors/_colors.py" -# __DIR__ = os.path.dirname(__FILE__) -# # ---------------------------------------- -# -# import matplotlib.colors as _colors -# import numpy as np -# -# from scitex.decorators._deprecated import deprecated -# from ._PARAMS import PARAMS -# -# # RGB -# # ------------------------------ -# -# -# def str2rgb(c): -# return PARAMS["RGB"][c] -# -# -# def str2rgba(c, alpha=1.0): -# rgba = rgb2rgba(PARAMS["RGB"][c]) -# rgba[-1] = alpha -# return rgba -# -# -# def rgb2rgba(rgb, alpha=1.0, round=2): -# rgb = np.array(rgb).astype(float) -# rgb /= 255 -# return [*rgb.round(round), alpha] -# -# -# def rgba2rgb(rgba): -# rgba = np.array(rgba).astype(float) -# rgb = (rgba[:3] * 255).clip(0, 255) -# return rgb.round(2).tolist() -# -# -# def rgba2hex(rgba): -# return "#{:02x}{:02x}{:02x}{:02x}".format( -# int(rgba[0]), int(rgba[1]), int(rgba[2]), int(rgba[3] * 255) -# ) -# -# -# def cycle_color_rgb(i_color, colors=None): -# if colors is None: -# colors = list(PARAMS["RGB"].keys()) -# n_colors = len(colors) -# return colors[i_color % n_colors] -# -# -# def gradiate_color_rgb(rgb_or_rgba, n=5): -# # Separate RGB and alpha if present -# if len(rgb_or_rgba) == 4: # RGBA format -# rgb = rgb_or_rgba[:3] -# alpha = rgb_or_rgba[3] -# has_alpha = True -# else: # RGB format -# rgb = rgb_or_rgba -# alpha = None -# has_alpha = False -# -# # Scale RGB values to 0-1 range if they're in 0-255 range -# if any(val > 1 for val in rgb): -# rgb = [val / 255 for val in rgb] -# -# rgb_hsv = _colors.rgb_to_hsv(np.array(rgb)) -# -# gradient = [] -# for step in range(n): -# color_hsv = [ -# rgb_hsv[0], -# rgb_hsv[1], -# rgb_hsv[2] * (1.0 - (step / n)), -# ] -# color_rgb = [int(v * 255) for v in _colors.hsv_to_rgb(color_hsv)] -# -# if has_alpha: -# gradient.append(rgb2rgba(color_rgb, alpha=alpha)) -# else: -# gradient.append(color_rgb) -# -# return gradient -# -# -# def gradiate_color_rgba(rgb_or_rgba, n=5): -# return gradiate_color_rgb(rgb_or_rgba, n) -# -# -# # BGRA -# # ------------------------------ -# def str2bgr(c): -# return rgb2bgr(str2rgb(c)) -# -# -# def str2bgra(c, alpha=1.0): -# return rgba2bgra(str2rgba(c)) -# -# -# def bgr2bgra(bgra, alpha=1.0, round=2): -# return rgb2rgba(bgra, alpha=alpha, round=round) -# -# -# def bgra2bgr(bgra): -# return rgba2rgb(bgra) -# -# -# def bgra2hex(bgra): -# """Convert BGRA color format to hex format.""" -# rgba = bgra2rgba(bgra) -# return rgba2hex(rgba) -# -# -# def cycle_color_bgr(i_color, colors=None): -# rgb_color = str2rgb(cycle_color(i_color, colors=colors)) -# return rgb2bgr(rgb_color) -# -# -# def gradiate_color_bgr(bgr_or_bgra, n=5): -# rgb_or_rgba = ( -# bgr2rgb(bgr_or_bgra) if len(bgr_or_bgra) == 3 else bgra2rgba(bgr_or_bgra) -# ) -# rgb_gradient = gradiate_color_rgb(rgb_or_rgba, n) -# return [ -# rgb2bgr(color) if len(color) == 3 else rgba2bgra(color) -# for color in rgb_gradient -# ] -# -# -# def gradiate_color_bgra(bgra, n=5): -# return gradiate_color_bgr(bgra, n) -# -# -# # Common -# # ------------------------------ -# def bgr2rgb(bgr): -# """Convert BGR color format to RGB format.""" -# return [bgr[2], bgr[1], bgr[0]] -# -# -# def rgb2bgr(rgb): -# """Convert RGB color format to BGR format.""" -# return [rgb[2], rgb[1], rgb[0]] -# -# -# def bgra2rgba(bgra): -# """Convert BGRA color format to RGBA format.""" -# return [bgra[2], bgra[1], bgra[0], bgra[3]] -# -# -# def rgba2bgra(rgba): -# """Convert RGBA color format to BGRA format.""" -# return [rgba[2], rgba[1], rgba[0], rgba[3]] -# -# -# def str2hex(c): -# return PARAMS["HEX"][c] -# -# -# def update_alpha(rgba, alpha): -# rgba_list = list(rgba) -# rgba_list[-1] = alpha -# return rgba_list -# -# -# def cycle_color(i_color, colors=None): -# return cycle_color_rgb(i_color, colors=colors) -# -# -# # Deprecated -# # ------------------------------ -# @deprecated("Use str2rgb instead") -# def to_rgb(c): -# return str2rgb(c) -# -# -# @deprecated("use str2rgba instewad") -# def to_rgba(c, alpha=1.0): -# return str2rgba(c, alpha=alpha) -# -# -# @deprecated("use str2hex instead") -# def to_hex(c): -# return PARAMS["HEX"][c] -# -# -# @deprecated("use gradiate_color_rgb/rgba/bgr/bgra instead") -# def gradiate_color(rgb_or_rgba, n=5): -# return gradiate_color_rgb(rgb_or_rgba, n) -# -# -# if __name__ == "__main__": -# c = "blue" -# print(to_rgb(c)) -# print(to_rgba(c)) -# print(to_hex(c)) -# print(cycle_color(1)) -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/color/_colors.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/styles/test__plot_defaults.py b/tests/scitex/plt/styles/test__plot_defaults.py deleted file mode 100644 index f060b8d75..000000000 --- a/tests/scitex/plt/styles/test__plot_defaults.py +++ /dev/null @@ -1,226 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_plot_defaults.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # -*- coding: utf-8 -*- -# # Timestamp: "2025-12-01 10:00:00 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_plot_defaults.py -# -# """Pre-processing default kwargs for plot methods. -# -# This module centralizes all default styling applied BEFORE matplotlib -# methods are called. Each function modifies kwargs in-place. -# -# Priority: direct kwarg → env var → YAML config → default -# -# Style values use the key format from YAML (e.g., 'lines.trace_mm'). -# Env vars: SCITEX_PLT_LINES_TRACE_MM (prefix + dots→underscores + uppercase) -# """ -# -# from scitex.plt.utils import mm_to_pt -# from scitex.plt.styles.presets import resolve_style_value -# -# # Default alpha for fill regions (0.3 = semi-transparent) -# DEFAULT_FILL_ALPHA = 0.3 -# -# -# # ============================================================================ -# # Style helper function -# # ============================================================================ -# def _get_style_value(key, default, style_dict=None): -# """Get style value with priority: style_dict → active_style → env → yaml → default. -# -# Args: -# key: YAML-style key (e.g., 'lines.trace_mm') -# default: Fallback default value -# style_dict: Optional user-provided style dict (overrides all) -# -# Returns: -# Resolved style value -# """ -# flat_key = _yaml_key_to_flat(key) -# -# # Priority 1: User passed explicit style dict -# if style_dict is not None and flat_key in style_dict: -# return style_dict[flat_key] -# -# # Priority 2: Check active style set via set_style() -# from scitex.plt.styles.presets import _active_style -# -# if _active_style is not None and flat_key in _active_style: -# return _active_style[flat_key] -# -# # Priority 3: Use resolve_style_value for: env → yaml → default -# return resolve_style_value(key, None, default) -# -# -# def _yaml_key_to_flat(key): -# """Convert YAML key to flat SCITEX_STYLE key. -# -# Examples: -# 'lines.trace_mm' -> 'trace_thickness_mm' -# 'markers.size_mm' -> 'marker_size_mm' -# """ -# # Mapping from YAML keys to flat keys used in SCITEX_STYLE -# mapping = { -# "lines.trace_mm": "trace_thickness_mm", -# "lines.errorbar_mm": "errorbar_thickness_mm", -# "lines.errorbar_cap_mm": "errorbar_cap_width_mm", -# "markers.size_mm": "marker_size_mm", -# } -# return mapping.get(key, key) -# -# -# # ============================================================================ -# # Pre-processing functions -# # ============================================================================ -# def apply_plot_defaults(method_name, kwargs, id_value=None, ax=None): -# """Apply default kwargs for a plot method before calling matplotlib. -# -# Args: -# method_name: Name of the matplotlib method being called -# kwargs: Keyword arguments dict (modified in-place) -# id_value: Optional id passed to the method -# ax: The matplotlib axes (for methods needing axis setup) -# -# Returns: -# Modified kwargs dict -# -# Note: -# Priority: direct kwarg → style dict → env var → yaml → default -# Users can pass `style=dict` kwarg to override env/yaml defaults. -# """ -# # Extract optional style dict (removes 'style' key from kwargs) -# style_dict = kwargs.pop("style", None) -# -# # Dispatch to method-specific defaults -# if method_name == "plot": -# _apply_plot_line_defaults(kwargs, id_value, style_dict) -# elif method_name in ("bar", "barh"): -# _apply_bar_defaults(kwargs, style_dict) -# elif method_name == "errorbar": -# _apply_errorbar_defaults(kwargs, style_dict) -# elif method_name in ("fill_between", "fill_betweenx"): -# _apply_fill_defaults(kwargs) -# elif method_name in ("quiver", "streamplot"): -# _apply_vector_field_defaults(method_name, kwargs, ax, style_dict) -# elif method_name == "boxplot": -# _apply_boxplot_defaults(kwargs) -# elif method_name == "violinplot": -# _apply_violinplot_defaults(kwargs) -# -# return kwargs -# -# -# def _apply_plot_line_defaults(kwargs, id_value=None, style_dict=None): -# """Apply defaults for ax.plot() method.""" -# line_width_mm = _get_style_value("lines.trace_mm", 0.2, style_dict) -# -# # Default line width -# if "linewidth" not in kwargs and "lw" not in kwargs: -# kwargs["linewidth"] = mm_to_pt(line_width_mm) -# -# # KDE-specific styling when id contains "kde" -# if id_value and "kde" in str(id_value).lower(): -# if "linestyle" not in kwargs and "ls" not in kwargs: -# kwargs["linestyle"] = "--" -# if "color" not in kwargs and "c" not in kwargs: -# kwargs["color"] = "black" -# -# -# def _apply_bar_defaults(kwargs, style_dict=None): -# """Apply defaults for ax.bar() and ax.barh() methods.""" -# line_width_mm = _get_style_value("lines.trace_mm", 0.2, style_dict) -# -# # Set error bar line thickness -# if "error_kw" not in kwargs: -# kwargs["error_kw"] = {} -# if "elinewidth" not in kwargs.get("error_kw", {}): -# kwargs["error_kw"]["elinewidth"] = mm_to_pt(line_width_mm) -# if "capthick" not in kwargs.get("error_kw", {}): -# kwargs["error_kw"]["capthick"] = mm_to_pt(line_width_mm) -# # Set a temporary capsize that will be adjusted in post-processing -# if "capsize" not in kwargs: -# kwargs["capsize"] = 5 # Placeholder, adjusted later to 33% of bar width -# -# -# def _apply_errorbar_defaults(kwargs, style_dict=None): -# """Apply defaults for ax.errorbar() method.""" -# line_width_mm = _get_style_value("lines.trace_mm", 0.2, style_dict) -# cap_size_mm = _get_style_value("lines.errorbar_cap_mm", 0.8, style_dict) -# -# if "capsize" not in kwargs: -# kwargs["capsize"] = mm_to_pt(cap_size_mm) -# if "capthick" not in kwargs: -# kwargs["capthick"] = mm_to_pt(line_width_mm) -# if "elinewidth" not in kwargs: -# kwargs["elinewidth"] = mm_to_pt(line_width_mm) -# -# -# def _apply_fill_defaults(kwargs): -# """Apply defaults for ax.fill_between() and ax.fill_betweenx() methods.""" -# if "alpha" not in kwargs: -# kwargs["alpha"] = DEFAULT_FILL_ALPHA # Transparent to see overlapping data -# -# -# def _apply_vector_field_defaults(method_name, kwargs, ax, style_dict=None): -# """Apply defaults for ax.quiver() and ax.streamplot() methods.""" -# line_width_mm = _get_style_value("lines.trace_mm", 0.2, style_dict) -# marker_size_mm = _get_style_value("markers.size_mm", 0.8, style_dict) -# -# # Set equal aspect ratio for proper vector display -# if ax is not None: -# ax.set_aspect("equal", adjustable="datalim") -# -# if method_name == "streamplot": -# if "arrowsize" not in kwargs: -# # arrowsize is a scaling factor; scale relative to default -# kwargs["arrowsize"] = mm_to_pt(marker_size_mm) / 3 -# if "linewidth" not in kwargs: -# kwargs["linewidth"] = mm_to_pt(line_width_mm) -# -# elif method_name == "quiver": -# if "width" not in kwargs: -# kwargs["width"] = 0.003 # Narrow arrow shaft (axes fraction) -# if "headwidth" not in kwargs: -# kwargs["headwidth"] = 3 # Head width relative to shaft -# if "headlength" not in kwargs: -# kwargs["headlength"] = 4 -# if "headaxislength" not in kwargs: -# kwargs["headaxislength"] = 3.5 -# -# -# def _apply_boxplot_defaults(kwargs): -# """Apply defaults for ax.boxplot() method.""" -# # Enable patch_artist for fillable boxes -# if "patch_artist" not in kwargs: -# kwargs["patch_artist"] = True -# -# -# def _apply_violinplot_defaults(kwargs): -# """Apply defaults for ax.violinplot() method.""" -# # Default to showing boxplot overlay (can be disabled with boxplot=False) -# # Store the boxplot setting for post-processing, then remove from kwargs -# # so it doesn't get passed to matplotlib's violinplot -# if "boxplot" not in kwargs: -# kwargs["boxplot"] = True # Default: add boxplot overlay -# -# # Default to hiding extrema (min/max bars) when boxplot is shown -# if "showextrema" not in kwargs: -# kwargs["showextrema"] = False -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_plot_defaults.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/styles/test__plot_postprocess.py b/tests/scitex/plt/styles/test__plot_postprocess.py deleted file mode 100644 index 0a815f3ce..000000000 --- a/tests/scitex/plt/styles/test__plot_postprocess.py +++ /dev/null @@ -1,503 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_plot_postprocess.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # Timestamp: "2026-01-13 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_plot_postprocess.py -# -# """Post-processing styling for plot methods. -# -# This module centralizes all styling applied AFTER matplotlib methods -# are called. Each function modifies the plot result or axes in-place. -# -# All default values are loaded from SCITEX_STYLE.yaml via presets.py. -# Delegates to figrecipe styling functions when available. -# """ -# -# from matplotlib.category import StrCategoryConverter, UnitData -# from matplotlib.ticker import FixedLocator, MaxNLocator -# -# from scitex.plt.styles._postprocess_helpers import ( -# calculate_cap_width_from_bar, -# calculate_cap_width_from_box, -# make_errorbar_one_sided, -# ) -# from scitex.plt.styles.presets import SCITEX_STYLE -# from scitex.plt.utils import mm_to_pt -# -# # ============================================================================ -# # Constants (loaded from centralized SCITEX_STYLE.yaml) -# # ============================================================================ -# DEFAULT_LINE_WIDTH_MM = SCITEX_STYLE.get("trace_thickness_mm", 0.2) -# DEFAULT_MARKER_SIZE_MM = SCITEX_STYLE.get("marker_size_mm", 0.8) -# DEFAULT_N_TICKS = SCITEX_STYLE.get("n_ticks", 4) - 1 # nbins = n_ticks - 1 -# SPINE_ZORDER = 1000 -# -# -# # ============================================================================ -# # Main post-processing function -# # ============================================================================ -# def apply_plot_postprocess(method_name, result, ax, kwargs, args=None): -# """Apply post-processing styling after matplotlib method call. -# -# Args: -# method_name: Name of the matplotlib method that was called -# result: Return value from the matplotlib method -# ax: The matplotlib axes -# kwargs: Original kwargs passed to the method -# args: Original positional args passed to the method (needed for violinplot) -# -# Returns -# ------- -# The result (possibly modified) -# """ -# # Always ensure spines are on top -# _ensure_spines_on_top(ax) -# -# # Apply tick locator for numerical axes -# _apply_tick_locator(ax) -# -# # Method-specific post-processing -# if method_name == "pie" and result is not None: -# _postprocess_pie(result) -# elif method_name == "stem" and result is not None: -# _postprocess_stem(result) -# elif method_name == "violinplot" and result is not None: -# _postprocess_violin(result, ax, kwargs, args) -# elif method_name == "boxplot" and result is not None: -# _postprocess_boxplot(result, ax) -# elif method_name == "scatter" and result is not None: -# _postprocess_scatter(result, kwargs) -# elif method_name == "bar" and result is not None: -# _postprocess_bar(result, ax, kwargs) -# elif method_name == "barh" and result is not None: -# _postprocess_barh(result, ax, kwargs) -# elif method_name == "errorbar" and result is not None: -# _postprocess_errorbar(result) -# elif method_name == "hist" and result is not None: -# _postprocess_hist(result, ax) -# elif method_name == "fill_between" and result is not None: -# _postprocess_fill_between(result, kwargs) -# -# return result -# -# -# # ============================================================================ -# # General post-processing -# # ============================================================================ -# def _ensure_spines_on_top(ax): -# """Ensure axes spines are always drawn in front of plot elements.""" -# try: -# ax.set_axisbelow(False) -# -# # Set very high z-order for spines -# for spine in ax.spines.values(): -# spine.set_zorder(SPINE_ZORDER) -# -# # Set z-order for tick marks -# ax.tick_params(zorder=SPINE_ZORDER) -# -# # Ensure plot patches have lower z-order than spines -# # But preserve intentionally set z-orders (e.g., boxplot in violin) -# for patch in ax.patches: -# current_z = patch.get_zorder() -# # Only lower z-order if it's >= SPINE_ZORDER or is at matplotlib default (1) -# if current_z >= SPINE_ZORDER: -# patch.set_zorder(current_z - SPINE_ZORDER) -# elif current_z == 1: -# # Default matplotlib z-order, lower it -# patch.set_zorder(0.5) -# # Otherwise, preserve the intentionally set z-order -# -# # Set axes patch behind everything -# ax.patch.set_zorder(-1) -# except Exception: -# pass -# -# -# def _apply_tick_locator(ax): -# """Apply MaxNLocator only to numerical (non-categorical) axes. -# -# Target: 3-4 ticks per axis for clean publication figures. -# MaxNLocator's nbins=3 gives approximately 3-4 tick marks. -# min_n_ticks=3 ensures at least 3 ticks (never 2). -# """ -# try: -# -# def is_categorical_axis(axis): -# # Use get_converter() for matplotlib 3.10+ compatibility -# converter = getattr(axis, "get_converter", lambda: axis.converter)() -# if isinstance(converter, StrCategoryConverter): -# return True -# if hasattr(axis, "units") and isinstance(axis.units, UnitData): -# return True -# if isinstance(axis.get_major_locator(), FixedLocator): -# return True -# return False -# -# if not is_categorical_axis(ax.xaxis): -# ax.xaxis.set_major_locator( -# MaxNLocator( -# nbins=DEFAULT_N_TICKS, min_n_ticks=3, integer=False, prune=None -# ) -# ) -# -# if not is_categorical_axis(ax.yaxis): -# ax.yaxis.set_major_locator( -# MaxNLocator( -# nbins=DEFAULT_N_TICKS, min_n_ticks=3, integer=False, prune=None -# ) -# ) -# except Exception: -# pass -# -# -# # ============================================================================ -# # Method-specific post-processing -# # ============================================================================ -# def _postprocess_pie(result): -# """Apply styling for pie charts.""" -# # pie returns (wedges, texts, autotexts) when autopct is used -# if len(result) >= 3: -# autotexts = result[2] -# for autotext in autotexts: -# autotext.set_fontsize(6) # 6pt for inline percentages -# -# -# def _postprocess_stem(result): -# """Apply styling for stem plots.""" -# baseline = result.baseline -# if baseline is not None: -# baseline.set_color("black") -# baseline.set_linestyle("--") -# -# -# def _postprocess_errorbar(result): -# """Apply styling for errorbar plots. -# -# Simplifies the legend to show only a line (no caps/bars). -# """ -# import matplotlib.legend as mlegend -# from matplotlib.container import ErrorbarContainer -# from matplotlib.legend_handler import HandlerErrorbar, HandlerLine2D -# -# # Custom handler that shows only a simple line for errorbar -# class SimpleLineHandler(HandlerErrorbar): -# def create_artists( -# self, -# legend, -# orig_handle, -# xdescent, -# ydescent, -# width, -# height, -# fontsize, -# trans, -# ): -# # Use HandlerLine2D to create just a line -# line_handler = HandlerLine2D() -# # Get the data line from the ErrorbarContainer -# data_line = orig_handle[0] -# if data_line is not None: -# return line_handler.create_artists( -# legend, -# data_line, -# xdescent, -# ydescent, -# width, -# height, -# fontsize, -# trans, -# ) -# return [] -# -# # Register the handler globally for ErrorbarContainer -# mlegend.Legend.update_default_handler_map({ErrorbarContainer: SimpleLineHandler()}) -# -# -# def _postprocess_violin(result, ax, kwargs, args): -# """Apply styling for violin plots with optional boxplot overlay.""" -# # Get scitex palette for coloring -# from scitex.plt.color._PARAMS import HEX -# -# palette = [ -# HEX["blue"], -# HEX["red"], -# HEX["green"], -# HEX["yellow"], -# HEX["purple"], -# HEX["orange"], -# HEX["lightblue"], -# HEX["pink"], -# ] -# -# if "bodies" in result: -# for i, body in enumerate(result["bodies"]): -# body.set_facecolor(palette[i % len(palette)]) -# body.set_edgecolor("black") -# body.set_linewidth(mm_to_pt(DEFAULT_LINE_WIDTH_MM)) -# body.set_alpha(1.0) -# -# # Add boxplot overlay by default (disable with boxplot=False) -# add_boxplot = kwargs.pop("boxplot", True) -# if add_boxplot and args: -# try: -# # Get data from first positional argument -# data = args[0] -# # Get positions if specified, otherwise use default -# positions = kwargs.get("positions", None) -# if positions is None: -# positions = range(1, len(data) + 1) -# -# # Calculate boxplot width dynamically from violin width -# # Get violin width from kwargs or use matplotlib default (0.5) -# violin_widths = kwargs.get("widths", 0.5) -# if hasattr(violin_widths, "__iter__"): -# violin_widths = violin_widths[0] if len(violin_widths) > 0 else 0.5 -# # Boxplot width = 20% of violin width -# boxplot_widths = violin_widths * 0.2 -# -# # Draw boxplot overlay with styling -# line_width = mm_to_pt(DEFAULT_LINE_WIDTH_MM) -# marker_size = mm_to_pt(DEFAULT_MARKER_SIZE_MM) -# -# # Call matplotlib's boxplot directly to avoid recursive post-processing -# # which would override our gray styling with the default blue -# if hasattr(ax, "_axes_mpl"): -# mpl_ax = ax._axes_mpl -# else: -# mpl_ax = ax -# bp = mpl_ax.boxplot( -# data, -# positions=list(positions), -# widths=boxplot_widths, -# patch_artist=True, -# manage_ticks=False, # Don't modify existing ticks -# ) -# -# # Style the boxplot: scitex gray fill with black edges for visibility -# # Set high z-order so boxplot appears on top of violin bodies -# boxplot_zorder = 10 -# for box in bp.get("boxes", []): -# box.set_facecolor(HEX["gray"]) # Scitex gray fill -# box.set_edgecolor("black") -# box.set_alpha(1.0) -# box.set_linewidth(line_width) -# box.set_zorder(boxplot_zorder) -# for median in bp.get("medians", []): -# median.set_color("black") # Black median line -# median.set_linewidth(line_width) # 0.2mm thickness -# median.set_zorder(boxplot_zorder + 1) -# for whisker in bp.get("whiskers", []): -# whisker.set_color("black") -# whisker.set_linewidth(line_width) -# whisker.set_zorder(boxplot_zorder) -# for cap in bp.get("caps", []): -# cap.set_color("black") -# cap.set_linewidth(line_width) -# cap.set_zorder(boxplot_zorder) -# for flier in bp.get("fliers", []): -# flier.set_markerfacecolor("none") # No fill (open circles) -# flier.set_markeredgecolor("black") -# flier.set_markersize(marker_size) # 0.8mm -# flier.set_markeredgewidth(line_width) # 0.2mm -# flier.set_zorder(boxplot_zorder + 2) -# except Exception: -# pass # Silently continue if boxplot overlay fails -# -# -# def _postprocess_boxplot(result, ax): -# """Apply styling for boxplots (standalone, not violin overlay).""" -# # Use the centralized style_boxplot function for consistent styling -# from scitex.plt.ax import style_boxplot -# -# style_boxplot(result) -# -# # Cap width: 33% of box width -# if "caps" in result and "boxes" in result and len(result["boxes"]) > 0: -# try: -# cap_width_pts = calculate_cap_width_from_box(result["boxes"][0], ax) -# for cap in result["caps"]: -# cap.set_markersize(cap_width_pts) -# except Exception: -# pass -# -# -# def _postprocess_scatter(result, kwargs): -# """Apply styling for scatter plots.""" -# # Apply default 0.8mm marker size if 's' not specified -# if "s" not in kwargs: -# size_pt = mm_to_pt(DEFAULT_MARKER_SIZE_MM) -# marker_area = size_pt**2 -# result.set_sizes([marker_area]) -# -# -# def _postprocess_hist(result, ax): -# """Apply styling for histogram plots. -# -# Ensures histogram bars have proper edge color and alpha for visibility. -# Delegates edge styling to figrecipe when available. -# """ -# # Delegate edge styling to figrecipe with fallback -# from scitex.plt.styles._postprocess_helpers import apply_hist_edge_style -# -# apply_hist_edge_style(ax, DEFAULT_LINE_WIDTH_MM) -# -# # Additionally ensure alpha is at least 0.7 for visibility -# if len(result) >= 3: -# patches = result[2] -# if hasattr(patches, "__iter__"): -# for patch_group in patches: -# if hasattr(patch_group, "__iter__"): -# for patch in patch_group: -# if patch.get_alpha() is None or patch.get_alpha() < 0.7: -# patch.set_alpha(1.0) -# else: -# if patch_group.get_alpha() is None or patch_group.get_alpha() < 0.7: -# patch_group.set_alpha(1.0) -# -# -# def _postprocess_fill_between(result, kwargs): -# """Apply styling for fill_between plots. -# -# Ensures shaded regions have proper alpha for visibility. -# """ -# # result is a PolyCollection -# if result is not None: -# # Only set edge if not already specified -# if "edgecolor" not in kwargs and "ec" not in kwargs: -# result.set_edgecolor("none") -# -# # Ensure alpha is reasonable (default 0.3 is common for fill_between) -# if "alpha" not in kwargs: -# result.set_alpha(0.3) -# -# -# def _postprocess_bar(result, ax, kwargs): -# """Apply styling for bar plots with colors and error bars.""" -# # Apply scitex palette only if color not explicitly set -# if "color" not in kwargs and "c" not in kwargs: -# from scitex.plt.color._PARAMS import HEX -# -# palette = [ -# HEX["blue"], -# HEX["red"], -# HEX["green"], -# HEX["yellow"], -# HEX["purple"], -# HEX["orange"], -# HEX["lightblue"], -# HEX["pink"], -# ] -# -# for i, patch in enumerate(result.patches): -# patch.set_facecolor(palette[i % len(palette)]) -# -# # Always apply SCITEX edge styling (black, 0.2mm) - delegate to figrecipe -# from scitex.plt.styles._postprocess_helpers import apply_bar_edge_style -# -# apply_bar_edge_style(ax, DEFAULT_LINE_WIDTH_MM) -# -# if "yerr" not in kwargs or kwargs["yerr"] is None: -# return -# -# try: -# errorbar = result.errorbar -# if errorbar is None: -# return -# -# lines = errorbar.lines -# if not lines or len(lines) < 3: -# return -# -# caplines = lines[1] -# if caplines and len(caplines) >= 2: -# # Hide lower caps (one-sided error bars) -# caplines[0].set_visible(False) -# -# # Adjust cap width to 33% of bar width -# if len(result.patches) > 0: -# cap_width_pts = calculate_cap_width_from_bar( -# result.patches[0], ax, "width" -# ) -# for cap in caplines[1:]: -# cap.set_markersize(cap_width_pts) -# -# # Make error bar lines one-sided -# barlinecols = lines[2] -# make_errorbar_one_sided(barlinecols, "vertical") -# except Exception: -# pass -# -# -# def _postprocess_barh(result, ax, kwargs): -# """Apply styling for horizontal bar plots with colors and error bars.""" -# # Apply scitex palette only if color not explicitly set -# if "color" not in kwargs and "c" not in kwargs: -# from scitex.plt.color._PARAMS import HEX -# -# palette = [ -# HEX["blue"], -# HEX["red"], -# HEX["green"], -# HEX["yellow"], -# HEX["purple"], -# HEX["orange"], -# HEX["lightblue"], -# HEX["pink"], -# ] -# -# for i, patch in enumerate(result.patches): -# patch.set_facecolor(palette[i % len(palette)]) -# -# # Always apply SCITEX edge styling (black, 0.2mm) - delegate to figrecipe -# from scitex.plt.styles._postprocess_helpers import apply_bar_edge_style -# -# apply_bar_edge_style(ax, DEFAULT_LINE_WIDTH_MM) -# -# if "xerr" not in kwargs or kwargs["xerr"] is None: -# return -# -# try: -# errorbar = result.errorbar -# if errorbar is None: -# return -# -# lines = errorbar.lines -# if not lines or len(lines) < 3: -# return -# -# caplines = lines[1] -# if caplines and len(caplines) >= 2: -# # Hide left caps (one-sided error bars) -# caplines[0].set_visible(False) -# -# # Adjust cap width to 33% of bar height -# if len(result.patches) > 0: -# cap_width_pts = calculate_cap_width_from_bar( -# result.patches[0], ax, "height" -# ) -# for cap in caplines[1:]: -# cap.set_markersize(cap_width_pts) -# -# # Make error bar lines one-sided -# barlinecols = lines[2] -# make_errorbar_one_sided(barlinecols, "horizontal") -# except Exception: -# pass -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_plot_postprocess.py -# -------------------------------------------------------------------------------- diff --git a/tests/scitex/plt/styles/test__postprocess_helpers.py b/tests/scitex/plt/styles/test__postprocess_helpers.py deleted file mode 100644 index e1a735ca3..000000000 --- a/tests/scitex/plt/styles/test__postprocess_helpers.py +++ /dev/null @@ -1,174 +0,0 @@ -# Add your tests here - -if __name__ == "__main__": - import os - - import pytest - - pytest.main([os.path.abspath(__file__)]) - -# -------------------------------------------------------------------------------- -# Start of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_postprocess_helpers.py -# -------------------------------------------------------------------------------- -# #!/usr/bin/env python3 -# # Timestamp: "2026-01-13 (ywatanabe)" -# # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_postprocess_helpers.py -# -# """Helper functions for plot post-processing. -# -# Extracted from _plot_postprocess.py to keep modules within line limits. -# Delegates to figrecipe styling functions when available. -# """ -# -# import numpy as np -# -# # Try to import figrecipe styling (delegate when available) -# try: -# from figrecipe.styles._plot_styles import ( -# apply_barplot_style as _fr_apply_barplot_style, -# ) -# from figrecipe.styles._plot_styles import ( -# apply_histogram_style as _fr_apply_histogram_style, -# ) -# -# FIGRECIPE_AVAILABLE = True -# except ImportError: -# FIGRECIPE_AVAILABLE = False -# -# # ============================================================================ -# # Constants -# # ============================================================================ -# CAP_WIDTH_RATIO = 1 / 3 # 33% of bar/box width -# -# -# # ============================================================================ -# # Helper functions -# # ============================================================================ -# def calculate_cap_width_from_box(box, ax): -# """Calculate cap width as 33% of box width in points.""" -# # Get box width from path -# if hasattr(box, "get_path"): -# path = box.get_path() -# vertices = path.vertices -# x_coords = vertices[:, 0] -# box_width_data = x_coords.max() - x_coords.min() -# elif hasattr(box, "get_xdata"): -# x_data = box.get_xdata() -# box_width_data = max(x_data) - min(x_data) -# else: -# box_width_data = 0.5 # Default -# -# return data_width_to_points(box_width_data, ax, "x") * CAP_WIDTH_RATIO -# -# -# def calculate_cap_width_from_bar(patch, ax, dimension): -# """Calculate cap width as 33% of bar width/height in points.""" -# if dimension == "width": -# bar_size = patch.get_width() -# return data_width_to_points(bar_size, ax, "x") * CAP_WIDTH_RATIO -# else: # height -# bar_size = patch.get_height() -# return data_width_to_points(bar_size, ax, "y") * CAP_WIDTH_RATIO -# -# -# def data_width_to_points(data_size, ax, axis="x"): -# """Convert a data-space size to points.""" -# fig = ax.get_figure() -# bbox = ax.get_position() -# -# if axis == "x": -# ax_size_inches = bbox.width * fig.get_figwidth() -# lim = ax.get_xlim() -# else: -# ax_size_inches = bbox.height * fig.get_figheight() -# lim = ax.get_ylim() -# -# data_range = lim[1] - lim[0] -# size_inches = (data_size / data_range) * ax_size_inches -# return size_inches * 72 # 72 points per inch -# -# -# def make_errorbar_one_sided(barlinecols, direction): -# """Make error bar line segments one-sided (outward only).""" -# if not barlinecols or len(barlinecols) == 0: -# return -# -# for lc in barlinecols: -# if not hasattr(lc, "get_segments"): -# continue -# -# segs = lc.get_segments() -# new_segs = [] -# for seg in segs: -# if len(seg) < 2: -# continue -# -# if direction == "vertical": -# # Keep upper half -# bottom_y = min(seg[0][1], seg[1][1]) -# top_y = max(seg[0][1], seg[1][1]) -# mid_y = (bottom_y + top_y) / 2 -# new_seg = np.array([[seg[0][0], mid_y], [seg[0][0], top_y]]) -# else: # horizontal -# # Keep right half -# left_x = min(seg[0][0], seg[1][0]) -# right_x = max(seg[0][0], seg[1][0]) -# mid_x = (left_x + right_x) / 2 -# new_seg = np.array([[mid_x, seg[0][1]], [right_x, seg[0][1]]]) -# -# new_segs.append(new_seg) -# -# if new_segs: -# lc.set_segments(new_segs) -# -# -# def apply_bar_edge_style(ax, line_width_mm): -# """Apply bar edge styling, delegating to figrecipe if available. -# -# Parameters -# ---------- -# ax : matplotlib Axes or AxisWrapper -# The axes containing bar patches. -# line_width_mm : float -# Line width in millimeters. -# """ -# from scitex.plt.utils import mm_to_pt -# -# ax_mpl = getattr(ax, "_axis_mpl", ax) -# -# if FIGRECIPE_AVAILABLE: -# _fr_apply_barplot_style(ax_mpl, {"barplot_edge_mm": line_width_mm}) -# else: -# # Fallback: apply edge styling directly -# from matplotlib.patches import Rectangle -# -# line_width_pt = mm_to_pt(line_width_mm) -# for patch in ax_mpl.patches: -# if isinstance(patch, Rectangle): -# patch.set_edgecolor("black") -# patch.set_linewidth(line_width_pt) -# -# -# def apply_hist_edge_style(ax, line_width_mm): -# """Apply histogram edge styling, delegating to figrecipe if available.""" -# from scitex.plt.utils import mm_to_pt -# -# ax_mpl = getattr(ax, "_axis_mpl", ax) -# -# if FIGRECIPE_AVAILABLE: -# _fr_apply_histogram_style(ax_mpl, {"histogram_edge_mm": line_width_mm}) -# else: -# from matplotlib.patches import Rectangle -# -# line_width_pt = mm_to_pt(line_width_mm) -# for patch in ax_mpl.patches: -# if isinstance(patch, Rectangle): -# patch.set_edgecolor("black") -# patch.set_linewidth(line_width_pt) -# -# -# # EOF - -# -------------------------------------------------------------------------------- -# End of Source Code from: /home/ywatanabe/proj/scitex-code/src/scitex/plt/styles/_postprocess_helpers.py -# -------------------------------------------------------------------------------- From 2a9cceabfa0d2ff2c141f117d0246a1d9df83b9d Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Thu, 19 Feb 2026 09:46:33 +1100 Subject: [PATCH 06/17] =?UTF-8?q?refactor(io):=20Remove=20AxisWrapper=20CS?= =?UTF-8?q?V=20export=20=E2=80=94=20Phase=203=20of=20figrecipe=20migration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace export_as_csv() AxisWrapper calls with figrecipe RecordingFigure proxy: - _figure_utils.py: New _RecordingFigureDataProxy reads from figrecipe recorder - _image_csv.py: Remove dead SigmaPlot export block - _plot_bundle.py: Simplify (CSV from explicit data only) - _plot_scitex.py: Remove AxisWrapper fallback, keep matplotlib line extractor - _save.py: Remove AxisWrapper CSV block - bundle/_mpl_helpers.py: _get_scitex_axes() → stub returning None CSV data still exported for figrecipe RecordingFigure objects via proxy. Co-Authored-By: Claude Sonnet 4.6 --- src/scitex/io/_save.py | 10 +- src/scitex/io/_save_modules/_figure_utils.py | 124 ++++++++----------- src/scitex/io/_save_modules/_image_csv.py | 51 -------- src/scitex/io/_save_modules/_plot_bundle.py | 11 +- src/scitex/io/_save_modules/_plot_scitex.py | 14 +-- src/scitex/io/bundle/_mpl_helpers.py | 20 +-- 6 files changed, 57 insertions(+), 173 deletions(-) diff --git a/src/scitex/io/_save.py b/src/scitex/io/_save.py index a9efe8621..d2e99e09d 100755 --- a/src/scitex/io/_save.py +++ b/src/scitex/io/_save.py @@ -364,7 +364,6 @@ def _save_scitex_bundle( from scitex.io.bundle import from_matplotlib - from ._save_modules._figure_utils import get_figure_with_data if isinstance(obj, matplotlib.figure.Figure): fig = obj @@ -381,14 +380,7 @@ def _save_scitex_bundle( dpi = kwargs.get("dpi", 300) name = kwargs.get("name") or Path(spath).stem - # Extract CSV data from scitex.plt tracking if available - scitex_source = get_figure_with_data(obj) - if csv_df is None and scitex_source is not None: - if hasattr(scitex_source, "export_as_csv"): - try: - csv_df = scitex_source.export_as_csv() - except Exception: - pass + # CSV via AxisWrapper removed (figrecipe migration); csv_df from explicit kwarg only # Delegate to Bundle (single source of truth) # Encoding is built from CSV columns directly for consistency diff --git a/src/scitex/io/_save_modules/_figure_utils.py b/src/scitex/io/_save_modules/_figure_utils.py index 45890b6fe..a775a906d 100755 --- a/src/scitex/io/_save_modules/_figure_utils.py +++ b/src/scitex/io/_save_modules/_figure_utils.py @@ -1,88 +1,64 @@ #!/usr/bin/env python3 -# Timestamp: 2025-12-19 -# File: /home/ywatanabe/proj/scitex-code/src/scitex/io/_save_modules/_figure_utils.py +# File: /home/ywatanabe/proj/scitex-python/src/scitex/io/_save_modules/_figure_utils.py """Utility functions for extracting figure data for CSV export.""" -def get_figure_with_data(obj): - """ - Extract figure or axes object that may contain plotting data for CSV export. +class _RecordingFigureDataProxy: + """Proxy providing export_as_csv() from a figrecipe RecordingFigure.""" + + def __init__(self, fig): + self._fig = fig + + def export_as_csv(self): + """Extract recorded plot data as a flat DataFrame.""" + try: + import pandas as pd + + rec = self._fig._recorder.figure_record + columns = {} + for ax_key, ax_rec in rec.axes.items(): + for call in ax_rec.calls: + for arg in call.args: + # args are dicts: {'name': ..., '_array': ..., ...} + if isinstance(arg, dict): + arr = arg.get("_array") + name = arg.get("name", "val") + else: + arr = getattr(arg, "_array", None) + name = getattr(arg, "name", "val") + if arr is not None: + col_name = f"{ax_key}_{call.id}_{name}" + data = arr.tolist() if hasattr(arr, "tolist") else list(arr) + columns[col_name] = data + + if not columns: + return None + + max_len = max(len(v) for v in columns.values()) + padded = { + k: list(v) + [float("nan")] * (max_len - len(v)) + for k, v in columns.items() + } + return pd.DataFrame(padded) + except Exception: + return None - Parameters - ---------- - obj : various matplotlib objects - Could be Figure, Axes, FigWrapper, AxisWrapper, or other matplotlib objects + +def get_figure_with_data(obj): + """Return a proxy with export_as_csv() if the object has figrecipe recording data. Returns ------- - object or None - Figure or axes object that has export_as_csv methods, or None if not found + _RecordingFigureDataProxy or None """ - import matplotlib.axes - import matplotlib.figure - import matplotlib.pyplot as plt - - # Check if object already has export methods (SciTeX wrapped objects) - if hasattr(obj, "export_as_csv"): - return obj - - # Handle matplotlib Figure objects - if isinstance(obj, matplotlib.figure.Figure): - # Get the current axes that might be wrapped with SciTeX functionality - current_ax = plt.gca() - if hasattr(current_ax, "export_as_csv"): - return current_ax - - # Check all axes in the figure - for ax in obj.axes: - if hasattr(ax, "export_as_csv"): - return ax - - return None - - # Handle matplotlib Axes objects - if isinstance(obj, matplotlib.axes.Axes): - if hasattr(obj, "export_as_csv"): - return obj - return None - - # Handle FigWrapper or similar SciTeX objects - if hasattr(obj, "figure") and hasattr(obj.figure, "axes"): - # Check if the wrapper itself has export methods - if hasattr(obj, "export_as_csv"): - return obj - - # Check the underlying figure's axes - for ax in obj.figure.axes: - if hasattr(ax, "export_as_csv"): - return ax - - return None - - # Handle AxisWrapper or similar SciTeX objects - if hasattr(obj, "_axis_mpl") or hasattr(obj, "_ax"): - if hasattr(obj, "export_as_csv"): - return obj - return None - - # Try to get the current figure and its axes as fallback - try: - current_fig = plt.gcf() - current_ax = plt.gca() - - if hasattr(current_ax, "export_as_csv"): - return current_ax - elif hasattr(current_fig, "export_as_csv"): - return current_fig - - # Check all axes in current figure - for ax in current_fig.axes: - if hasattr(ax, "export_as_csv"): - return ax + # figrecipe RecordingFigure directly + if hasattr(obj, "_recorder") and hasattr(obj._recorder, "figure_record"): + return _RecordingFigureDataProxy(obj) - except: - pass + # figrecipe RecordingFigure via .fig attribute (e.g. bundle objects) + if hasattr(obj, "fig") and hasattr(getattr(obj, "fig", None), "_recorder"): + return _RecordingFigureDataProxy(obj.fig) return None diff --git a/src/scitex/io/_save_modules/_image_csv.py b/src/scitex/io/_save_modules/_image_csv.py index 308a669a1..8fa0ab2b4 100755 --- a/src/scitex/io/_save_modules/_image_csv.py +++ b/src/scitex/io/_save_modules/_image_csv.py @@ -260,7 +260,6 @@ def _export_csv_data( if fig_obj is not None and hasattr(fig_obj, "export_as_csv"): csv_data = fig_obj.export_as_csv() if csv_data is not None and not csv_data.empty: - # Determine CSV path if parent_name.lower() in image_extensions: grandparent_dir = os.path.dirname(parent_dir) csv_dir = os.path.join(grandparent_dir, "csv") @@ -269,35 +268,17 @@ def _export_csv_data( csv_path = os.path.splitext(spath)[0] + ".csv" os.makedirs(os.path.dirname(csv_path), exist_ok=True) - - # Import here to avoid circular import from . import save_csv save_csv(csv_data, csv_path) - # Update metadata with CSV info if collected_metadata is not None: _update_metadata_with_csv(collected_metadata, csv_data, csv_path) - # Handle symlinks for CSV _create_csv_symlinks( csv_path, spath, symlink_from_cwd, symlink_to_path, image_extensions ) - # Also export SigmaPlot format if available - if fig_obj is not None and hasattr(fig_obj, "export_as_csv_for_sigmaplot"): - _export_sigmaplot_csv( - fig_obj, - spath, - parent_name, - parent_dir, - filename_without_ext, - symlink_from_cwd, - symlink_to_path, - image_extensions, - dry_run, - ) - except Exception as e: logger.warning(f"CSV export failed: {e}") @@ -375,38 +356,6 @@ def _create_csv_symlinks( symlink(csv_path, csv_cwd, True, True) -def _export_sigmaplot_csv( - fig_obj, - spath, - parent_name, - parent_dir, - filename_without_ext, - symlink_from_cwd, - symlink_to_path, - image_extensions, - dry_run, -): - """Export SigmaPlot-formatted CSV.""" - sigmaplot_data = fig_obj.export_as_csv_for_sigmaplot() - if sigmaplot_data is not None and not sigmaplot_data.empty: - if parent_name.lower() in image_extensions: - grandparent_dir = os.path.dirname(parent_dir) - csv_dir = os.path.join(grandparent_dir, "csv") - csv_sigmaplot_path = os.path.join( - csv_dir, filename_without_ext + "_for_sigmaplot.csv" - ) - else: - ext = os.path.splitext(spath)[1].lower().replace(".", "") - csv_sigmaplot_path = spath.replace(ext, "csv").replace( - ".csv", "_for_sigmaplot.csv" - ) - - os.makedirs(os.path.dirname(csv_sigmaplot_path), exist_ok=True) - from . import save_csv - - save_csv(sigmaplot_data, csv_sigmaplot_path) - - def _save_metadata_json( spath, collected_metadata, diff --git a/src/scitex/io/_save_modules/_plot_bundle.py b/src/scitex/io/_save_modules/_plot_bundle.py index e886a3432..9168e531b 100755 --- a/src/scitex/io/_save_modules/_plot_bundle.py +++ b/src/scitex/io/_save_modules/_plot_bundle.py @@ -9,8 +9,6 @@ from scitex import logging -from ._figure_utils import get_figure_with_data - logger = logging.getLogger() @@ -77,15 +75,8 @@ def save_plot_bundle(obj, spath, as_zip=False, data=None, layered=True, **kwargs bundle_dir = p if str(p).endswith(".plot") else Path(str(p) + ".plot") temp_dir = None - # Get CSV data from figure if not provided + # CSV via AxisWrapper removed (figrecipe migration); data passed explicitly or None csv_df = data - if csv_df is None: - csv_source = get_figure_with_data(obj) - if csv_source is not None and hasattr(csv_source, "export_as_csv"): - try: - csv_df = csv_source.export_as_csv() - except Exception: - pass from scitex.plt.io import save_layered_plot_bundle diff --git a/src/scitex/io/_save_modules/_plot_scitex.py b/src/scitex/io/_save_modules/_plot_scitex.py index 9e9c207d3..4ac973292 100755 --- a/src/scitex/io/_save_modules/_plot_scitex.py +++ b/src/scitex/io/_save_modules/_plot_scitex.py @@ -12,8 +12,6 @@ import numpy as np -from ._figure_utils import get_figure_with_data - def _create_stx_spec(bundle_type, title, size): """Create a spec dictionary for .stx bundle. @@ -247,16 +245,8 @@ def save_plot_as_scitex(obj, spath, as_zip=True, basename=None, **kwargs): # Get CSV data from figure if not provided csv_df = data if csv_df is None: - # Try SciTeX wrapped objects first - csv_source = get_figure_with_data(obj) - if csv_source is not None and hasattr(csv_source, "export_as_csv"): - try: - csv_df = csv_source.export_as_csv() - except Exception: - pass - # Fall back to extracting from matplotlib lines - if csv_df is None: - csv_df = _extract_data_from_figure(fig) + # CSV via AxisWrapper removed (figrecipe migration); fall back to matplotlib lines + csv_df = _extract_data_from_figure(fig) # Create spec for .stx format fig_width_inch, fig_height_inch = fig.get_size_inches() diff --git a/src/scitex/io/bundle/_mpl_helpers.py b/src/scitex/io/bundle/_mpl_helpers.py index 266b63340..18a3873a3 100755 --- a/src/scitex/io/bundle/_mpl_helpers.py +++ b/src/scitex/io/bundle/_mpl_helpers.py @@ -14,22 +14,7 @@ def _get_scitex_axes(fig: "MplFigure") -> Optional[Any]: - """Find scitex.plt wrapped axes with tracking data. - - Uses the same helper as sio.save to find objects with export_as_csv. - """ - try: - from scitex.io._save_modules._figure_utils import get_figure_with_data - - return get_figure_with_data(fig) - except ImportError: - pass - - # Fallback: check figure axes directly - axes_list = list(fig.axes) if hasattr(fig.axes, "__iter__") else [fig.axes] - for ax in axes_list: - if hasattr(ax, "export_as_csv") and hasattr(ax, "history"): - return ax + """Legacy stub — AxisWrapper CSV export removed in figrecipe migration.""" return None @@ -127,7 +112,8 @@ def validate_encoding_csv_link(encoding: "Encoding", csv_df: "Any") -> list: encoding: Encoding object with trace definitions csv_df: DataFrame with CSV data - Returns: + Returns + ------- List of validation errors (empty if valid) """ errors = [] From 336193a646820d62fa04c55c0c638865ded2c95e Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Thu, 19 Feb 2026 09:56:12 +1100 Subject: [PATCH 07/17] =?UTF-8?q?fix(io):=20Wire=20figrecipe=20recorder=20?= =?UTF-8?q?data=20into=20CSV=20export=20=E2=80=94=20Phase=203=20completion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The RecordingFigureDataProxy now handles both states: - Pre-savefig: reads _array from in-memory recorder args - Post-savefig: reads from figrecipe-saved _data/*.csv files Pass spath through _export_csv_data → get_figure_with_data so the proxy can locate the data directory. Also removed dead _export_sigmaplot_csv(). Result: stx.io.save(fr_fig, "plot.png") now creates plot.csv alongside the figrecipe plot.yaml and plot_data/ directory. Co-Authored-By: Claude Sonnet 4.6 --- src/scitex/io/_save_modules/_figure_utils.py | 66 +++++++++++++++----- src/scitex/io/_save_modules/_image_csv.py | 2 +- 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/scitex/io/_save_modules/_figure_utils.py b/src/scitex/io/_save_modules/_figure_utils.py index a775a906d..33dc04ce1 100755 --- a/src/scitex/io/_save_modules/_figure_utils.py +++ b/src/scitex/io/_save_modules/_figure_utils.py @@ -7,11 +7,17 @@ class _RecordingFigureDataProxy: """Proxy providing export_as_csv() from a figrecipe RecordingFigure.""" - def __init__(self, fig): + def __init__(self, fig, spath=None): self._fig = fig + self._spath = spath def export_as_csv(self): - """Extract recorded plot data as a flat DataFrame.""" + """Extract recorded plot data as a flat DataFrame. + + Handles two states: + - Pre-savefig: args have '_array' with in-memory data + - Post-savefig: args reference CSV files in _data/ directory + """ try: import pandas as pd @@ -20,17 +26,47 @@ def export_as_csv(self): for ax_key, ax_rec in rec.axes.items(): for call in ax_rec.calls: for arg in call.args: - # args are dicts: {'name': ..., '_array': ..., ...} - if isinstance(arg, dict): - arr = arg.get("_array") - name = arg.get("name", "val") - else: - arr = getattr(arg, "_array", None) - name = getattr(arg, "name", "val") + arr = ( + arg.get("_array") + if isinstance(arg, dict) + else getattr(arg, "_array", None) + ) + name = ( + arg.get("name") + if isinstance(arg, dict) + else getattr(arg, "name", None) + ) or "val" + data_path = ( + arg.get("data") + if isinstance(arg, dict) + else getattr(arg, "data", None) + ) + + col_name = f"{ax_key}_{call.id}_{name}" + if arr is not None: - col_name = f"{ax_key}_{call.id}_{name}" - data = arr.tolist() if hasattr(arr, "tolist") else list(arr) - columns[col_name] = data + # Pre-savefig: use in-memory array + columns[col_name] = ( + arr.tolist() if hasattr(arr, "tolist") else list(arr) + ) + elif ( + isinstance(data_path, str) + and data_path.endswith(".csv") + and self._spath + ): + # Post-savefig: data was serialized to a CSV file + import os + + stem = os.path.splitext(os.path.basename(self._spath))[0] + data_dir = os.path.join( + os.path.dirname(self._spath), f"{stem}_data" + ) + full_path = os.path.join( + data_dir, os.path.basename(data_path) + ) + if os.path.exists(full_path): + sub_df = pd.read_csv(full_path, header=None) + columns[col_name] = sub_df.iloc[:, 0].tolist() if not columns: return None @@ -45,7 +81,7 @@ def export_as_csv(self): return None -def get_figure_with_data(obj): +def get_figure_with_data(obj, spath=None): """Return a proxy with export_as_csv() if the object has figrecipe recording data. Returns @@ -54,11 +90,11 @@ def get_figure_with_data(obj): """ # figrecipe RecordingFigure directly if hasattr(obj, "_recorder") and hasattr(obj._recorder, "figure_record"): - return _RecordingFigureDataProxy(obj) + return _RecordingFigureDataProxy(obj, spath=spath) # figrecipe RecordingFigure via .fig attribute (e.g. bundle objects) if hasattr(obj, "fig") and hasattr(getattr(obj, "fig", None), "_recorder"): - return _RecordingFigureDataProxy(obj.fig) + return _RecordingFigureDataProxy(obj.fig, spath=spath) return None diff --git a/src/scitex/io/_save_modules/_image_csv.py b/src/scitex/io/_save_modules/_image_csv.py index 8fa0ab2b4..bf944c4c3 100755 --- a/src/scitex/io/_save_modules/_image_csv.py +++ b/src/scitex/io/_save_modules/_image_csv.py @@ -255,7 +255,7 @@ def _export_csv_data( csv_path = None try: - fig_obj = get_figure_with_data(obj) + fig_obj = get_figure_with_data(obj, spath=spath) if fig_obj is not None and hasattr(fig_obj, "export_as_csv"): csv_data = fig_obj.export_as_csv() From 6067e04f1ecb970574cab9146a2f037554ad2c21 Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Thu, 19 Feb 2026 10:01:05 +1100 Subject: [PATCH 08/17] =?UTF-8?q?fix(tests):=20Update=20plt=20import=20tes?= =?UTF-8?q?ts=20+=20fix=20diagram=20compile=20imports=20=E2=80=94=20Phase?= =?UTF-8?q?=204?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test_imports.py: Replace 'ax' submodule assertion with figrecipe API checks (ax submodule deleted in AxisWrapper migration) - diagram/__init__.py: Fix broken figrecipe._diagram._compile import → figrecipe._diagram._graphviz + figrecipe._diagram._mermaid Co-Authored-By: Claude Sonnet 4.6 --- src/scitex/diagram/__init__.py | 3 ++- tests/custom/test_imports.py | 12 ++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/scitex/diagram/__init__.py b/src/scitex/diagram/__init__.py index b927c6f78..c481fb209 100755 --- a/src/scitex/diagram/__init__.py +++ b/src/scitex/diagram/__init__.py @@ -46,7 +46,8 @@ get_preset, list_presets, ) -from figrecipe._diagram._compile import compile_to_graphviz, compile_to_mermaid +from figrecipe._diagram._graphviz import compile_to_graphviz +from figrecipe._diagram._mermaid import compile_to_mermaid __all__ = [ "Diagram", diff --git a/tests/custom/test_imports.py b/tests/custom/test_imports.py index 4e960b758..cbf53c79f 100755 --- a/tests/custom/test_imports.py +++ b/tests/custom/test_imports.py @@ -90,14 +90,10 @@ def test_scitex_plt_import(self): """Test that scitex.plt module imports correctly.""" from scitex import plt - assert hasattr(plt, "ax") - - def test_scitex_plt_ax_import(self): - """Test that scitex.plt.ax submodule imports correctly.""" - from scitex.plt import ax # noqa: F401 - - assert hasattr(ax, "stx_heatmap") - assert hasattr(ax, "stx_joyplot") + # ax submodule removed in figrecipe migration (AxisWrapper deleted) + assert hasattr(plt, "subplots") + assert hasattr(plt, "save") + assert hasattr(plt, "color") def test_scitex_session_import(self): """Test that scitex.session module imports without circular imports.""" From d804fcfd734e76cdc5b73bd2b77985638f5ead69 Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Thu, 19 Feb 2026 10:09:49 +1100 Subject: [PATCH 09/17] =?UTF-8?q?refactor(plt):=20Remove=20dead=20format?= =?UTF-8?q?=5Frecord=20imports=20=E2=80=94=20Phase=205=20cleanup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three files had try-except blocks importing format_record from the deleted _subplots._export_as_csv module. They already had correct fallback paths, so remove the dead try blocks and go directly to the pattern-based column name generation. - plt/utils/_metadata/_csv.py - plt/utils/metadata/_data_linkage.py - plt/utils/_collect_figure_metadata.py Co-Authored-By: Claude Sonnet 4.6 --- .../plt/utils/_collect_figure_metadata.py | 37 ++----------------- src/scitex/plt/utils/_metadata/_csv.py | 21 +---------- .../plt/utils/metadata/_data_linkage.py | 32 +--------------- 3 files changed, 5 insertions(+), 85 deletions(-) diff --git a/src/scitex/plt/utils/_collect_figure_metadata.py b/src/scitex/plt/utils/_collect_figure_metadata.py index 1ab8ee945..53bf4ae4f 100755 --- a/src/scitex/plt/utils/_collect_figure_metadata.py +++ b/src/scitex/plt/utils/_collect_figure_metadata.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # Timestamp: "2025-11-19 13:00:00 (ywatanabe)" # File: /home/ywatanabe/proj/scitex-code/src/scitex/plt/utils/_collect_figure_metadata.py @@ -2376,8 +2375,6 @@ def _extract_csv_columns_from_history(ax) -> list: List of dictionaries containing CSV column mappings for each tracked plot, e.g., [{"id": "boxplot_0", "method": "boxplot", "columns": ["ax_00_boxplot_0_boxplot_0", "ax_00_boxplot_0_boxplot_1"]}] """ - from ._csv_column_naming import get_csv_column_name - # Get axes position for CSV column naming ax_row, ax_col = 0, 0 # Default for single axes if hasattr(ax, "_scitex_metadata") and "position_in_grid" in ax._scitex_metadata: @@ -2680,36 +2677,8 @@ def _get_csv_columns_for_method( List of column names that will be in the CSV (exact match) """ # Import the actual formatters to ensure consistency - # This is the single source of truth - we use the same code path as CSV export - try: - import pandas as pd - - from scitex.plt._subplots._export_as_csv import format_record - - # Construct the record tuple as used in tracking - record = (id_val, method, tracked_dict, kwargs) - - # Call the actual formatter to get the DataFrame - df = format_record(record) - - if df is not None and not df.empty: - # Add the axis prefix (this is what FigWrapper.export_as_csv does) - # Uses zero-padded index: ax_00_, ax_01_, etc. - prefix = f"ax_{ax_index:02d}_" - columns = [] - for col in df.columns: - col_str = str(col) - if not col_str.startswith(prefix): - col_str = f"{prefix}{col_str}" - columns.append(col_str) - return columns - - except Exception: - # If formatters fail, fall back to pattern-based generation - pass - - # Fallback: Pattern-based column name generation - # This should rarely be used since we prefer the actual formatter + # format_record removed in figrecipe migration; use pattern-based generation directly + # Pattern-based column name generation import numpy as np prefix = f"ax_{ax_index:02d}_" @@ -2953,7 +2922,7 @@ def verify_csv_json_consistency(csv_path: str, json_path: str = None) -> dict: try: # Read JSON metadata - with open(json_path, "r") as f: + with open(json_path) as f: metadata = json.load(f) # Get columns_actual from data section diff --git a/src/scitex/plt/utils/_metadata/_csv.py b/src/scitex/plt/utils/_metadata/_csv.py index 00163d284..104d53559 100755 --- a/src/scitex/plt/utils/_metadata/_csv.py +++ b/src/scitex/plt/utils/_metadata/_csv.py @@ -199,26 +199,7 @@ def _get_csv_columns_for_method( list List of column names that will be in the CSV """ - try: - from scitex.plt._subplots._export_as_csv import format_record - - record = (id_val, method, tracked_dict, kwargs) - df = format_record(record) - - if df is not None and not df.empty: - prefix = f"ax_{ax_index:02d}_" - columns = [] - for col in df.columns: - col_str = str(col) - if not col_str.startswith(prefix): - col_str = f"{prefix}{col_str}" - columns.append(col_str) - return columns - - except Exception: - pass - - # Fallback: Pattern-based column name generation + # format_record removed in figrecipe migration; use pattern-based fallback directly return _get_csv_columns_fallback(id_val, method, tracked_dict, kwargs, ax_index) diff --git a/src/scitex/plt/utils/metadata/_data_linkage.py b/src/scitex/plt/utils/metadata/_data_linkage.py index e7eb5dd69..b883e78b0 100755 --- a/src/scitex/plt/utils/metadata/_data_linkage.py +++ b/src/scitex/plt/utils/metadata/_data_linkage.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # File: scitex/plt/utils/metadata/_data_linkage.py """ @@ -23,9 +22,6 @@ from ._csv_hash import _compute_csv_hash, _compute_csv_hash_from_df from ._csv_verification import assert_csv_json_consistency, verify_csv_json_consistency from ._recipe_extraction import ( - _build_data_ref, - _extract_calls_from_history, - _filter_style_kwargs, collect_recipe_metadata, ) @@ -73,33 +69,7 @@ def _get_csv_columns_for_method( list List of column names that will be in the CSV (exact match) """ - # Import the actual formatters to ensure consistency - try: - import pandas as pd - - from scitex.plt._subplots._export_as_csv import format_record - - # Construct the record tuple as used in tracking - record = (id_val, method, tracked_dict, kwargs) - - # Call the actual formatter to get the DataFrame - df = format_record(record) - - if df is not None and not df.empty: - # Add the axis prefix (this is what FigWrapper.export_as_csv does) - prefix = f"ax_{ax_index:02d}_" - columns = [] - for col in df.columns: - col_str = str(col) - if not col_str.startswith(prefix): - col_str = f"{prefix}{col_str}" - columns.append(col_str) - return columns - - except Exception: - # If formatters fail, fall back to pattern-based generation - pass - + # format_record removed in figrecipe migration; use pattern-based generation directly # Fallback: Pattern-based column name generation prefix = f"ax_{ax_index:02d}_" columns = [] From 004110de96a93c1d4706c43d8b92373fd82a8992 Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Thu, 19 Feb 2026 10:28:20 +1100 Subject: [PATCH 10/17] test(plt): Add migration stability tests and demo for figrecipe backend --- examples/plt/migration_demo.py | 197 ++++++++++++++ tests/scitex/plt/test__migration_stability.py | 247 ++++++++++++++++++ 2 files changed, 444 insertions(+) create mode 100755 examples/plt/migration_demo.py create mode 100755 tests/scitex/plt/test__migration_stability.py diff --git a/examples/plt/migration_demo.py b/examples/plt/migration_demo.py new file mode 100755 index 000000000..fed5a7fac --- /dev/null +++ b/examples/plt/migration_demo.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Timestamp: 2026-02-19 +# Author: ywatanabe +# File: examples/plt/migration_demo.py +""" +Migration demo: scitex.plt on figrecipe backend. + +Shows all key features after the AxisWrapper/FigWrapper -> figrecipe migration. +All plot calls are now handled by figrecipe's RecordingFigure / RecordingAxes. + +Run with: + python examples/plt/migration_demo.py +""" + +import matplotlib + +matplotlib.use("Agg") + +import os +import tempfile + +import numpy as np + +import scitex as stx + + +def demo_basic_line(plt): + """1. Basic line plot - confirm RecordingFigure/RecordingAxes returned.""" + fig, ax = plt.subplots() + ax.plot([1, 2, 3, 4, 5], [1, 4, 2, 3, 5], label="data") + ax.set_xyt("X", "Y", "Basic Line Plot") + ax.hide_spines() + + with tempfile.TemporaryDirectory() as tmpdir: + plt.save( + fig, os.path.join(tmpdir, "basic_line.png"), validate=False, verbose=False + ) + + print(f" fig type : {type(fig).__name__} (module: {type(fig).__module__})") + print(f" ax type : {type(ax).__name__} (module: {type(ax).__module__})") + print(f" has _recorder : {hasattr(fig, '_recorder')}") + plt.close("all") + + +def demo_scientific_methods(plt): + """2. stx_* scientific methods across a 2x3 grid of axes.""" + data = np.random.randn(50, 8) # 50 timepoints, 8 trials + + fig, axes = plt.subplots(2, 3, axes_width_mm=40, axes_height_mm=28) + axes_flat = axes.flatten() + + # Mean +/- SD + axes_flat[0].stx_mean_std(data) + axes_flat[0].set_xyt("Time", "Value", "Mean +/- SD") + + # Mean +/- 95 % CI + axes_flat[1].stx_mean_ci(data) + axes_flat[1].set_xyt("Time", "Value", "Mean +/- 95% CI") + + # Median +/- IQR + axes_flat[2].stx_median_iqr(data) + axes_flat[2].set_xyt("Time", "Value", "Median +/- IQR") + + # ECDF + axes_flat[3].stx_ecdf(data[:, 0]) + axes_flat[3].set_xyt("Value", "Probability", "ECDF") + + # Confusion matrix + conf_mat = np.array([[45, 5, 2], [3, 38, 4], [1, 2, 50]]) + axes_flat[4].stx_conf_mat( + conf_mat, + x_labels=["A", "B", "C"], + y_labels=["A", "B", "C"], + ) + axes_flat[4].set_xyt("Predicted", "True", "Confusion Matrix") + + # Raster plot + spike_times = [np.sort(np.random.uniform(0, 1, 10)) for _ in range(5)] + axes_flat[5].stx_raster(spike_times) + axes_flat[5].set_xyt("Time (s)", "Trial", "Raster Plot") + + with tempfile.TemporaryDirectory() as tmpdir: + plt.save( + fig, + os.path.join(tmpdir, "scientific_methods.png"), + validate=False, + verbose=False, + ) + print(" stx_mean_std / stx_mean_ci / stx_median_iqr : OK") + print(" stx_ecdf / stx_conf_mat / stx_raster : OK") + plt.close("all") + + +def demo_additional_methods(plt): + """3. Additional stx_* methods: violin, heatmap, fillv.""" + fig, axes = plt.subplots(1, 3, axes_width_mm=40, axes_height_mm=28) + + # Violin + axes[0].stx_violin([np.random.randn(30), np.random.randn(30)]) + axes[0].set_xyt("Group", "Value", "Violin") + + # Heatmap + axes[1].stx_heatmap(np.random.randn(6, 6)) + axes[1].set_xyt("Col", "Row", "Heatmap") + + # Shaded vertical region + axes[2].plot([0, 1, 2, 3], [0, 1, 0.5, 0.8]) + axes[2].stx_fillv([0.5], [1.5]) + axes[2].set_xyt("X", "Y", "Shaded Region") + + with tempfile.TemporaryDirectory() as tmpdir: + plt.save( + fig, + os.path.join(tmpdir, "additional_methods.png"), + validate=False, + verbose=False, + ) + print(" stx_violin / stx_heatmap / stx_fillv : OK") + plt.close("all") + + +def demo_recording_capability(plt): + """4. Verify figrecipe call-recording works.""" + fig, ax = plt.subplots() + ax.plot([1, 2, 3], [4, 5, 6]) + + fr = fig._recorder.figure_record + calls_total = sum(len(ar.calls) for ar in fr.axes.values()) + print(f" axes recorded : {len(fr.axes)}") + print(f" calls recorded : {calls_total}") + assert calls_total >= 1, "Recording failed" + plt.close("all") + + +def demo_io_save(plt): + """5. stx.io.save() integration with RecordingFigure.""" + fig, ax = plt.subplots() + ax.plot( + np.linspace(0, 2 * np.pi, 100), + np.sin(np.linspace(0, 2 * np.pi, 100)), + ) + ax.set_xyt("Phase (rad)", "Amplitude", "Sine Wave") + + with tempfile.TemporaryDirectory() as tmpdir: + out = os.path.join(tmpdir, "sine.png") + stx.io.save(fig, out) + exists = os.path.exists(out) + print(f" io.save created PNG : {exists}") + plt.close("all") + + +def demo_color_module(plt): + """6. Confirm scitex.plt.color submodule is still accessible.""" + assert hasattr(plt, "color"), "plt.color missing" + from scitex.plt.color import HEX + + assert isinstance(HEX, dict) and len(HEX) > 0 + sample = list(HEX.items())[:3] + print(f" HEX sample : {sample}") + + +@stx.session +def main( + CONFIG=stx.session.INJECTED, + plt=stx.session.INJECTED, + COLORS=stx.session.INJECTED, + rngg=stx.session.INJECTED, + logger=stx.session.INJECTED, +): + print("scitex.plt -> figrecipe migration demo") + print("=" * 50) + + print("\n[1] Basic line plot") + demo_basic_line(plt) + + print("\n[2] Scientific methods (2x3 grid)") + demo_scientific_methods(plt) + + print("\n[3] Additional methods (violin / heatmap / fillv)") + demo_additional_methods(plt) + + print("\n[4] Recording capability") + demo_recording_capability(plt) + + print("\n[5] stx.io.save integration") + demo_io_save(plt) + + print("\n[6] Color submodule") + demo_color_module(plt) + + print("\n" + "=" * 50) + print("Migration demo complete - all features working.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/scitex/plt/test__migration_stability.py b/tests/scitex/plt/test__migration_stability.py new file mode 100755 index 000000000..f856f11a6 --- /dev/null +++ b/tests/scitex/plt/test__migration_stability.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +# Timestamp: 2026-02-19 +# Author: ywatanabe +# File: tests/scitex/plt/test__migration_stability.py +""" +Integration tests confirming scitex.plt -> figrecipe migration stability. + +Verifies that: +- stx.plt.subplots() returns figrecipe RecordingFigure/RecordingAxes +- All stx_* scientific plot methods are accessible and callable +- AxisWrapper / FigWrapper are fully removed +- Recording capability is functional +- Color submodule and io.save integration still work +""" + +import matplotlib + +matplotlib.use("Agg") + +import matplotlib.pyplot as mplt +import numpy as np +import pytest + +import scitex as stx +import scitex.plt as plt + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def fig_ax(): + fig, ax = plt.subplots() + yield fig, ax + mplt.close("all") + + +@pytest.fixture +def fig_ax_multi(): + fig, axes = plt.subplots(1, 3) + yield fig, axes + mplt.close("all") + + +# --------------------------------------------------------------------------- +# TestMigrationBackend +# --------------------------------------------------------------------------- + + +class TestMigrationBackend: + """Verify figrecipe classes are returned and legacy classes are gone.""" + + def test_subplots_returns_recording_figure(self, fig_ax): + """stx.plt.subplots() must return a figrecipe RecordingFigure.""" + fig, _ = fig_ax + assert hasattr(fig, "_recorder"), ( + f"Expected RecordingFigure with _recorder, got {type(fig)}" + ) + + def test_subplots_returns_recording_axes(self, fig_ax): + """Returned axes must be RecordingAxes (has SciTexMixin methods).""" + _, ax = fig_ax + assert hasattr(ax, "stx_mean_std"), ( + f"Expected RecordingAxes with stx_mean_std, got {type(ax)}" + ) + + def test_save_is_figrecipe_save(self): + """plt.save must resolve to figrecipe, not scitex internals.""" + assert plt.save.__module__.startswith("figrecipe"), ( + f"plt.save module is '{plt.save.__module__}', expected figrecipe.*" + ) + + def test_no_axiswrapper(self): + """AxisWrapper must be unimportable from scitex.plt after migration.""" + with pytest.raises(ImportError): + from scitex.plt import AxisWrapper # noqa: F401 + + def test_no_figwrapper(self): + """FigWrapper must be unimportable from scitex.plt after migration.""" + with pytest.raises(ImportError): + from scitex.plt import FigWrapper # noqa: F401 + + +# --------------------------------------------------------------------------- +# TestStyleMethods +# --------------------------------------------------------------------------- + + +class TestStyleMethods: + """Verify style/decoration methods provided by figrecipe mixins.""" + + def test_set_xyt(self, fig_ax): + """ax.set_xyt() should label axes and set title without error.""" + _, ax = fig_ax + assert callable(getattr(ax, "set_xyt", None)), "set_xyt not callable" + ax.set_xyt("X label", "Y label", "Plot Title") + + def test_hide_spines(self, fig_ax): + """ax.hide_spines() should run without error.""" + _, ax = fig_ax + assert callable(getattr(ax, "hide_spines", None)), "hide_spines not callable" + ax.hide_spines() + + def test_sci_note(self, fig_ax): + """ax.sci_note() should apply scientific notation without error.""" + _, ax = fig_ax + assert callable(getattr(ax, "sci_note", None)), "sci_note not callable" + ax.plot([1, 2], [1e6, 2e6]) + ax.sci_note() + + +# --------------------------------------------------------------------------- +# TestSciTexMethods +# --------------------------------------------------------------------------- + + +class TestSciTexMethods: + """Verify all stx_* scientific plotting methods are functional.""" + + def test_stx_mean_std(self, fig_ax): + """ax.stx_mean_std() should plot mean +/- SD bands.""" + _, ax = fig_ax + ax.stx_mean_std(np.random.randn(10, 5)) + + def test_stx_mean_ci(self, fig_ax): + """ax.stx_mean_ci() should plot mean +/- 95% CI bands.""" + _, ax = fig_ax + ax.stx_mean_ci(np.random.randn(10, 5)) + + def test_stx_median_iqr(self, fig_ax): + """ax.stx_median_iqr() should plot median +/- IQR bands.""" + _, ax = fig_ax + ax.stx_median_iqr(np.random.randn(10, 5)) + + def test_stx_ecdf(self, fig_ax): + """ax.stx_ecdf() should plot empirical CDF.""" + _, ax = fig_ax + ax.stx_ecdf(np.random.randn(100)) + + def test_stx_conf_mat(self, fig_ax): + """ax.stx_conf_mat() should render a 2x2 confusion matrix.""" + _, ax = fig_ax + ax.stx_conf_mat(np.array([[10, 2], [3, 15]])) + + def test_stx_heatmap(self, fig_ax): + """ax.stx_heatmap() should render a 4x4 heatmap.""" + _, ax = fig_ax + ax.stx_heatmap(np.random.randn(4, 4)) + + def test_stx_violin(self, fig_ax): + """ax.stx_violin() should render violin plots for two groups.""" + _, ax = fig_ax + ax.stx_violin([np.random.randn(20), np.random.randn(20)]) + + def test_stx_raster(self, fig_ax): + """ax.stx_raster() should render a spike raster plot.""" + _, ax = fig_ax + ax.stx_raster([[0.1, 0.5, 0.9], [0.2, 0.7]]) + + def test_stx_fillv(self, fig_ax): + """ax.stx_fillv() should add vertical shaded regions.""" + _, ax = fig_ax + ax.plot([0, 1], [0, 1]) + ax.stx_fillv([0.2], [0.5]) + + +# --------------------------------------------------------------------------- +# TestRecordingCapability +# --------------------------------------------------------------------------- + + +class TestRecordingCapability: + """Verify figrecipe's call-recording mechanism works end-to-end.""" + + def test_calls_recorded(self, fig_ax): + """plot() calls must be captured in _recorder.figure_record.""" + fig, ax = fig_ax + ax.plot([1, 2, 3], [4, 5, 6]) + figure_record = fig._recorder.figure_record + # At least one axes entry must have at least one recorded call + assert len(figure_record.axes) >= 1, "No axes entries recorded" + calls_total = sum(len(ar.calls) for ar in figure_record.axes.values()) + assert calls_total >= 1, "No calls recorded after ax.plot()" + + def test_recipe_structure(self, fig_ax): + """RecordingFigure._recorder.figure_record must have an .axes dict.""" + fig, ax = fig_ax + ax.plot([1, 2], [1, 2]) + recorder = fig._recorder + assert hasattr(recorder, "figure_record"), "_recorder missing figure_record" + fr = recorder.figure_record + assert hasattr(fr, "axes"), "figure_record missing .axes" + assert isinstance(fr.axes, dict), ( + f"figure_record.axes is {type(fr.axes)}, expected dict" + ) + + +# --------------------------------------------------------------------------- +# TestColorModule +# --------------------------------------------------------------------------- + + +class TestColorModule: + """Verify scitex.plt.color submodule is still accessible post-migration.""" + + def test_color_module_accessible(self): + """plt.color attribute must exist on the scitex.plt module.""" + assert hasattr(plt, "color"), "scitex.plt.color submodule not accessible" + + def test_hex_colors_accessible(self): + """HEX color dict must be importable from scitex.plt.color.""" + from scitex.plt.color import HEX + + assert isinstance(HEX, dict), f"HEX is {type(HEX)}, expected dict" + assert len(HEX) > 0, "HEX dict is empty" + + +# --------------------------------------------------------------------------- +# TestIoSaveIntegration +# --------------------------------------------------------------------------- + + +class TestIoSaveIntegration: + """Verify that figure saving works via both plt.save and stx.io.save.""" + + def test_io_save_fig_creates_png(self, tmp_path): + """stx.io.save(fig, path) must create a PNG file on disk.""" + fig, ax = plt.subplots() + ax.plot([1, 2, 3], [4, 5, 6]) + out = str(tmp_path / "test_io.png") + try: + stx.io.save(fig, out) + assert (tmp_path / "test_io.png").exists(), "PNG not created by io.save" + finally: + mplt.close("all") + + def test_plt_save_creates_png(self, tmp_path): + """plt.save(fig, path, validate=False) must create a PNG file.""" + fig, ax = plt.subplots() + ax.plot([1, 2], [1, 2]) + out = str(tmp_path / "test_plt.png") + try: + plt.save(fig, out, validate=False, verbose=False) + assert (tmp_path / "test_plt.png").exists(), "PNG not created by plt.save" + finally: + mplt.close("all") From 7a7cea13a08ac88b520119c963790effcef8c494 Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Fri, 20 Feb 2026 07:52:05 +1100 Subject: [PATCH 11/17] chore: bump version to v2.18.1 Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e4b4afe87..dd4d12748 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ build-backend = "hatchling.build" [project] name = "scitex" -version = "2.18.0" +version = "2.18.1" description = "A comprehensive Python library for scientific computing and data analysis" readme = "README.md" requires-python = ">=3.10" @@ -116,7 +116,7 @@ ai = [ "ruamel.yaml", "xarray", "umap-learn", - "sktime", + "sktime>=0.21.0", # >=0.21 has proper build-system deps declared (0.8.1 was broken) "markdown2", "imbalanced-learn", # # Heavy dependencies handled by _AVAILABLE flags @@ -780,7 +780,7 @@ all = [ "imbalanced-learn", "umap-learn>=0.5.4", "llvmlite>=0.39.0", # Force Python 3.11+ compatible version - "sktime", + "sktime>=0.21.0", # >=0.21 has proper build-system deps declared (0.8.1 was broken) "catboost", "opencv-python", # neuro From 95be5dd72dc386a661145b2c6ef717fd22f1a66c Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Fri, 20 Feb 2026 16:46:14 +1100 Subject: [PATCH 12/17] feat(audio): Add output_path parameter to audio_speak MCP tool - Add optional output_path parameter to register_audio_tools() in _mcp_tools/audio.py - Add optional output_path parameter to speak_handler() in audio/_mcp/handlers.py - output_path allows callers to specify an explicit save path for audio files - When output_path is None and save=True, auto-generate timestamped path as before - Update docstrings to document save vs output_path distinction Co-Authored-By: Claude Sonnet 4.6 --- src/scitex/_mcp_tools/audio.py | 6 ++++++ src/scitex/audio/_mcp/handlers.py | 7 ++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/scitex/_mcp_tools/audio.py b/src/scitex/_mcp_tools/audio.py index 2c1ee3d77..adf7328ae 100755 --- a/src/scitex/_mcp_tools/audio.py +++ b/src/scitex/_mcp_tools/audio.py @@ -25,6 +25,7 @@ async def audio_speak( speed: float = 1.5, play: bool = True, save: bool = False, + output_path: Optional[str] = None, fallback: bool = True, agent_id: Optional[str] = None, wait: bool = True, @@ -37,6 +38,10 @@ async def audio_speak( - If local audio available -> uses local - If neither available -> returns error with instructions + Args: + save: Auto-save to timestamped file in SCITEX_DIR/audio/ if output_path not set. + output_path: Explicit path to save audio file (e.g. /tmp/notify.mp3). + Environment variables: - SCITEX_AUDIO_MODE: 'local', 'remote', or 'auto' (default: auto) - SCITEX_AUDIO_RELAY_URL: Relay server URL for remote playback @@ -51,6 +56,7 @@ async def audio_speak( speed=speed, play=play, save=save, + output_path=output_path, fallback=fallback, agent_id=agent_id, wait=wait, diff --git a/src/scitex/audio/_mcp/handlers.py b/src/scitex/audio/_mcp/handlers.py index 67f69ccec..626c18211 100755 --- a/src/scitex/audio/_mcp/handlers.py +++ b/src/scitex/audio/_mcp/handlers.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # Timestamp: "2026-02-06 23:02:36 (ywatanabe)" # File: /home/ywatanabe/proj/scitex-python/src/scitex/audio/_mcp/handlers.py @@ -301,6 +300,7 @@ async def speak_handler( speed: float = 1.5, play: bool = True, save: bool = False, + output_path: str | None = None, fallback: bool = True, agent_id: str | None = None, wait: bool = True, @@ -309,6 +309,8 @@ async def speak_handler( """Convert text to speech with fallback. Args: + save: If True and output_path is None, auto-generate a timestamped path. + output_path: Explicit path to save the audio file (overrides save flag). signature: If True, prepend hostname/project/branch to text. """ try: @@ -323,8 +325,7 @@ async def speak_handler( sig = _get_signature() final_text = sig + text - output_path = None - if save: + if output_path is None and save: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_path = str(_get_audio_dir() / f"tts_{timestamp}.mp3") From 6b103aaf6d3f036f68d11f7ed35638f0aca9bcfe Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Sat, 21 Feb 2026 17:45:24 +1100 Subject: [PATCH 13/17] feat(scholar): add HTTP mode to CitationGraphBuilder via crossref-local API Co-Authored-By: Claude Opus 4.6 --- src/scitex/scholar/citation_graph/__init__.py | 8 +- src/scitex/scholar/citation_graph/builder.py | 45 +++- .../scholar/citation_graph/database_http.py | 230 ++++++++++++++++++ 3 files changed, 267 insertions(+), 16 deletions(-) create mode 100755 src/scitex/scholar/citation_graph/database_http.py diff --git a/src/scitex/scholar/citation_graph/__init__.py b/src/scitex/scholar/citation_graph/__init__.py index 876515d2d..4311625af 100755 --- a/src/scitex/scholar/citation_graph/__init__.py +++ b/src/scitex/scholar/citation_graph/__init__.py @@ -9,12 +9,14 @@ - Build citation network graphs - Export for visualization (D3.js, vis.js, Cytoscape) -Example: +Example (local SQLite): >>> from scitex.scholar.citation_graph import CitationGraphBuilder - >>> >>> builder = CitationGraphBuilder(db_path="/path/to/crossref.db") >>> graph = builder.build("10.1038/s41586-020-2008-3", top_n=20) - >>> builder.export_json(graph, "network.json") + +Example (HTTP via crossref-local): + >>> builder = CitationGraphBuilder(api_url="http://localhost:31291") + >>> graph = builder.build("10.1038/s41586-020-2008-3", top_n=20) """ from .builder import CitationGraphBuilder diff --git a/src/scitex/scholar/citation_graph/builder.py b/src/scitex/scholar/citation_graph/builder.py index 660302614..ef1d34741 100755 --- a/src/scitex/scholar/citation_graph/builder.py +++ b/src/scitex/scholar/citation_graph/builder.py @@ -5,7 +5,6 @@ """ import json -from collections import Counter from pathlib import Path from typing import List, Optional @@ -17,21 +16,37 @@ class CitationGraphBuilder: """ Build citation network graphs for academic papers. - Example: - >>> builder = CitationGraphBuilder("/path/to/crossref.db") + Example (SQLite): + >>> builder = CitationGraphBuilder(db_path="/path/to/crossref.db") + >>> graph = builder.build("10.1038/s41586-020-2008-3", top_n=20) + + Example (HTTP via crossref-local): + >>> builder = CitationGraphBuilder(api_url="http://localhost:31291") >>> graph = builder.build("10.1038/s41586-020-2008-3", top_n=20) - >>> builder.export_json(graph, "network.json") """ - def __init__(self, db_path: str): + def __init__(self, db_path: str = None, api_url: str = None): """ - Initialize builder with database path. + Initialize builder with database path or HTTP API URL. Args: - db_path: Path to CrossRef SQLite database + db_path: Path to CrossRef SQLite database (local mode) + api_url: URL of crossref-local HTTP API (HTTP mode) + + Raises + ------ + ValueError: If neither db_path nor api_url is provided """ - self.db_path = db_path - self.db = CitationDatabase(db_path) + if api_url: + from .database_http import CitationDatabaseHTTP + + self.db_path = None + self.db = CitationDatabaseHTTP(api_url) + elif db_path: + self.db_path = db_path + self.db = CitationDatabase(db_path) + else: + raise ValueError("Either db_path or api_url is required") def build( self, @@ -51,7 +66,8 @@ def build( weight_cocitation: Weight for co-citation weight_direct: Weight for direct citations - Returns: + Returns + ------- CitationGraph object with nodes and edges """ with self.db: @@ -100,7 +116,8 @@ def _create_paper_node(self, doi: str, similarity_score: float) -> PaperNode: doi: DOI of the paper similarity_score: Calculated similarity score - Returns: + Returns + ------- PaperNode object """ metadata = self.db.get_paper_metadata(doi) @@ -142,7 +159,8 @@ def _build_citation_edges(self, dois: List[str]) -> List[CitationEdge]: Args: dois: List of DOIs in the network - Returns: + Returns + ------- List of CitationEdge objects """ edges = [] @@ -183,7 +201,8 @@ def get_paper_summary(self, doi: str) -> Optional[dict]: Args: doi: DOI of the paper - Returns: + Returns + ------- Dictionary with paper summary """ with self.db: diff --git a/src/scitex/scholar/citation_graph/database_http.py b/src/scitex/scholar/citation_graph/database_http.py new file mode 100755 index 000000000..7d8dcbfdd --- /dev/null +++ b/src/scitex/scholar/citation_graph/database_http.py @@ -0,0 +1,230 @@ +""" +HTTP-based database access layer for citation graph queries. + +Uses crossref-local HTTP API instead of direct SQLite access. +Implements the same interface as CitationDatabase for drop-in replacement. +""" + +from collections import Counter +from typing import Dict, List, Optional, Tuple + + +class CitationDatabaseHTTP: + """ + HTTP interface for citation graph operations via crossref-local API. + + Drop-in replacement for CitationDatabase that uses the crossref-local + HTTP server instead of direct SQLite database access. + + Example: + >>> db = CitationDatabaseHTTP() # auto-detects from env/config + >>> with db: + ... refs = db.get_references("10.1038/s41586-020-2008-3") + """ + + def __init__(self, api_url: str = None): + """ + Initialize HTTP database connection. + + Args: + api_url: URL of the crossref-local HTTP API server. + If None, auto-detects from CROSSREF_LOCAL_API_URL env var + or crossref_local config defaults. + """ + from crossref_local.remote import RemoteClient + + if api_url is None: + import os + + from crossref_local._core.config import DEFAULT_API_URL + + api_url = os.environ.get("CROSSREF_LOCAL_API_URL", DEFAULT_API_URL) + + self.api_url = api_url + self.client = RemoteClient(api_url) + + def connect(self, read_only: bool = True): + """No-op for HTTP mode (connection is stateless).""" + + def close(self): + """No-op for HTTP mode.""" + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + + def get_references(self, doi: str, limit: int = 100) -> List[str]: + """ + Get papers cited by this DOI (forward citations / references). + + Args: + doi: DOI of the paper + limit: Maximum number of references to return + + Returns + ------- + List of DOIs cited by the paper + """ + return self.client.get_cited(doi, limit=limit) + + def get_citations(self, doi: str, limit: int = 100) -> List[Tuple[str, int]]: + """ + Get papers that cite this DOI (reverse citations). + + Args: + doi: DOI of the paper + limit: Maximum number of citations to return + + Returns + ------- + List of (citing_doi, year) tuples. + Note: year is 0 since the HTTP API doesn't return year with citing DOIs. + """ + citing_dois = self.client.get_citing(doi, limit=limit) + return [(d, 0) for d in citing_dois] + + def get_cocited_papers(self, doi: str, limit: int = 50) -> List[Tuple[str, int]]: + """ + Find papers co-cited with this DOI. + + Computed client-side: find papers that cite the seed, + then count how often other papers appear in their reference lists. + + Args: + doi: DOI of the paper + limit: Maximum number of results + + Returns + ------- + List of (cocited_doi, cocitation_count) tuples + """ + # Get papers that cite this DOI + citing_dois = self.client.get_citing(doi, limit=50) + + # For each citing paper, get its references and count co-occurrences + cocitation_counts = Counter() + for citing_doi in citing_dois[:30]: # Limit HTTP calls + refs = self.client.get_cited(citing_doi, limit=100) + for ref_doi in refs: + if ref_doi.lower() != doi.lower(): + cocitation_counts[ref_doi] += 1 + + return cocitation_counts.most_common(limit) + + def get_bibliographic_coupled_papers( + self, doi: str, limit: int = 50 + ) -> List[Tuple[str, int]]: + """ + Find papers with similar references (bibliographic coupling). + + Computed client-side: get seed's references, then for each reference + find papers that also cite it, and count shared references. + + Args: + doi: DOI of the paper + limit: Maximum number of results + + Returns + ------- + List of (coupled_doi, shared_references_count) tuples + """ + # Get seed paper's references + seed_refs = self.client.get_cited(doi, limit=100) + + # For each reference, find other papers that also cite it + coupling_counts = Counter() + for ref_doi in seed_refs[:30]: # Limit HTTP calls + citers = self.client.get_citing(ref_doi, limit=100) + for citer_doi in citers: + if citer_doi.lower() != doi.lower(): + coupling_counts[citer_doi] += 1 + + return coupling_counts.most_common(limit) + + def get_paper_metadata(self, doi: str) -> Optional[Dict]: + """ + Get metadata for a paper from crossref-local API. + + Args: + doi: DOI of the paper + + Returns + ------- + Dictionary with paper metadata in CrossRef format, or None + """ + work = self.client.get(doi) + if work is None: + return None + + # Convert Work object to CrossRef-style metadata dict + # that CitationGraphBuilder._create_paper_node expects + metadata = { + "title": [work.title] if work.title else ["Unknown"], + "author": [], + "container-title": [work.journal] if work.journal else [], + } + + # Parse authors + if work.authors: + for author_str in work.authors: + parts = author_str.rsplit(" ", 1) + if len(parts) == 2: + metadata["author"].append({"given": parts[0], "family": parts[1]}) + else: + metadata["author"].append({"family": author_str, "given": ""}) + + # Add year in CrossRef date format + if work.year: + metadata["published"] = {"date-parts": [[work.year]]} + + return metadata + + def get_combined_similarity_scores( + self, + seed_doi: str, + weight_coupling: float = 2.0, + weight_cocitation: float = 2.0, + weight_direct: float = 1.0, + max_papers: int = 100, + ) -> Counter: + """ + Calculate combined similarity scores using multiple metrics. + + Same algorithm as CitationDatabase but using HTTP API. + + Args: + seed_doi: DOI of the seed paper + weight_coupling: Weight for bibliographic coupling score + weight_cocitation: Weight for co-citation score + weight_direct: Weight for direct citation score + max_papers: Maximum papers to consider per metric + + Returns + ------- + Counter with {doi: combined_score} + """ + scores = Counter() + + # 1. Bibliographic coupling + coupled = self.get_bibliographic_coupled_papers(seed_doi, limit=max_papers) + for doi, count in coupled: + scores[doi] += count * weight_coupling + + # 2. Co-citation + cocited = self.get_cocited_papers(seed_doi, limit=max_papers) + for doi, count in cocited: + scores[doi] += count * weight_cocitation + + # 3. Direct citations + refs = self.get_references(seed_doi, limit=50) + for doi in refs: + scores[doi] += weight_direct + + citations = self.get_citations(seed_doi, limit=50) + for doi, _ in citations: + scores[doi] += weight_direct + + return scores From 6cb67b3ceb96cca311bb0c1b0b2b5f8e268430c1 Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Sat, 21 Feb 2026 18:27:34 +1100 Subject: [PATCH 14/17] refactor(scholar): auto-detect backend in CitationGraphBuilder via crossref_local.Config Co-Authored-By: Claude Opus 4.6 --- src/scitex/scholar/citation_graph/builder.py | 41 +++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/scitex/scholar/citation_graph/builder.py b/src/scitex/scholar/citation_graph/builder.py index ef1d34741..512ea7ac1 100755 --- a/src/scitex/scholar/citation_graph/builder.py +++ b/src/scitex/scholar/citation_graph/builder.py @@ -16,26 +16,32 @@ class CitationGraphBuilder: """ Build citation network graphs for academic papers. - Example (SQLite): - >>> builder = CitationGraphBuilder(db_path="/path/to/crossref.db") + Auto-detects backend via crossref_local.Config (DB → HTTP). + + Example (auto-detect): + >>> builder = CitationGraphBuilder() >>> graph = builder.build("10.1038/s41586-020-2008-3", top_n=20) - Example (HTTP via crossref-local): + Example (explicit SQLite): + >>> builder = CitationGraphBuilder(db_path="/path/to/crossref.db") + + Example (explicit HTTP): >>> builder = CitationGraphBuilder(api_url="http://localhost:31291") - >>> graph = builder.build("10.1038/s41586-020-2008-3", top_n=20) """ def __init__(self, db_path: str = None, api_url: str = None): """ - Initialize builder with database path or HTTP API URL. + Initialize builder with database path, HTTP API URL, or auto-detect. + + When no args given, delegates to crossref_local.Config for auto-detection: + 1. CROSSREF_LOCAL_MODE env var (explicit "db" or "http") + 2. CROSSREF_LOCAL_API_URL env var → HTTP mode + 3. Local DB file existence → DB mode + 4. Fallback to HTTP mode Args: db_path: Path to CrossRef SQLite database (local mode) api_url: URL of crossref-local HTTP API (HTTP mode) - - Raises - ------ - ValueError: If neither db_path nor api_url is provided """ if api_url: from .database_http import CitationDatabaseHTTP @@ -46,7 +52,22 @@ def __init__(self, db_path: str = None, api_url: str = None): self.db_path = db_path self.db = CitationDatabase(db_path) else: - raise ValueError("Either db_path or api_url is required") + self._auto_detect() + + def _auto_detect(self): + """Auto-detect backend via crossref_local.Config.""" + from crossref_local._core.config import Config + + mode = Config.get_mode() + + if mode == "db": + self.db_path = str(Config.get_db_path()) + self.db = CitationDatabase(self.db_path) + else: + from .database_http import CitationDatabaseHTTP + + self.db_path = None + self.db = CitationDatabaseHTTP(Config.get_api_url()) def build( self, From eab0e942c34ae2eeab124ae7ef74a265e377a2ed Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Sat, 21 Feb 2026 21:25:47 +1100 Subject: [PATCH 15/17] feat(mcp): auto-bridge all figrecipe tools programmatically MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace hard-coded plt/diagram wrappers with dynamic auto-bridge: - plt.py: registers all 50 plt_* tools from figrecipe automatically - diagram.py: registers all 9 diagram_* tools automatically - fr.py (new): registers all 14 fr_* tools automatically - __init__.py: adds register_fr_tools Total: 133 → 194 tools. Any future figrecipe tool additions are now picked up automatically without code changes. Co-Authored-By: Claude Sonnet 4.6 --- src/scitex/_mcp_tools/__init__.py | 4 + src/scitex/_mcp_tools/diagram.py | 258 ++-------------------------- src/scitex/_mcp_tools/fr.py | 32 ++++ src/scitex/_mcp_tools/plt.py | 269 ++---------------------------- 4 files changed, 57 insertions(+), 506 deletions(-) create mode 100755 src/scitex/_mcp_tools/fr.py diff --git a/src/scitex/_mcp_tools/__init__.py b/src/scitex/_mcp_tools/__init__.py index 6a0984cb4..778e18a56 100755 --- a/src/scitex/_mcp_tools/__init__.py +++ b/src/scitex/_mcp_tools/__init__.py @@ -11,9 +11,11 @@ from .dataset import register_dataset_tools from .dev import register_dev_tools from .diagram import register_diagram_tools +from .fr import register_fr_tools from .introspect import register_introspect_tools from .linter import register_linter_tools from .plt import register_plt_tools +from .project import register_project_tools from .scholar import register_scholar_tools from .social import register_social_tools from .stats import register_stats_tools @@ -33,9 +35,11 @@ def register_all_tools(mcp) -> None: register_dataset_tools(mcp) register_dev_tools(mcp) register_diagram_tools(mcp) + register_fr_tools(mcp) register_introspect_tools(mcp) register_linter_tools(mcp) register_plt_tools(mcp) + register_project_tools(mcp) register_scholar_tools(mcp) register_social_tools(mcp) register_stats_tools(mcp) diff --git a/src/scitex/_mcp_tools/diagram.py b/src/scitex/_mcp_tools/diagram.py index 46cf87407..1e5be687b 100755 --- a/src/scitex/_mcp_tools/diagram.py +++ b/src/scitex/_mcp_tools/diagram.py @@ -1,270 +1,32 @@ #!/usr/bin/env python3 -# Timestamp: 2026-01-24 +# Timestamp: 2026-02-21 # File: /home/ywatanabe/proj/scitex-code/src/scitex/_mcp_tools/diagram.py """Diagram module tools for FastMCP unified server. -This module delegates to figrecipe's diagram implementation for single source of truth. -All diagram_* tools are thin wrappers around figrecipe's canonical implementation. +Programmatically bridges all figrecipe diagram_* tools into scitex MCP. +No manual wrapping — any new figrecipe diagram tool appears automatically. """ from __future__ import annotations -import os -from typing import Any, Dict, Literal, Optional # noqa: F401 - def register_diagram_tools(mcp) -> None: - """Register diagram tools with FastMCP server. - - Delegates to figrecipe's diagram tools (canonical source). - Tools are prefixed with 'diagram_' for scitex namespace consistency. - """ - # Ensure branding is set before any figrecipe imports - os.environ.setdefault("FIGRECIPE_BRAND", "scitex.diagram") - os.environ.setdefault("FIGRECIPE_ALIAS", "diagram") - - # Check if figrecipe is available + """Register all figrecipe diagram_* tools with the FastMCP server.""" try: - from figrecipe import Diagram - - _FIGRECIPE_AVAILABLE = True + from figrecipe._mcp import server as fr_mcp except ImportError: - _FIGRECIPE_AVAILABLE = False - - if not _FIGRECIPE_AVAILABLE: @mcp.tool() def diagram_not_available() -> str: """[diagram] figrecipe not installed.""" - return "figrecipe is required for diagram tools. Install with: pip install figrecipe" + return "figrecipe is required. Install with: pip install figrecipe" return - def _load_diagram(spec_dict, spec_path): - """Load a Diagram from spec_dict or spec_path.""" - if spec_path: - return Diagram.from_yaml(spec_path) - elif spec_dict: - return Diagram.from_dict(spec_dict) - else: - raise ValueError("Either spec_dict or spec_path must be provided") - - @mcp.tool() - def diagram_create( - spec_dict: Optional[Dict[str, Any]] = None, - spec_path: Optional[str] = None, - ) -> Dict[str, Any]: - """Create a diagram from a YAML specification file or dictionary. - - **PRIMARY DIAGRAM TOOL - Use this for creating diagrams.** - - This is the recommended entry point for diagram creation using SciTeX's - publication-optimized diagram system. It generates both Mermaid and - Graphviz representations from your specification. - - **Available Themes:** - - MATPLOTLIB: Matplotlib color scheme - - SCITEX: SciTeX publication theme (RECOMMENDED, DEFAULT) - - **Available Presets:** - - workflow: Left-to-right flow, rounded boxes (process diagrams) - - decision: Top-down flow, diamond nodes (flowcharts) - - pipeline: Left-to-right flow, data cylinders (data pipelines) - - scientific: Top-down flow, clean academic style (methods diagrams) - - Parameters - ---------- - spec_dict : dict, optional - Diagram specification as dictionary. Required keys: nodes, edges. - Optional keys: metadata, preset, groups, theme. - - spec_path : str, optional - Path to YAML specification file. Alternative to spec_dict. - - Returns - ------- - dict - Dictionary with 'mermaid' and 'graphviz' string representations. - - Examples - -------- - Create a simple workflow diagram: - - >>> spec = { - ... "preset": "workflow", - ... "theme": "SCITEX", - ... "nodes": [ - ... {"id": "input", "label": "Raw Data"}, - ... {"id": "process", "label": "Analysis"}, - ... {"id": "output", "label": "Results"} - ... ], - ... "edges": [ - ... {"from": "input", "to": "process"}, - ... {"from": "process", "to": "output"} - ... ] - ... } - >>> diagram_create(spec_dict=spec) - """ - d = _load_diagram(spec_dict, spec_path) - return { - "mermaid": d.to_mermaid(), - "graphviz": d.to_graphviz(), - "nodes": len(d.spec.nodes), - "edges": len(d.spec.edges), - "success": True, - } - - @mcp.tool() - def diagram_compile_mermaid( - spec_dict: Optional[Dict[str, Any]] = None, - spec_path: Optional[str] = None, - output_path: Optional[str] = None, - ) -> Dict[str, Any]: - """Compile diagram specification to Mermaid format. - - Parameters - ---------- - spec_dict : dict, optional - Diagram specification as dictionary. - - spec_path : str, optional - Path to YAML specification file. - - output_path : str, optional - Path to save .mmd file. If not specified, returns the Mermaid string only. - - Returns - ------- - dict - Dictionary with 'mermaid' string and 'output_path' (if saved). - """ - d = _load_diagram(spec_dict, spec_path) - mermaid = d.to_mermaid(output_path) - return {"mermaid": mermaid, "output_path": output_path, "success": True} - - @mcp.tool() - def diagram_compile_graphviz( - spec_dict: Optional[Dict[str, Any]] = None, - spec_path: Optional[str] = None, - output_path: Optional[str] = None, - ) -> Dict[str, Any]: - """Compile diagram specification to Graphviz DOT format. - - Parameters - ---------- - spec_dict : dict, optional - Diagram specification as dictionary. - - spec_path : str, optional - Path to YAML specification file. - - output_path : str, optional - Path to save .dot file. If not specified, returns the DOT string only. - - Returns - ------- - dict - Dictionary with 'graphviz' string and 'output_path' (if saved). - """ - d = _load_diagram(spec_dict, spec_path) - graphviz = d.to_graphviz(output_path) - return {"graphviz": graphviz, "output_path": output_path, "success": True} - - @mcp.tool() - def diagram_render( - spec_dict: Optional[Dict[str, Any]] = None, - spec_path: Optional[str] = None, - output_path: str = "", - format: Literal["png", "svg", "pdf"] = "png", - backend: Literal["auto", "mermaid-cli", "graphviz", "mermaid.ink"] = "auto", - scale: float = 2.0, - ) -> Dict[str, Any]: - """Render diagram to image file (PNG, SVG, PDF). - - Parameters - ---------- - spec_dict : dict, optional - Diagram specification as dictionary. - - spec_path : str, optional - Path to YAML specification file. - - output_path : str - Path to save the rendered image. - - format : str - Output format: png, svg, or pdf. - - backend : str - Rendering backend: - - auto: Automatically choose best available backend - - mermaid-cli: Use mermaid-cli (requires npm install -g @mermaid-js/mermaid-cli) - - graphviz: Use Graphviz dot command - - mermaid.ink: Use online Mermaid.ink service (no local install required) - - scale : float - Scale factor for rendering (default: 2.0 for high DPI). - - Returns - ------- - dict - Dictionary with 'output_path' and 'success' status. - """ - if not output_path: - raise ValueError("output_path is required") - d = _load_diagram(spec_dict, spec_path) - result_path = d.render(output_path, format=format, backend=backend, scale=scale) - return { - "output_path": str(result_path), - "format": format, - "backend": backend, - "success": True, - } - - @mcp.tool() - def diagram_split( - spec_path: str, - max_nodes_per_part: int = 10, - strategy: Literal["by_groups", "by_articulation"] = "by_groups", - ) -> Dict[str, Any]: - """Split a large diagram into smaller parts for multi-column layouts. - - Useful for complex diagrams that need to be broken down into - manageable pieces for publication layouts. - - Parameters - ---------- - spec_path : str - Path to the YAML specification file. - - max_nodes_per_part : int - Maximum number of nodes per split part (default: 10). - - strategy : str - Splitting strategy: - - by_groups: Split based on node groups defined in spec - - by_articulation: Split at articulation points in the graph - - Returns - ------- - dict - Dictionary with split parts and metadata about the splitting. - """ - d = Diagram.from_yaml(spec_path) - parts = d.split(max_nodes=max_nodes_per_part, strategy=strategy) - return { - "parts": [ - { - "title": p.spec.title, - "nodes": len(p.spec.nodes), - "edges": len(p.spec.edges), - "mermaid": p.to_mermaid(), - } - for p in parts - ], - "num_parts": len(parts), - "success": True, - } + tools = fr_mcp.mcp._tool_manager._tools + for name, tool in tools.items(): + if name.startswith("diagram_"): + mcp.add_tool(tool) # EOF diff --git a/src/scitex/_mcp_tools/fr.py b/src/scitex/_mcp_tools/fr.py new file mode 100755 index 000000000..7e48d771b --- /dev/null +++ b/src/scitex/_mcp_tools/fr.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Timestamp: 2026-02-21 +# File: /home/ywatanabe/proj/scitex-code/src/scitex/_mcp_tools/fr.py +"""FigRecipe specialized plot tools for FastMCP unified server. + +Programmatically bridges all figrecipe fr_* tools into scitex MCP. +No manual wrapping — any new figrecipe fr tool appears automatically. +""" + +from __future__ import annotations + + +def register_fr_tools(mcp) -> None: + """Register all figrecipe fr_* tools with the FastMCP server.""" + try: + from figrecipe._mcp import server as fr_mcp + except ImportError: + + @mcp.tool() + def fr_not_available() -> str: + """[fr] figrecipe not installed.""" + return "figrecipe is required. Install with: pip install figrecipe" + + return + + tools = fr_mcp.mcp._tool_manager._tools + for name, tool in tools.items(): + if name.startswith("fr_"): + mcp.add_tool(tool) + + +# EOF diff --git a/src/scitex/_mcp_tools/plt.py b/src/scitex/_mcp_tools/plt.py index 9f59b7f1e..e0649fa14 100755 --- a/src/scitex/_mcp_tools/plt.py +++ b/src/scitex/_mcp_tools/plt.py @@ -1,281 +1,34 @@ #!/usr/bin/env python3 -# Timestamp: 2026-01-24 +# Timestamp: 2026-02-21 # File: /home/ywatanabe/proj/scitex-code/src/scitex/_mcp_tools/plt.py """Plt module tools for FastMCP unified server. -This module delegates to figrecipe's MCP tools for single source of truth. -All plt_* tools are thin wrappers around figrecipe's canonical implementation. +Programmatically bridges all figrecipe plt_* tools into scitex MCP. +No manual wrapping — any new figrecipe plt tool appears automatically. """ from __future__ import annotations -import os -from typing import Any, Dict, List, Literal, Optional, Tuple # noqa: F401 - def register_plt_tools(mcp) -> None: - """Register plt tools with FastMCP server. - - Delegates to figrecipe's plt tools (canonical source). - Tools are prefixed with 'plt_' for scitex namespace consistency. - """ - # Ensure branding is set before any figrecipe imports - os.environ.setdefault("FIGRECIPE_BRAND", "scitex.plt") - os.environ.setdefault("FIGRECIPE_ALIAS", "plt") - - # Check if figrecipe is available + """Register all figrecipe plt_* tools with the FastMCP server.""" try: from figrecipe._mcp import server as fr_mcp - - # Access underlying functions from FunctionTool objects - # figrecipe tools are named plt_* for proper MCP categorization - _plot = fr_mcp.plt_plot.fn - _reproduce = fr_mcp.plt_reproduce.fn - _compose = fr_mcp.plt_compose.fn - _info = fr_mcp.plt_info.fn - _validate = fr_mcp.plt_validate.fn - _crop = fr_mcp.plt_crop.fn - _extract_data = fr_mcp.plt_extract_data.fn - - _FIGRECIPE_AVAILABLE = True except ImportError: - _FIGRECIPE_AVAILABLE = False - - if not _FIGRECIPE_AVAILABLE: @mcp.tool() def plt_not_available() -> str: """[plt] figrecipe not installed.""" - return "figrecipe is required for plt tools. Install with: pip install figrecipe" + return "figrecipe is required. Install with: pip install figrecipe" return - # Delegate to figrecipe's MCP tools with plt_ prefix - # Each wrapper simply calls the figrecipe function - - @mcp.tool() - def plt_plot( - spec: Dict[str, Any], - output_path: str, - dpi: int = 300, - save_recipe: bool = True, - ) -> Dict[str, Any]: - """[plt] Create a matplotlib figure from a declarative specification. - - Parameters - ---------- - spec : dict - Declarative specification. Key sections: figure, plots, stat_annotations, - xlabel, ylabel, title, legend, xlim, ylim. - - Plot types: line, plot, step, fill, fill_between, fill_betweenx, errorbar, - scatter, bar, barh, hist, hist2d, boxplot, box, violinplot, violin, imshow, - matshow, pcolormesh, contour, contourf, pie, stem, eventplot, hexbin, - specgram, psd, heatmap. - - Style presets: MATPLOTLIB, SCITEX (set via figure.style). - output_path : str - Path to save the output figure. - dpi : int - DPI for raster output (default: 300). - save_recipe : bool - If True, also save as figrecipe YAML recipe. - - Returns - ------- - dict - Result with 'image_path' and 'recipe_path'. - """ - return _plot(spec, output_path, dpi, save_recipe) - - @mcp.tool() - def plt_reproduce( - recipe_path: str, - output_path: Optional[str] = None, - format: Literal["png", "pdf", "svg"] = "png", - dpi: int = 300, - ) -> Dict[str, Any]: - """[plt] Reproduce a figure from a saved YAML recipe. - - Parameters - ---------- - recipe_path : str - Path to the .yaml recipe file. - - output_path : str, optional - Output path for the reproduced figure. - Defaults to recipe_path with .reproduced.{format} suffix. - - format : str - Output format: png, pdf, or svg. - - dpi : int - DPI for raster output. - - Returns - ------- - dict - Result with 'output_path' and 'success'. - """ - return _reproduce(recipe_path, output_path, format, dpi) - - @mcp.tool() - def plt_compose( - sources: List[str], - output_path: str, - layout: Literal["horizontal", "vertical", "grid"] = "horizontal", - gap_mm: float = 5.0, - dpi: int = 300, - panel_labels: bool = True, - label_style: Literal["uppercase", "lowercase", "numeric"] = "uppercase", - caption: Optional[str] = None, - create_symlinks: bool = True, - canvas_size_mm: Optional[Tuple[float, float]] = None, - facecolor: str = "white", - ) -> Dict[str, Any]: - """[plt] Compose multiple figures into a single figure with panel labels. - - Supports two modes: - 1. Grid-based layout (list sources): automatic arrangement with layout parameter - 2. Free-form positioning (dict sources): precise mm-based positioning - - Parameters - ---------- - sources : list of str or dict - Either: - - List of paths to source images or recipe files (grid-based layout) - - Dict mapping source paths to positioning specs with 'xy_mm' and 'size_mm': - {"panel_a.yaml": {"xy_mm": [0, 0], "size_mm": [80, 50]}, ...} - output_path : str - Path to save the composed figure. - layout : str - Layout mode for list sources: 'horizontal', 'vertical', or 'grid'. - Ignored when using dict sources with mm positioning. - gap_mm : float - Gap between panels in millimeters (for grid-based layout only). - dpi : int - DPI for output. - panel_labels : bool - If True, add panel labels (A, B, C, D) automatically. - label_style : str - Style: 'uppercase' (A,B,C), 'lowercase' (a,b,c), 'numeric' (1,2,3). - caption : str, optional - Figure caption to add below. - create_symlinks : bool - If True (default), create symlinks to source files for traceability. - canvas_size_mm : tuple of (float, float), optional - Canvas size as (width_mm, height_mm) for free-form positioning. - Required when sources is a dict with mm positioning. - facecolor : str - Background color for the composed figure. Default is 'white'. - All source panels are flattened onto this background to ensure - consistent appearance regardless of original panel transparency. - - Returns - ------- - dict - Result with 'output_path', 'success', and 'sources_dir' (if symlinks created). - """ - return _compose( - sources, - output_path, - layout, - gap_mm, - dpi, - panel_labels, - label_style, - caption, - create_symlinks, - canvas_size_mm, - facecolor, - ) - - @mcp.tool() - def plt_info(recipe_path: str, verbose: bool = False) -> Dict[str, Any]: - """[plt] Get information about a recipe file. - - Parameters - ---------- - recipe_path : str - Path to the .yaml recipe file. - - verbose : bool - If True, include detailed call information. - - Returns - ------- - dict - Recipe information including figure dimensions, call counts, etc. - """ - return _info(recipe_path, verbose) - - @mcp.tool() - def plt_validate( - recipe_path: str, - mse_threshold: float = 100.0, - ) -> Dict[str, Any]: - """[plt] Validate that a recipe can reproduce its original figure. - - Parameters - ---------- - recipe_path : str - Path to the .yaml recipe file. - - mse_threshold : float - Maximum acceptable mean squared error (default: 100). - - Returns - ------- - dict - Validation result with 'passed', 'mse', and details. - """ - return _validate(recipe_path, mse_threshold) - - @mcp.tool() - def plt_crop( - input_path: str, - output_path: Optional[str] = None, - margin_mm: float = 1.0, - overwrite: bool = False, - ) -> Dict[str, Any]: - """[plt] Crop whitespace from a figure image. - - Parameters - ---------- - input_path : str - Path to the input image. - - output_path : str, optional - Path for cropped output. Defaults to input with .cropped suffix. - - margin_mm : float - Margin to keep around content in millimeters. - - overwrite : bool - If True, overwrite the input file. - - Returns - ------- - dict - Result with 'output_path' and 'success'. - """ - return _crop(input_path, output_path, margin_mm, overwrite) - - @mcp.tool() - def plt_extract_data(recipe_path: str) -> Dict[str, Dict[str, Any]]: - """[plt] Extract plotted data arrays from a saved recipe. - - Parameters - ---------- - recipe_path : str - Path to the .yaml recipe file. - - Returns - ------- - dict - Nested dict: {call_id: {'x': list, 'y': list, ...}} - """ - return _extract_data(recipe_path) + tools = fr_mcp.mcp._tool_manager._tools + registered = 0 + for name, tool in tools.items(): + if name.startswith("plt_"): + mcp.add_tool(tool) + registered += 1 # EOF From 9cfdc4185acf2646a08778b78d8bac2994505ed0 Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Sat, 21 Feb 2026 21:31:55 +1100 Subject: [PATCH 16/17] feat(mcp): rename figrecipe tools with scitex branding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply scitex branding via prefix renaming: - fr_* → plt_stx_* (e.g. fr_conf_mat → plt_stx_conf_mat) - diagram_* → plt_diagram_* (e.g. diagram_create → plt_diagram_create) All 73 figrecipe tools now appear under a single 'plt' module category. Co-Authored-By: Claude Sonnet 4.6 --- src/scitex/_mcp_tools/diagram.py | 16 ++++++++++------ src/scitex/_mcp_tools/fr.py | 16 ++++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/scitex/_mcp_tools/diagram.py b/src/scitex/_mcp_tools/diagram.py index 1e5be687b..552ef2445 100755 --- a/src/scitex/_mcp_tools/diagram.py +++ b/src/scitex/_mcp_tools/diagram.py @@ -3,22 +3,25 @@ # File: /home/ywatanabe/proj/scitex-code/src/scitex/_mcp_tools/diagram.py """Diagram module tools for FastMCP unified server. -Programmatically bridges all figrecipe diagram_* tools into scitex MCP. -No manual wrapping — any new figrecipe diagram tool appears automatically. +Programmatically bridges all figrecipe diagram_* tools into scitex MCP, +renamed as plt_diagram_* for consistent scitex branding. + diagram_create → plt_diagram_create + diagram_render → plt_diagram_render + ... """ from __future__ import annotations def register_diagram_tools(mcp) -> None: - """Register all figrecipe diagram_* tools with the FastMCP server.""" + """Register figrecipe diagram_* tools as plt_diagram_* in the FastMCP server.""" try: from figrecipe._mcp import server as fr_mcp except ImportError: @mcp.tool() - def diagram_not_available() -> str: - """[diagram] figrecipe not installed.""" + def plt_diagram_not_available() -> str: + """[plt] figrecipe not installed.""" return "figrecipe is required. Install with: pip install figrecipe" return @@ -26,7 +29,8 @@ def diagram_not_available() -> str: tools = fr_mcp.mcp._tool_manager._tools for name, tool in tools.items(): if name.startswith("diagram_"): - mcp.add_tool(tool) + new_name = "plt_diagram_" + name[len("diagram_") :] + mcp.add_tool(tool.model_copy(update={"name": new_name})) # EOF diff --git a/src/scitex/_mcp_tools/fr.py b/src/scitex/_mcp_tools/fr.py index 7e48d771b..2a475b6f1 100755 --- a/src/scitex/_mcp_tools/fr.py +++ b/src/scitex/_mcp_tools/fr.py @@ -3,22 +3,25 @@ # File: /home/ywatanabe/proj/scitex-code/src/scitex/_mcp_tools/fr.py """FigRecipe specialized plot tools for FastMCP unified server. -Programmatically bridges all figrecipe fr_* tools into scitex MCP. -No manual wrapping — any new figrecipe fr tool appears automatically. +Programmatically bridges all figrecipe fr_* tools into scitex MCP, +renamed as plt_stx_* for consistent scitex branding. + fr_conf_mat → plt_stx_conf_mat + fr_ecdf → plt_stx_ecdf + ... """ from __future__ import annotations def register_fr_tools(mcp) -> None: - """Register all figrecipe fr_* tools with the FastMCP server.""" + """Register figrecipe fr_* tools as plt_stx_* in the FastMCP server.""" try: from figrecipe._mcp import server as fr_mcp except ImportError: @mcp.tool() - def fr_not_available() -> str: - """[fr] figrecipe not installed.""" + def plt_stx_not_available() -> str: + """[plt] figrecipe not installed.""" return "figrecipe is required. Install with: pip install figrecipe" return @@ -26,7 +29,8 @@ def fr_not_available() -> str: tools = fr_mcp.mcp._tool_manager._tools for name, tool in tools.items(): if name.startswith("fr_"): - mcp.add_tool(tool) + new_name = "plt_stx_" + name[len("fr_") :] + mcp.add_tool(tool.model_copy(update={"name": new_name})) # EOF From 165dd35ff22d5a2ec27179d20e1f7868259af6c7 Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Sat, 21 Feb 2026 23:47:01 +1100 Subject: [PATCH 17/17] fix(plt/hitmap): handle RecordingFigure nested axes in _get_flat_axes stx.plt.subplots() returns a RecordingFigure whose .axes property yields nested lists ([[ax1], [ax2]]) instead of flat [ax1, ax2]. Add _get_flat_axes() helper that flattens nested structures, and use it in all three functions (get_all_artists, detect_logical_groups, get_all_artists_with_groups). This fixes AttributeError: 'list' object has no attribute 'get_lines' when generate_hitmap_id_colors() is called with a RecordingFigure. Color map size now correctly reflects the number of plotted elements (was always 0). Co-Authored-By: Claude Sonnet 4.6 --- .../plt/utils/_hitmap/_artist_extraction.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/scitex/plt/utils/_hitmap/_artist_extraction.py b/src/scitex/plt/utils/_hitmap/_artist_extraction.py index f46cc52bc..ae4ce8bd6 100755 --- a/src/scitex/plt/utils/_hitmap/_artist_extraction.py +++ b/src/scitex/plt/utils/_hitmap/_artist_extraction.py @@ -11,6 +11,8 @@ from typing import Any, Dict, List, Optional, Tuple +import matplotlib.axes as _mpl_axes + __all__ = [ "get_all_artists", "get_all_artists_with_groups", @@ -18,6 +20,25 @@ ] +def _get_flat_axes(fig) -> List[Any]: + """Return a flat list of matplotlib Axes from a figure. + + RecordingFigure (from stx.plt.subplots) may expose fig.axes as a nested + list such as [[ax1], [ax2]], whereas plain matplotlib uses [ax1, ax2]. + This helper normalises both cases. + """ + raw = fig.axes if hasattr(fig, "axes") else [] + flat: List[Any] = [] + for item in raw: + if isinstance(item, (list, tuple)): + for subitem in item: + if isinstance(subitem, _mpl_axes.Axes): + flat.append(subitem) + elif isinstance(item, _mpl_axes.Axes): + flat.append(item) + return flat + + def get_all_artists(fig, include_text: bool = False) -> List[Tuple[Any, int, str]]: """ Extract all selectable artists from a figure. @@ -36,7 +57,7 @@ def get_all_artists(fig, include_text: bool = False) -> List[Tuple[Any, int, str """ artists = [] - for ax_idx, ax in enumerate(fig.axes): + for ax_idx, ax in enumerate(_get_flat_axes(fig)): # Lines (Line2D) for line in ax.get_lines(): label = line.get_label() @@ -114,7 +135,7 @@ def get_group_id(group_type: str, ax_idx: int) -> str: group_counter[key] += 1 return f"{group_type}_{ax_idx}_{idx}" - for ax_idx, ax in enumerate(fig.axes): + for ax_idx, ax in enumerate(_get_flat_axes(fig)): # Detect BarContainers (covers bar charts and histograms) bar_containers = [ c for c in ax.containers if "BarContainer" in type(c).__name__ @@ -277,7 +298,7 @@ def get_all_artists_with_groups( artists_with_groups = [] - for ax_idx, ax in enumerate(fig.axes): + for ax_idx, ax in enumerate(_get_flat_axes(fig)): # Lines for line in ax.get_lines(): label = line.get_label()