diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml
new file mode 100644
index 0000000..fedcecb
--- /dev/null
+++ b/.github/workflows/deploy.yml
@@ -0,0 +1,48 @@
+# This file was created automatically with `myst init --gh-pages` 🪄 💚
+# Ensure your GitHub Pages settings for this repository are set to deploy with **GitHub Actions**.
+
+name: MyST GitHub Pages Deploy
+on:
+ push:
+ # Runs on pushes targeting the default branch
+ branches: [main]
+env:
+ # `BASE_URL` determines, relative to the root of the domain, the URL that your site is served from.
+ # E.g., if your site lives at `https://mydomain.org/myproject`, set `BASE_URL=/myproject`.
+ # If, instead, your site lives at the root of the domain, at `https://mydomain.org`, set `BASE_URL=''`.
+ BASE_URL: /${{ github.event.repository.name }}
+
+# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
+permissions:
+ contents: read
+ pages: write
+ id-token: write
+# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
+# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete.
+concurrency:
+ group: 'pages'
+ cancel-in-progress: false
+jobs:
+ deploy:
+ environment:
+ name: github-pages
+ url: ${{ steps.deployment.outputs.page_url }}
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Setup Pages
+ uses: actions/configure-pages@v3
+ - uses: actions/setup-node@v4
+ with:
+ node-version: 18.x
+ - name: Install MyST
+ run: npm install -g mystmd
+ - name: Build HTML Assets
+ run: myst build --html
+ - name: Upload artifact
+ uses: actions/upload-pages-artifact@v3
+ with:
+ path: './_build/html'
+ - name: Deploy to GitHub Pages
+ id: deployment
+ uses: actions/deploy-pages@v4
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 3d7daae..8a9085b 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -30,7 +30,7 @@ jobs:
sudo Rscript -e 'install.packages("clue", repos="https://cloud.r-project.org")'
- name: Install Python dependencies
run: |
- uv pip install -e ".[dev,docs,matching]" --system
+ uv pip install -e ".[dev,docs,matching,images]" --system
- name: Run tests with coverage
run: make test
- name: Upload coverage to Codecov
@@ -39,20 +39,3 @@ jobs:
file: ./coverage.xml
fail_ci_if_error: false
verbose: true
- - name: Test documentation builds
- run: make documentation
- - name: Check documentation build
- run: |
- for notebook in $(find docs/_build/jupyter_execute -name "*.ipynb"); do
- if grep -q '"output_type": "error"' "$notebook"; then
- echo "Error found in $notebook"
- cat "$notebook"
- exit 1
- fi
- done
- - name: Deploy documentation
- uses: JamesIves/github-pages-deploy-action@releases/v3
- with:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- BRANCH: gh-pages # The branch the action should deploy to.
- FOLDER: docs/_build/html # The folder the action should deploy.
diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr_code_changes.yaml
similarity index 58%
rename from .github/workflows/pr.yaml
rename to .github/workflows/pr_code_changes.yaml
index 3081697..13feb80 100644
--- a/.github/workflows/pr.yaml
+++ b/.github/workflows/pr_code_changes.yaml
@@ -57,40 +57,4 @@ jobs:
with:
file: ./coverage.xml
fail_ci_if_error: false
- verbose: true
-
- Build:
- runs-on: ubuntu-latest
- steps:
- - name: Checkout repo
- uses: actions/checkout@v3
- - name: Install uv
- uses: astral-sh/setup-uv@v5
- - name: Set up Python
- uses: actions/setup-python@v4
- with:
- python-version: "3.13"
- - name: Install R and dependencies
- run: |
- sudo apt-get update
- sudo apt-get install -y r-base r-base-dev libtirpc-dev
- - name: Install R packages
- run: |
- sudo Rscript -e 'install.packages("StatMatch", repos="https://cloud.r-project.org")'
- sudo Rscript -e 'install.packages("clue", repos="https://cloud.r-project.org")'
- - name: Install dependencies
- run: |
- uv pip install -e ".[dev,docs,matching]" --system
- - name: Build package
- run: make build
- - name: Test documentation builds
- run: make documentation
- - name: Check documentation build
- run: |
- for notebook in $(find docs/_build/jupyter_execute -name "*.ipynb"); do
- if grep -q '"output_type": "error"' "$notebook"; then
- echo "Error found in $notebook"
- cat "$notebook"
- exit 1
- fi
- done
+ verbose: true
\ No newline at end of file
diff --git a/.github/workflows/pr_docs_changes.yaml b/.github/workflows/pr_docs_changes.yaml
new file mode 100644
index 0000000..4481634
--- /dev/null
+++ b/.github/workflows/pr_docs_changes.yaml
@@ -0,0 +1,52 @@
+# Workflow that runs on code changes to a pull request.
+
+name: Docs changes
+on:
+ pull_request:
+ branches:
+ - main
+
+ paths:
+ - docs/**
+ - .github/**
+ workflow_dispatch:
+
+jobs:
+ Test:
+ runs-on: ubuntu-latest
+ name: Test documentation builds
+ steps:
+ - name: Checkout repo
+ uses: actions/checkout@v4
+ - name: Install uv
+ uses: astral-sh/setup-uv@v5
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.13'
+
+ - name: Install R and dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y r-base r-base-dev libtirpc-dev
+ - name: Install R packages
+ run: |
+ sudo Rscript -e 'install.packages("StatMatch", repos="https://cloud.r-project.org")'
+ sudo Rscript -e 'install.packages("clue", repos="https://cloud.r-project.org")'
+ - name: Install dependencies
+ run: |
+ uv pip install -e ".[dev,docs,matching,images]" --system
+ - name: Install JB
+ run: uv pip install "jupyter-book>=2.0.0a0" --system
+ - name: Test documentation builds
+ run: make documentation
+ - name: Check documentation build
+ run: |
+ for notebook in $(find docs/_build/jupyter_execute -name "*.ipynb"); do
+ if grep -q '"output_type": "error"' "$notebook"; then
+ echo "Error found in $notebook"
+ cat "$notebook"
+ exit 1
+ fi
+ done
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 5979578..555168e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -78,4 +78,6 @@ celerybeat.pid
*.csv
*.jpg
*.html
-*.h5
\ No newline at end of file
+*.h5
+# MyST build outputs
+_build
diff --git a/Makefile b/Makefile
index d6f633c..ebc2af6 100644
--- a/Makefile
+++ b/Makefile
@@ -15,8 +15,9 @@ format:
black . -l 79
documentation:
- cd docs && jupyter-book build .
- python docs/add_plotly_to_book.py docs/_build/html
+ cd docs && jupyter book clean . --all
+ cd docs && jupyter book build .
+ python docs/add_plotly_to_book.py docs/_build
build:
pip install build
diff --git a/changelog_entry.yaml b/changelog_entry.yaml
index e69de29..03c1e59 100644
--- a/changelog_entry.yaml
+++ b/changelog_entry.yaml
@@ -0,0 +1,5 @@
+- bump: patch
+ changes:
+ added:
+ - Moved data loading utilities into utils and removed scf downloading functionality.
+ - Updated documentation to reflect the new structure and created a myst.yml file to deploy documentation with new jb v2.
diff --git a/docs/_static/__init__.py b/docs/_static/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/docs/_static/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/docs/_toc.yml b/docs/_toc.yml
index c349757..a8b558d 100644
--- a/docs/_toc.yml
+++ b/docs/_toc.yml
@@ -1,8 +1,6 @@
format: jb-book
root: index
-
parts:
-
- caption: Models
chapters:
- file: models/imputer/index
diff --git a/docs/autoimpute/autoimpute.ipynb b/docs/autoimpute/autoimpute.ipynb
index e727c8b..c2efa86 100644
--- a/docs/autoimpute/autoimpute.ipynb
+++ b/docs/autoimpute/autoimpute.ipynb
@@ -18,11 +18,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
+ "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <668E1903-F0E7-30D5-BA27-15F8287F87F7> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
"Trying to import in ABI mode.\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"PWD\" redefined by R and overriding existing variable. Current: \"/\", R: \"/Users/movil1/Desktop/PYTHONJOBS/PolicyEngine/microimpute/docs/autoimpute\"\n",
+ "/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable \"PWD\" redefined by R and overriding existing variable. Current: \"/\", R: \"/Users/movil1/Desktop/PYTHONJOBS/PolicyEngine/microimpute/docs/autoimpute\"\n",
" warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpzb89RA\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpoTaONA\"\n",
+ "/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpjTEAjd\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpSiMg9b\"\n",
" warnings.warn(\n"
]
}
@@ -248,7 +248,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "1160ab751cc6443e8b95ce12a6895d49",
+ "model_id": "8d2fbfaa071a4c6592a774c35a323c83",
"version_major": 2,
"version_minor": 0
},
@@ -263,21 +263,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 3.1s finished\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Batch computation too fast (0.07565808296203613s.) Setting batch_size=2.\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.1s finished\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 3.4s finished\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.7s finished\n",
+ "[Parallel(n_jobs=-1)]: Batch computation too fast (0.1890571117401123s.) Setting batch_size=2.\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.2s finished\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 3.0s finished\n"
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.9s finished\n",
+ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 3.6s finished\n"
]
},
{
@@ -323,7 +317,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -370,59 +364,59 @@
"
\n",
" \n",
" | QRF | \n",
- " 0.005829 | \n",
- " 0.008561 | \n",
- " 0.011626 | \n",
- " 0.014301 | \n",
- " 0.015824 | \n",
+ " 0.005206 | \n",
+ " 0.007714 | \n",
+ " 0.010652 | \n",
+ " 0.013726 | \n",
+ " 0.017578 | \n",
" ... | \n",
- " 0.015397 | \n",
- " 0.013911 | \n",
- " 0.011076 | \n",
- " 0.007627 | \n",
- " 0.016249 | \n",
+ " 0.016469 | \n",
+ " 0.013081 | \n",
+ " 0.010507 | \n",
+ " 0.006769 | \n",
+ " 0.016371 | \n",
"
\n",
" \n",
" | OLS | \n",
- " 0.003848 | \n",
- " 0.006656 | \n",
- " 0.009042 | \n",
- " 0.011071 | \n",
- " 0.012733 | \n",
+ " 0.003969 | \n",
+ " 0.006727 | \n",
+ " 0.009112 | \n",
+ " 0.011131 | \n",
+ " 0.012739 | \n",
" ... | \n",
- " 0.013373 | \n",
- " 0.011476 | \n",
- " 0.009000 | \n",
- " 0.005670 | \n",
- " 0.012827 | \n",
+ " 0.013079 | \n",
+ " 0.011218 | \n",
+ " 0.008765 | \n",
+ " 0.005431 | \n",
+ " 0.012667 | \n",
"
\n",
" \n",
" | QuantReg | \n",
- " 0.003581 | \n",
- " 0.006514 | \n",
- " 0.008995 | \n",
- " 0.011158 | \n",
- " 0.012975 | \n",
+ " 0.003882 | \n",
+ " 0.006577 | \n",
+ " 0.009084 | \n",
+ " 0.011261 | \n",
+ " 0.012793 | \n",
" ... | \n",
- " 0.013403 | \n",
- " 0.011501 | \n",
- " 0.009050 | \n",
- " 0.005945 | \n",
- " 0.012834 | \n",
+ " 0.013183 | \n",
+ " 0.011169 | \n",
+ " 0.008899 | \n",
+ " 0.005258 | \n",
+ " 0.012713 | \n",
"
\n",
" \n",
" | Matching | \n",
- " 0.023600 | \n",
- " 0.023571 | \n",
- " 0.023543 | \n",
- " 0.023514 | \n",
- " 0.023486 | \n",
+ " 0.024895 | \n",
+ " 0.024740 | \n",
+ " 0.024585 | \n",
+ " 0.024430 | \n",
+ " 0.024275 | \n",
" ... | \n",
- " 0.023172 | \n",
- " 0.023144 | \n",
- " 0.023115 | \n",
- " 0.023087 | \n",
- " 0.023343 | \n",
+ " 0.022570 | \n",
+ " 0.022415 | \n",
+ " 0.022260 | \n",
+ " 0.022105 | \n",
+ " 0.023500 | \n",
"
\n",
" \n",
"\n",
@@ -431,15 +425,15 @@
],
"text/plain": [
" 0.05 0.1 0.15 0.2 0.25 ... 0.8 0.85 0.9 0.95 mean_loss\n",
- "QRF 0.005829 0.008561 0.011626 0.014301 0.015824 ... 0.015397 0.013911 0.011076 0.007627 0.016249\n",
- "OLS 0.003848 0.006656 0.009042 0.011071 0.012733 ... 0.013373 0.011476 0.009000 0.005670 0.012827\n",
- "QuantReg 0.003581 0.006514 0.008995 0.011158 0.012975 ... 0.013403 0.011501 0.009050 0.005945 0.012834\n",
- "Matching 0.023600 0.023571 0.023543 0.023514 0.023486 ... 0.023172 0.023144 0.023115 0.023087 0.023343\n",
+ "QRF 0.005206 0.007714 0.010652 0.013726 0.017578 ... 0.016469 0.013081 0.010507 0.006769 0.016371\n",
+ "OLS 0.003969 0.006727 0.009112 0.011131 0.012739 ... 0.013079 0.011218 0.008765 0.005431 0.012667\n",
+ "QuantReg 0.003882 0.006577 0.009084 0.011261 0.012793 ... 0.013183 0.011169 0.008899 0.005258 0.012713\n",
+ "Matching 0.024895 0.024740 0.024585 0.024430 0.024275 ... 0.022570 0.022415 0.022260 0.022105 0.023500\n",
"\n",
"[4 rows x 20 columns]"
]
},
- "execution_count": 6,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -459,7 +453,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -467,7 +461,7 @@
"output_type": "stream",
"text": [
"Best performing method: OLS\n",
- "Average loss: 0.0128\n"
+ "Average loss: 0.0127\n"
]
}
],
@@ -489,7 +483,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -538,25 +532,25 @@
],
"xaxis": "x",
"y": [
- 0.005829386076662433,
- 0.00856105395153346,
- 0.011626175046873516,
- 0.014300647406418544,
- 0.015824304502979173,
- 0.018433289344375187,
- 0.02060722953506577,
- 0.020760025411998718,
- 0.022565452074453166,
- 0.022894803477996075,
- 0.021925716974908502,
- 0.02106669659131533,
- 0.01995566361352964,
- 0.019014029498789217,
- 0.017346619969805206,
- 0.015397201876470158,
- 0.01391050378078308,
- 0.011076481953005384,
- 0.007627050864379574
+ 0.005205574757339217,
+ 0.007714403027329015,
+ 0.01065194199838141,
+ 0.013726455833237,
+ 0.01757780514534066,
+ 0.018107434762520774,
+ 0.019844256732387433,
+ 0.022552991818725138,
+ 0.021767176700496924,
+ 0.02176271884036238,
+ 0.022431454234331308,
+ 0.022519291751133385,
+ 0.02141376217036338,
+ 0.020211412017013857,
+ 0.018733449365168162,
+ 0.016468815556855313,
+ 0.013081288782524089,
+ 0.010506640243809581,
+ 0.006768809609731342
],
"yaxis": "y"
},
@@ -599,25 +593,25 @@
],
"xaxis": "x",
"y": [
- 0.0038477993149538254,
- 0.006656275252079852,
- 0.00904195546246103,
- 0.011070661084408441,
- 0.012732521787295892,
- 0.01415364523812077,
- 0.015265140752399263,
- 0.01607393376072333,
- 0.016645247530872858,
- 0.01703916972578955,
- 0.01718839322103158,
- 0.017071707772460466,
- 0.016674601673214974,
- 0.015917739067342566,
- 0.014815834740048467,
- 0.01337329286563062,
- 0.011475808671119468,
- 0.008999927337815994,
- 0.005669603175410814
+ 0.003968866261144534,
+ 0.006727149092118921,
+ 0.009111943845003724,
+ 0.011131059169560876,
+ 0.012738883924364994,
+ 0.01409263239295278,
+ 0.015194280261362608,
+ 0.0160051275445406,
+ 0.01652413484199293,
+ 0.01685683616057514,
+ 0.016917118207723066,
+ 0.01668155786641982,
+ 0.016236073474765234,
+ 0.015506066708203278,
+ 0.01447908846019153,
+ 0.013079021820500435,
+ 0.01121849349164269,
+ 0.008764648482194161,
+ 0.005431025111958136
],
"yaxis": "y"
},
@@ -660,25 +654,25 @@
],
"xaxis": "x",
"y": [
- 0.0035810390219439096,
- 0.006514223340283585,
- 0.008995351643073387,
- 0.011158246242958816,
- 0.012974970538910937,
- 0.01423268933158089,
- 0.015330047470700136,
- 0.01602476890041046,
- 0.0166477989417812,
- 0.017016090357039843,
- 0.017035494544465558,
- 0.01699546306281755,
- 0.016673802095929002,
- 0.015887762592297763,
- 0.014890055237687114,
- 0.013402811076696474,
- 0.011500537353479856,
- 0.009049504374390572,
- 0.005944598811241848
+ 0.0038819401091077133,
+ 0.00657749357289137,
+ 0.009083755453766702,
+ 0.011261246050560414,
+ 0.012792735643008287,
+ 0.014031169768181771,
+ 0.015238913580143431,
+ 0.016190000344508862,
+ 0.016718503781175408,
+ 0.017002911225970307,
+ 0.017055534163794735,
+ 0.01685878414761417,
+ 0.016313712268620943,
+ 0.01555684766405394,
+ 0.014467656699950203,
+ 0.013182914422505103,
+ 0.011169418659965487,
+ 0.008898965006706804,
+ 0.00525814858573891
],
"yaxis": "y"
},
@@ -721,25 +715,25 @@
],
"xaxis": "x",
"y": [
- 0.02359958308435484,
- 0.023571102327493376,
- 0.023542621570631906,
- 0.023514140813770443,
- 0.023485660056908973,
- 0.023457179300047506,
- 0.02342869854318604,
- 0.023400217786324572,
- 0.023371737029463106,
- 0.023343256272601642,
- 0.023314775515740172,
- 0.023286294758878705,
- 0.023257814002017235,
- 0.023229333245155775,
- 0.023200852488294302,
- 0.02317237173143284,
- 0.02314389097457137,
- 0.02311541021770991,
- 0.02308692946084844
+ 0.024895202639483433,
+ 0.02474020209681031,
+ 0.024585201554137192,
+ 0.024430201011464064,
+ 0.024275200468790944,
+ 0.02412019992611782,
+ 0.023965199383444703,
+ 0.023810198840771582,
+ 0.02365519829809846,
+ 0.023500197755425337,
+ 0.023345197212752217,
+ 0.0231901966700791,
+ 0.02303519612740598,
+ 0.022880195584732855,
+ 0.02272519504205973,
+ 0.022570194499386614,
+ 0.022415193956713497,
+ 0.022260193414040372,
+ 0.02210519287136725
],
"yaxis": "y"
}
@@ -766,8 +760,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 0.01624854378691274,
- "y1": 0.01624854378691274
+ "y0": 0.016370825439318438,
+ "y1": 0.016370825439318438
},
{
"line": {
@@ -779,8 +773,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 0.012827013601746304,
- "y1": 0.012827013601746304
+ "y0": 0.012666526690379761,
+ "y1": 0.012666526690379761
},
{
"line": {
@@ -792,8 +786,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 0.012834487101983627,
- "y1": 0.012834487101983627
+ "y0": 0.01271266584990866,
+ "y1": 0.01271266584990866
},
{
"line": {
@@ -805,8 +799,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 0.02334325627260164,
- "y1": 0.02334325627260164
+ "y0": 0.02350019775542534,
+ "y1": 0.02350019775542534
}
],
"template": {
@@ -1694,7 +1688,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -1730,38 +1724,38 @@
" \n",
" | 1 | \n",
" OLS | \n",
- " 0.012827 | \n",
+ " 0.012667 | \n",
" 0.05 | \n",
- " 0.003848 | \n",
+ " 0.003969 | \n",
" 0.55 | \n",
- " 0.017188 | \n",
+ " 0.016917 | \n",
"
\n",
" \n",
" | 2 | \n",
" QuantReg | \n",
- " 0.012834 | \n",
+ " 0.012713 | \n",
" 0.05 | \n",
- " 0.003581 | \n",
+ " 0.003882 | \n",
" 0.55 | \n",
- " 0.017035 | \n",
+ " 0.017056 | \n",
"
\n",
" \n",
" | 0 | \n",
" QRF | \n",
- " 0.016249 | \n",
+ " 0.016371 | \n",
" 0.05 | \n",
- " 0.005829 | \n",
- " 0.50 | \n",
- " 0.022895 | \n",
+ " 0.005206 | \n",
+ " 0.40 | \n",
+ " 0.022553 | \n",
"
\n",
" \n",
" | 3 | \n",
" Matching | \n",
- " 0.023343 | \n",
+ " 0.023500 | \n",
" 0.95 | \n",
- " 0.023087 | \n",
+ " 0.022105 | \n",
" 0.05 | \n",
- " 0.023600 | \n",
+ " 0.024895 | \n",
"
\n",
" \n",
"\n",
@@ -1769,13 +1763,13 @@
],
"text/plain": [
" Method Mean Test Quantile Loss Best Quantile Best Test Quantile Loss Worst Quantile Worst Test Quantile Loss\n",
- "1 OLS 0.012827 0.05 0.003848 0.55 0.017188\n",
- "2 QuantReg 0.012834 0.05 0.003581 0.55 0.017035\n",
- "0 QRF 0.016249 0.05 0.005829 0.50 0.022895\n",
- "3 Matching 0.023343 0.95 0.023087 0.05 0.023600"
+ "1 OLS 0.012667 0.05 0.003969 0.55 0.016917\n",
+ "2 QuantReg 0.012713 0.05 0.003882 0.55 0.017056\n",
+ "0 QRF 0.016371 0.05 0.005206 0.40 0.022553\n",
+ "3 Matching 0.023500 0.95 0.022105 0.05 0.024895"
]
},
- "execution_count": 10,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -1802,7 +1796,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -1840,28 +1834,28 @@
" \n",
" \n",
" | 0 | \n",
- " -0.010049 | \n",
- " -0.033896 | \n",
+ " 0.015336 | \n",
+ " 0.038018 | \n",
"
\n",
" \n",
" | 1 | \n",
- " 0.023339 | \n",
- " 0.036040 | \n",
+ " 0.019831 | \n",
+ " 0.036004 | \n",
"
\n",
" \n",
" | 2 | \n",
- " 0.000684 | \n",
- " -0.025150 | \n",
+ " -0.020689 | \n",
+ " -0.005872 | \n",
"
\n",
" \n",
" | 3 | \n",
- " -0.022801 | \n",
- " -0.009154 | \n",
+ " 0.015436 | \n",
+ " 0.021340 | \n",
"
\n",
" \n",
" | 4 | \n",
- " -0.010737 | \n",
- " -0.023260 | \n",
+ " -0.029310 | \n",
+ " -0.050130 | \n",
"
\n",
" \n",
"\n",
@@ -1869,14 +1863,14 @@
],
"text/plain": [
" s1 s4\n",
- "0 -0.010049 -0.033896\n",
- "1 0.023339 0.036040\n",
- "2 0.000684 -0.025150\n",
- "3 -0.022801 -0.009154\n",
- "4 -0.010737 -0.023260"
+ "0 0.015336 0.038018\n",
+ "1 0.019831 0.036004\n",
+ "2 -0.020689 -0.005872\n",
+ "3 0.015436 0.021340\n",
+ "4 -0.029310 -0.050130"
]
},
- "execution_count": 11,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -1892,7 +1886,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -1938,16 +1932,16 @@
" \n",
" \n",
" | 0 | \n",
- " -0.001882 | \n",
- " -0.044642 | \n",
- " -0.051474 | \n",
- " -0.026328 | \n",
- " -0.019163 | \n",
- " 0.074412 | \n",
- " -0.068332 | \n",
- " -0.092204 | \n",
- " -0.010049 | \n",
- " -0.033896 | \n",
+ " 0.038076 | \n",
+ " 0.050680 | \n",
+ " 0.061696 | \n",
+ " 0.021872 | \n",
+ " -0.034821 | \n",
+ " -0.043401 | \n",
+ " 0.019907 | \n",
+ " -0.017646 | \n",
+ " 0.015336 | \n",
+ " 0.038018 | \n",
"
\n",
" \n",
" | 1 | \n",
@@ -1959,24 +1953,11 @@
" -0.032356 | \n",
" 0.002861 | \n",
" -0.025930 | \n",
- " 0.023339 | \n",
- " 0.036040 | \n",
+ " 0.019831 | \n",
+ " 0.036004 | \n",
"
\n",
" \n",
" | 2 | \n",
- " 0.005383 | \n",
- " -0.044642 | \n",
- " -0.036385 | \n",
- " 0.021872 | \n",
- " 0.015596 | \n",
- " 0.008142 | \n",
- " -0.031988 | \n",
- " -0.046641 | \n",
- " 0.000684 | \n",
- " -0.025150 | \n",
- "
\n",
- " \n",
- " | 3 | \n",
" -0.045472 | \n",
" 0.050680 | \n",
" -0.047163 | \n",
@@ -1985,21 +1966,34 @@
" 0.000779 | \n",
" -0.062917 | \n",
" -0.038357 | \n",
- " -0.022801 | \n",
- " -0.009154 | \n",
+ " -0.020689 | \n",
+ " -0.005872 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 0.063504 | \n",
+ " 0.050680 | \n",
+ " -0.001895 | \n",
+ " 0.066629 | \n",
+ " 0.108914 | \n",
+ " 0.022869 | \n",
+ " -0.035816 | \n",
+ " 0.003064 | \n",
+ " 0.015436 | \n",
+ " 0.021340 | \n",
"
\n",
" \n",
" | 4 | \n",
- " -0.027310 | \n",
+ " -0.096328 | \n",
" -0.044642 | \n",
- " -0.018062 | \n",
- " -0.040099 | \n",
- " -0.011335 | \n",
- " 0.037595 | \n",
- " -0.008943 | \n",
- " -0.054925 | \n",
- " -0.010737 | \n",
- " -0.023260 | \n",
+ " -0.083808 | \n",
+ " 0.008101 | \n",
+ " -0.090561 | \n",
+ " -0.013948 | \n",
+ " -0.062917 | \n",
+ " -0.034215 | \n",
+ " -0.029310 | \n",
+ " -0.050130 | \n",
"
\n",
" \n",
"\n",
@@ -2007,14 +2001,14 @@
],
"text/plain": [
" age sex bmi bp s2 s3 s5 s6 s1 s4\n",
- "0 -0.001882 -0.044642 -0.051474 -0.026328 -0.019163 0.074412 -0.068332 -0.092204 -0.010049 -0.033896\n",
- "1 0.085299 0.050680 0.044451 -0.005670 -0.034194 -0.032356 0.002861 -0.025930 0.023339 0.036040\n",
- "2 0.005383 -0.044642 -0.036385 0.021872 0.015596 0.008142 -0.031988 -0.046641 0.000684 -0.025150\n",
- "3 -0.045472 0.050680 -0.047163 -0.015999 -0.024800 0.000779 -0.062917 -0.038357 -0.022801 -0.009154\n",
- "4 -0.027310 -0.044642 -0.018062 -0.040099 -0.011335 0.037595 -0.008943 -0.054925 -0.010737 -0.023260"
+ "0 0.038076 0.050680 0.061696 0.021872 -0.034821 -0.043401 0.019907 -0.017646 0.015336 0.038018\n",
+ "1 0.085299 0.050680 0.044451 -0.005670 -0.034194 -0.032356 0.002861 -0.025930 0.019831 0.036004\n",
+ "2 -0.045472 0.050680 -0.047163 -0.015999 -0.024800 0.000779 -0.062917 -0.038357 -0.020689 -0.005872\n",
+ "3 0.063504 0.050680 -0.001895 0.066629 0.108914 0.022869 -0.035816 0.003064 0.015436 0.021340\n",
+ "4 -0.096328 -0.044642 -0.083808 0.008101 -0.090561 -0.013948 -0.062917 -0.034215 -0.029310 -0.050130"
]
},
- "execution_count": 12,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -2036,7 +2030,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -2190,139 +2184,139 @@
132
],
"y": [
- -0.008448724111216851,
+ -0.04422349842444599,
-0.04559945128264711,
- 0.003934851612593237,
- -0.04009563984984263,
- -0.0029449126784123676,
- -0.015328488402222454,
- -0.04284754556624487,
-0.04009563984984263,
- -0.038719686991641515,
- -0.04422349842444599,
- 0.0080627101871966,
- 0.08786797596286161,
- -0.04972730985725048,
+ 0.09061988167926385,
+ -0.10338947132709418,
+ -0.007072771253015731,
+ 0.08924392882106273,
+ 0.0342058144930179,
+ -0.037343734133440394,
+ -0.037343734133440394,
0.0025588987543921156,
- 0.039709625925822375,
+ 0.05484510736603471,
+ 0.020446285911006685,
+ -0.0841261313122785,
+ -0.04147159270804375,
0.01219056876179996,
- -0.007072771253015731,
-0.0318399227006359,
- 0.06447677737344255,
+ -0.06623874415566393,
-0.04422349842444599,
- 0.005310804470794357,
- -0.005696818394814609,
- -0.00019300696201012598,
+ 0.024574144485610048,
0.027326050202012293,
-0.011200629827619093,
- -0.12678066991651324,
+ -0.038719686991641515,
-0.06761469701386505,
+ 0.010814615903598841,
+ 0.03282986163481677,
-0.029088016984233665,
- -0.00019300696201012598,
- -0.04559945128264711,
- 0.03695772020942014,
-0.0318399227006359,
- -0.04559945128264711,
- 0.014942474478202204,
+ 0.02182223876920781,
+ 0.027326050202012293,
+ 0.010814615903598841,
+ 0.041085578784023497,
+ -0.0249601584096303,
+ -0.0318399227006359,
+ -0.0029449126784123676,
+ 0.001182945896190995,
0.013566521620001083,
- 0.02319819162740893,
- -0.019456346976825818,
0.0080627101871966,
- 0.04383748450042574,
- -0.05523112129005496,
- 0.00943866304539772,
- -0.0029449126784123676,
- 0.01219056876179996,
- 0.019070333052805567,
+ -0.034591828417038145,
-0.033215875558837024,
- 0.01219056876179996,
- -0.02083229983502694,
+ 0.03833367306762126,
+ -0.005696818394814609,
-0.034591828417038145,
+ 0.0342058144930179,
+ 0.05484510736603471,
+ 0.03695772020942014,
0.039709625925822375,
- 0.08511607024645937,
+ -0.02358420555142918,
+ 0.1332744202834986,
+ -0.02358420555142918,
-0.037343734133440394,
- 0.020446285911006685,
- 0.02319819162740893,
+ -0.08825398988688185,
+ 0.05897296594063807,
+ -0.07587041416307178,
+ -0.0318399227006359,
+ 0.045213437358626866,
0.06447677737344255,
-0.004320865536613489,
- 0.039709625925822375,
+ -0.019456346976825818,
-0.018080394118624697,
- -0.009824676969417972,
- 0.126394655992493,
- 0.024574144485610048,
-0.051103262715451604,
- -0.0249601584096303,
+ -0.05523112129005496,
+ 0.04383748450042574,
+ 0.05209320164963247,
+ -0.016704441260423575,
+ 0.024574144485610048,
+ 0.04934129593323023,
-0.001568959820211247,
- -0.06623874415566393,
+ 0.053469154507833586,
+ 0.04658939021682799,
-0.04835135699904936,
- 0.003934851612593237,
- 0.0080627101871966,
- 0.039709625925822375,
- -0.02358420555142918,
- 0.006686757328995478,
- 0.08374011738825825,
- -0.009824676969417972,
+ -0.05385516843185383,
-0.04972730985725048,
- 0.006686757328995478,
- -0.04559945128264711,
- -0.008448724111216851,
- 0.020446285911006685,
- -0.04284754556624487,
- 0.0741084473808504,
- -0.02220825269322806,
+ -0.00019300696201012598,
+ 0.04796534307502911,
+ 0.024574144485610048,
+ 0.04934129593323023,
-0.007072771253015731,
- -0.026336111267831423,
- 0.08786797596286161,
- 0.01219056876179996,
- 0.001182945896190995,
- -0.004320865536613489,
+ 0.07548440023905152,
+ -0.07587041416307178,
+ 0.12501870313429186,
+ -0.007072771253015731,
+ -0.06623874415566393,
+ 0.04383748450042574,
+ 0.053469154507833586,
+ -0.04972730985725048,
+ 0.020446285911006685,
+ -0.0910058956032841,
+ -0.035967781275239266,
+ 0.05759701308243695,
-0.038719686991641515,
- -0.07311850844666953,
+ 0.03282986163481677,
-0.009824676969417972,
- 0.03145390877661565,
- 0.001182945896190995,
0.030077955918414535,
- -0.018080394118624697,
- 0.07823630595545376,
+ 0.08236416453005713,
0.07273249452264928,
- 0.024574144485610048,
- -0.029088016984233665,
- -0.027712064126032544,
+ -0.0029449126784123676,
-0.011200629827619093,
- 0.07548440023905152,
-0.009824676969417972,
0.04934129593323023,
- 0.04246153164222462,
- -0.08137422559587626,
- 0.0080627101871966,
- -0.04284754556624487,
+ 0.01219056876179996,
+ -0.051103262715451604,
+ 0.10988322169407955,
+ -0.037343734133440394,
+ -0.016704441260423575,
+ -0.046975404140848234,
-0.009824676969417972,
- 0.041085578784023497,
-0.001568959820211247,
- 0.05071724879143135,
- 0.013566521620001083,
- -0.06348683843926169,
- 0.03558176735121902,
+ -0.06623874415566393,
0.11951489170148738,
- 0.014942474478202204,
+ -0.005696818394814609,
-0.060734932722859444,
- 0.0933717873956661,
- -0.04972730985725048,
- -0.007072771253015731,
+ -0.02358420555142918,
+ -0.0318399227006359,
+ -0.008448724111216851,
+ -0.011200629827619093,
+ 0.017694380194604446,
+ -0.029088016984233665,
-0.06348683843926169,
- -0.046975404140848234,
-0.004320865536613489,
- -0.005696818394814609,
- -0.033215875558837024,
- -0.10476542418529532,
- 0.0080627101871966,
- 0.05622106022423583,
+ 0.04796534307502911,
+ 0.014942474478202204,
+ 0.10988322169407955,
0.02595009734381117,
- 0.001182945896190995,
- 0.03833367306762126,
+ -0.0029449126784123676,
+ -0.015328488402222454,
+ -0.06623874415566393,
+ 0.039709625925822375,
+ -0.016704441260423575,
+ 0.04658939021682799,
+ 0.04934129593323023,
-0.037343734133440394,
- 0.05759701308243695,
- -0.037343734133440394
+ 0.08374011738825825
]
},
{
@@ -2469,139 +2463,139 @@
132
],
"y": [
- -0.010048593007542262,
- 0.023338501758184967,
- 0.0006835338812098971,
- -0.02280104909751903,
- -0.010737111654463158,
- -0.023882530312442665,
- 0.004905462672456875,
- 0.00075209010919026,
- -0.024500914859212054,
- 0.00012343310011733236,
- 0.017633611347778456,
- 0.023923325909469433,
- -0.018645344540876234,
- 0.005277811601810739,
- -0.009855274714457844,
- 0.007914569040499487,
- -0.019304583858471332,
- -0.0005811599955468477,
- 0.007623403906902181,
- -0.017091036710617992,
- -0.008518459689484597,
- 0.0013922353480573715,
- -0.011756137000993299,
- 0.008291759189449511,
- -0.013497022547325079,
- 0.005991608353077426,
- -0.009325678474084179,
- -0.016809920256381464,
- -0.017594407659192728,
- -0.03287919566104813,
- -0.01228816787738256,
- -0.00870882658768158,
- -0.036714525841408316,
- 0.012981449114547638,
- 0.02840884952911296,
- -0.012616482722592536,
- -0.002279548231524427,
- 0.023114612439785107,
- 0.026424972870186483,
- -0.030049015469396378,
- 0.009311384481937881,
- 0.016733643303268906,
- -0.007873673097017036,
- -0.002877579080630571,
- 0.027914446691046705,
- 0.0029507383043251914,
- -0.019559820445164654,
- 0.025383751899413443,
- -0.00034674185383667434,
- 0.010239162490165103,
- -0.0016001853420769877,
- 0.01614303798145745,
- 0.01577545568149537,
- 0.015714382411644745,
- 0.008761174602915344,
- -0.013649928238251268,
- 0.0068917739673989425,
- -0.022187686180788476,
- 0.01973441993892279,
- 0.0019879644636025436,
- 0.011009557988734181,
- 0.03288282796707179,
- -0.012562620640996298,
- -0.025423633262118394,
- -0.005637687543304472,
- -0.014833299997513729,
- 0.0024731522961210953,
- -0.0010095782030263402,
- 0.018103345950304158,
- 0.02663366844269967,
- 0.010695697631391417,
- 0.03105841128947713,
- -0.006840667932786598,
- 0.0037729961806200942,
- 0.006040903192333032,
- -0.02359255783423953,
- 0.03722140747675881,
- 0.013015647431272856,
- 0.015443116588012482,
- -0.004049208088396334,
- -0.006340081548047081,
- 0.0016357617392529404,
- 0.016669502416500982,
- 0.0003262643150171215,
- 0.02831722421055538,
- 0.015758213452194043,
- 0.0021904385141604116,
- -0.00916266715190667,
- 0.0047985432615569295,
- -0.005399860495909248,
- 0.0040924003461566515,
- 0.026789563503795355,
- 0.014726621146321683,
- 0.012682366090675398,
- 0.004940893437097496,
- 0.0448942004438042,
- 0.0006184458069888334,
- -0.015845406652774032,
- 0.008716128273819948,
- 0.001312190890282829,
- 0.01494969167615913,
- -0.02612449460649421,
- 0.009870452202166206,
- -0.032133690363723176,
- -0.011799189045501394,
- 0.014727918261594447,
- -0.005570390455314822,
- 0.012555002558503524,
- -0.010820878330820053,
- 0.013427673818188376,
- -0.01447323863096344,
- -0.034081236099599406,
- 0.02083799231821074,
- 0.015612904166299195,
- 0.0062804053352556784,
- 0.019764291977684523,
- -0.0001538566756582576,
- -0.012625076509389786,
- -0.012462224781742222,
- -0.028917306700808585,
- 0.009785151048203338,
- 0.01148322191236023,
- 0.017416359963750055,
- 0.01984237241068698,
- 0.04262186680767208,
- 0.02023818373093776,
- -0.008000811006941897,
- -0.00238874268391004,
- 0.0016024256056552983,
- -0.004142487686568597,
- 0.009456771535636004,
- 0.015399224939029224,
- 0.005703937608435265
+ 0.015335893412261787,
+ 0.019831397350219125,
+ -0.020689466568013702,
+ 0.015435738201011403,
+ -0.029309854520789096,
+ 0.001686771135886331,
+ -0.009061951453460775,
+ 0.017965463348520977,
+ -0.01116463504524719,
+ -0.02358678648754828,
+ -0.008453630992034087,
+ -0.0013484009848432387,
+ -0.004486328029351374,
+ -0.030134425277439904,
+ 0.013845928682104743,
+ 0.0037734064556851406,
+ -0.0034911810657077486,
+ -0.017143571390884297,
+ -0.0160012916239723,
+ 0.017466302024966504,
+ 0.004847975953561153,
+ -0.011989139860143723,
+ -0.030107640332910232,
+ -0.007934290311649194,
+ -0.006669897776088721,
+ -0.009386840890569972,
+ -0.013233743591168474,
+ -0.025991127042483353,
+ 0.006375064425027523,
+ -0.0071520669761420365,
+ -0.0006150573929569955,
+ 0.0158274431299741,
+ 0.005079664165358413,
+ -0.006823312453897693,
+ 0.009020653324495854,
+ 0.0032865690880503444,
+ 0.027895300019423832,
+ 0.01878628758446477,
+ 0.004528605497318604,
+ -0.005771084427924661,
+ 0.0022644415676482,
+ -0.020593372075096025,
+ 0.019955105623094695,
+ -0.0015368678676446357,
+ 0.010515982284226087,
+ -0.014932491554899617,
+ -0.0038689227974084916,
+ -0.01625354129726306,
+ 0.003248955549460316,
+ -0.011691583556349312,
+ -0.02092413840684404,
+ -0.026657353626542733,
+ 0.0017785880135172188,
+ -0.018347371917627454,
+ 0.0069951411155479305,
+ 0.0001685296476830805,
+ 0.01190634535010246,
+ 0.006642342151338026,
+ -0.020132468838705737,
+ 0.004087188214181843,
+ -0.01765370850954273,
+ -0.0038256317515334304,
+ 0.00655293873409814,
+ 0.0036972779500178687,
+ 0.0032971798652427706,
+ 0.000026399604419164216,
+ 0.01877197402368051,
+ -0.010529883551155916,
+ 0.009025181168390975,
+ 0.03031336842418389,
+ -0.00473132401057213,
+ -0.007785326656716772,
+ -0.024033788693305965,
+ 0.011282834980223191,
+ 0.012474738954915887,
+ 0.00591482039160062,
+ 0.01818064545897285,
+ -0.014454422787418411,
+ 0.005560172146212504,
+ -0.00022774353327546836,
+ 0.013587732397295984,
+ -0.002786630452729916,
+ 0.004594127000918377,
+ 0.010594741840370153,
+ -0.01402530370721945,
+ -0.002927039096417097,
+ 0.0017026251384077134,
+ -0.008235354013588735,
+ 0.0009712485029294681,
+ 0.030143381703275265,
+ 0.0003156511255766694,
+ 0.02134981224601213,
+ 0.0012876315570140252,
+ 0.022679174548044474,
+ 0.0005260771488358384,
+ 0.0027198782265173563,
+ 0.006880628627366828,
+ 0.02985555227875509,
+ 0.008340713771782858,
+ -0.025555479188149915,
+ 0.009942404961568604,
+ -0.015401106312717697,
+ 0.0005543539068820487,
+ -0.01825465671447129,
+ -0.006467607129783769,
+ -0.016677090336012463,
+ -0.005584941969192943,
+ -0.008183032813576964,
+ -0.015628837074044038,
+ 0.010355132956507512,
+ -0.0001413286341740566,
+ 0.013897464783803484,
+ 0.017934095497606217,
+ -0.007159189011485633,
+ -0.0018843794232999262,
+ -0.01200286015971886,
+ 0.014059268615517779,
+ -0.024248978354007685,
+ -0.02705399182541282,
+ 0.01219108745687901,
+ -0.008344582868340144,
+ 0.03457954609658779,
+ 0.013769796401312541,
+ -0.0030489273968929195,
+ 0.01657664072438172,
+ -0.005791823072132839,
+ -0.015457004400971645,
+ 0.008815105508654105,
+ -0.003100832291628117,
+ -0.013886025407480058,
+ -0.012897576623653605,
+ 0.004467132249727546,
+ -0.025766976944524116
]
}
],
@@ -3598,139 +3592,139 @@
132
],
"y": [
- -0.03949338287409329,
-0.002592261998183278,
-0.002592261998183278,
-0.03949338287409329,
- -0.03949338287409329,
+ 0.01770335448356722,
+ -0.0763945037500033,
+ 0.07120997975363674,
+ 0.10811110062954676,
0.03430885887772673,
+ -0.002592261998183278,
-0.03949338287409329,
- -0.03949338287409329,
- -0.0763945037500033,
+ -0.002592261998183278,
+ 0.03430885887772673,
-0.0763945037500033,
-0.03949338287409329,
- 0.07120997975363674,
- 0.01585829843977173,
- -0.03949338287409329,
-0.002592261998183278,
-0.002592261998183278,
- -0.03949338287409329,
0.0029429061332032365,
- -0.002592261998183278,
- 0.07120997975363674,
+ -0.0763945037500033,
0.07120997975363674,
- -0.0018542395806650938,
- -0.03949338287409329,
- -0.03949338287409329,
-0.002592261998183278,
- -0.047980640675552584,
- -0.002592261998183278,
- -0.05019470792810719,
-0.03949338287409329,
-0.002592261998183278,
-0.03949338287409329,
-0.002592261998183278,
- -0.03949338287409329,
0.03430885887772673,
+ -0.03949338287409329,
+ -0.05019470792810719,
+ -0.03949338287409329,
0.03430885887772673,
0.03430885887772673,
-0.03949338287409329,
0.03430885887772673,
- 0.07120997975363674,
- -0.0763945037500033,
- 0.03430885887772673,
+ -0.03949338287409329,
-0.002592261998183278,
+ 0.05275941931568174,
+ 0.05017634085436802,
+ 0.020655444153640023,
+ 0.03430885887772673,
-0.03949338287409329,
+ -0.002592261998183278,
-0.03949338287409329,
-0.03949338287409329,
- 0.03430885887772673,
- 0.07120997975363674,
-0.002592261998183278,
- 0.10811110062954676,
- 0.03430885887772673,
- -0.03949338287409329,
+ 0.13025177315509276,
-0.002592261998183278,
- 0.08006624876385515,
- 0.07120997975363674,
-0.002592261998183278,
0.10811110062954676,
- 0.03430885887772673,
- -0.03949338287409329,
- 0.03430885887772673,
- 0.03430885887772673,
- 0.03430885887772673,
-0.03949338287409329,
+ 0.10811110062954676,
+ -0.002592261998183278,
-0.03949338287409329,
-0.03949338287409329,
+ 0.07120997975363674,
-0.03949338287409329,
-0.03949338287409329,
-0.002592261998183278,
- 0.056080520194513636,
- -0.03949338287409329,
- -0.012555564634678981,
+ 0.07120997975363674,
-0.002592261998183278,
0.03430885887772673,
+ 0.03430885887772673,
-0.03949338287409329,
- 0.08080427118137334,
- -0.002592261998183278,
- -0.0763945037500033,
- 0.07120997975363674,
- -0.002592261998183278,
- 0.07120997975363674,
+ -0.03949338287409329,
+ -0.014400620678474476,
+ -0.021411833644897377,
-0.002592261998183278,
0.03430885887772673,
-0.03949338287409329,
- 0.03430885887772673,
- 0.10811110062954676,
- 0.03430885887772673,
- -0.0011162171631468765,
-0.03949338287409329,
- -0.06938329078358041,
+ 0.14501222150545676,
-0.002592261998183278,
-0.03949338287409329,
+ -0.03949338287409329,
+ -0.03949338287409329,
+ -0.05056371913686628,
0.03430885887772673,
+ 0.05091436327188625,
+ 0.07120997975363674,
-0.03949338287409329,
+ 0.03430885887772673,
+ -0.03764832683029779,
+ -0.002592261998183278,
+ 0.03430885887772673,
+ -0.002592261998183278,
+ -0.002592261998183278,
+ 0.07120997975363674,
+ -0.0708593356186168,
+ 0.0003598276718895252,
-0.03949338287409329,
- 0.10811110062954676,
- 0.08486339447772344,
-0.002592261998183278,
-0.002592261998183278,
-0.03949338287409329,
0.03430885887772673,
-0.002592261998183278,
-0.03949338287409329,
- 0.03430885887772673,
+ 0.07120997975363674,
+ 0.08486339447772344,
+ 0.07120997975363674,
0.03430885887772673,
-0.03949338287409329,
+ 0.03430885887772673,
-0.002592261998183278,
-0.0763945037500033,
0.03430885887772673,
- 0.07120997975363674,
- -0.03949338287409329,
- 0.03430885887772673,
-0.002592261998183278,
- -0.03949338287409329,
-0.002592261998183278,
+ 0.028404679537581124,
+ 0.03430885887772673,
+ -0.03949338287409329,
+ -0.03949338287409329,
0.08670845052151895,
- 0.07120997975363674,
+ 0.03430885887772673,
-0.03395821474270679,
-0.002592261998183278,
- -0.002592261998183278,
- -0.002592261998183278,
-0.03949338287409329,
- -0.02583996815000658,
- 0.07120997975363674,
- -0.002592261998183278,
-0.03949338287409329,
+ -0.03949338287409329,
+ 0.03430885887772673,
+ -0.03949338287409329,
+ -0.03949338287409329,
+ 0.07120997975363674,
+ 0.03430885887772673,
-0.002592261998183278,
- 0.021024455362399115,
+ 0.09187460744414634,
0.07120997975363674,
+ 0.003311917341962329,
+ -0.021411833644897377,
+ -0.03949338287409329,
0.07120997975363674,
0.03430885887772673,
- 0.10811110062954676,
- -0.03949338287409329,
- 0.023238522614953735,
- -0.011079519799642579
+ -0.024732934523729287,
+ 0.03430885887772673,
+ -0.011079519799642579,
+ -0.03949338287409329
]
},
{
@@ -3877,139 +3871,139 @@
132
],
"y": [
- -0.033896113833800104,
- 0.03603982124583388,
- -0.02514959495753533,
- -0.009153673337028049,
- -0.023259537860115424,
- 0.0016458158473403005,
- 0.007286405897723118,
- -0.02075686822686397,
- -0.043549940538604455,
- -0.030065670331970884,
- 7.311121081554719e-06,
- 0.045670417658070464,
- 0.00421840989426087,
- 0.001979143923846953,
- -0.019941854302657065,
- -0.015785560034542314,
- -0.03628544223067562,
- -0.000431905226498143,
- 0.008043512609159614,
- 0.009438827081986647,
- 0.0033913922375727427,
- 0.00936105052206243,
- -0.04037565420945434,
- 0.017707113648686896,
- -0.0008696581328993848,
- 0.025507659225543027,
- -0.035756904959039895,
- -0.03267194180654841,
- -0.035269639644668144,
- -0.043825087161866595,
- -0.008444512462475219,
- 0.00026512218814421853,
- -0.05085905862010786,
- 0.023168611445820483,
- 0.04274087981128667,
- 0.006433139152016789,
- -0.029794262363819303,
- -0.00014511911203389622,
- 0.04921815584950953,
- -0.04362948712793756,
- -0.005124988093455922,
- 0.05147428487867014,
- -0.0021571756146030055,
- -0.012732564467696118,
- 0.038035850525582234,
- 0.0343401025197168,
- 0.005493144500231844,
- 0.017917736211547595,
- 0.013898083972782695,
- 0.03736836078924559,
- -0.017548021575837063,
- 0.025827083473416213,
- 0.0003470634633979322,
- 0.02119932556311277,
- 0.005545713799734401,
- 0.023127315143638298,
- -0.015306961231096025,
- -0.011776876399901732,
- 0.02098233604817692,
- 0.02301918688005211,
- 0.01195168623367922,
- 0.009411461487704654,
- -0.04105435737677416,
- -0.035717840307218605,
- -0.0009950380414296307,
- 0.0036143360552140687,
- -0.02279779986281323,
- 0.01355012448789025,
- 0.004137638543490342,
- 0.008212121884840636,
- -0.0037142954743162285,
- 0.0488844529369689,
- 0.02924273012093785,
- 0.006814163451597154,
- 0.02491074264006944,
- -0.014192098607728538,
- 0.05551452570218897,
- 0.021292042993315847,
- 0.01257631309674684,
- 0.008806690221627451,
- -0.0006362069927716291,
- 0.006392942331479835,
- -0.009834980456394235,
- -0.020196010625864586,
- 0.048446735251986645,
- 0.015426586597509753,
- 0.014090629649528215,
- -0.03188141815776567,
- 0.0010602561704424955,
- -0.022297811578360135,
- 0.019886962527918113,
- -0.0024503722447135885,
- -0.006988019211208833,
- -0.0075094456665051905,
- 0.00537611358032769,
- 0.038807979765950414,
- 0.00822974836581868,
- -0.009616152664245956,
- 0.012985506139464696,
- -0.015840792852428412,
- -0.012405557187776682,
- -0.026185280019799972,
- 0.015928109105293353,
- -0.01019686839658426,
- -0.0009578712331579572,
- -0.010973132938666662,
- 0.009889292081968478,
- 0.06280452698226915,
- -0.03645138545362458,
- 0.0154664197978276,
- 0.0013428775695772284,
- -0.03637064966358713,
- 0.020361105691256345,
- 0.013824893401455799,
- 0.019294232468343733,
- 0.001798967852404379,
- 0.00404114723675774,
- -0.013138751843066891,
- -0.017409128677473316,
- -0.005903945567537133,
- -0.011888412385449814,
- 0.006633162138983448,
- 0.023947943755935668,
- 0.011071463951143837,
- 0.06866356998876627,
- 0.023181440419472853,
- 0.0077770324968229505,
- -0.011758457364823017,
- -0.02497774386911474,
- -0.0003005575468449512,
- 0.006940629091609234,
- 0.009518457582575158,
- 0.010625739426498998
+ 0.03801815115296861,
+ 0.03600433856025695,
+ -0.005872054203347856,
+ 0.021339979233694583,
+ -0.05013020933635276,
+ 0.021103966187970995,
+ 0.005792824629814879,
+ 0.026193855290159787,
+ -0.022587022691146815,
+ -0.0010815380121506797,
+ 0.01608793607263828,
+ 0.01248313194439169,
+ -0.0224189561986319,
+ -0.04701354762403807,
+ 0.025713164081000814,
+ -0.018310642590766937,
+ -0.005986054582544699,
+ -0.02364900535375452,
+ 0.009907036694734271,
+ -0.000002404753059959564,
+ 0.017687309793477286,
+ 0.0023696060773003933,
+ -0.03746171227414108,
+ -0.033349180448551084,
+ 0.007053876868125156,
+ -0.02820745943938836,
+ -0.030222789694490354,
+ -0.006071293198440249,
+ 0.01062949653330887,
+ -0.03503182585650625,
+ 0.0030053110241476232,
+ 0.033141470802637005,
+ 0.0037258039133428148,
+ 0.004628378087605184,
+ 0.0022951113220639097,
+ 0.014563485262689625,
+ 0.025012862840052986,
+ -0.0023358269407980706,
+ -0.007087563979847968,
+ 0.010724496636703454,
+ 0.016187322405989208,
+ -0.04023158643704103,
+ 0.009113556784833656,
+ -0.0032022106628799476,
+ 0.022964274184683983,
+ 0.012752745314570452,
+ 0.012696059442091116,
+ -0.03768422189149042,
+ 0.032457011495066905,
+ -0.011954506767732818,
+ -0.04124662078961619,
+ -0.011678196886871616,
+ -0.02173977357692884,
+ -0.03424992605422941,
+ -0.021496789629415827,
+ -0.025420892467095186,
+ 0.021758213602930666,
+ 0.008816463115973653,
+ 0.0015276912335713561,
+ -0.016869485965228165,
+ -0.038724458728373726,
+ -0.002842482193406489,
+ -0.024044611982656255,
+ -0.028169635441315906,
+ -0.001962254525920726,
+ 0.022262923016894892,
+ -0.006072092930318232,
+ -0.03784044954011754,
+ 0.02820569241279132,
+ 0.00563274933449468,
+ 0.0032733856212008557,
+ 0.0048896372256658676,
+ -0.04192883734954138,
+ 0.012390566525925578,
+ 0.005398119619099769,
+ 0.0033530001917489788,
+ 0.037322726395419926,
+ -0.005674946121132441,
+ -0.011870036824520752,
+ -0.012936726530537759,
+ -0.003001151640504267,
+ 0.005897798534626533,
+ 0.01156550845508734,
+ 0.0157782038280125,
+ 0.023573253528940488,
+ -0.021873063986186398,
+ 0.009146401006039789,
+ 0.0026837259590410424,
+ 0.0126092595351763,
+ 0.006457660483236803,
+ 0.014897228577978438,
+ 0.038231670790419496,
+ -0.0042202789536772145,
+ -0.0033324742680576162,
+ -0.007122362560399039,
+ 0.0006276621718468464,
+ 0.025924116108620117,
+ 0.02438308430344615,
+ -0.01636902440042491,
+ -0.029441614977468097,
+ 0.034278259085474876,
+ -0.00940529967700519,
+ 0.027097509022699355,
+ 0.00451092959112535,
+ 0.025006211365158015,
+ 0.0015600565363524804,
+ 0.011556548106624813,
+ -0.03306985914736149,
+ -0.02499472658610263,
+ 0.007016021831458234,
+ 0.016101158246854777,
+ -0.0028063213932619465,
+ 0.00877540472169476,
+ 0.004943881299674521,
+ -0.02624868211126258,
+ 0.011276154360700235,
+ 0.038921338557117616,
+ -0.03759438804635347,
+ -0.004598964070324002,
+ 0.012902550593130674,
+ 0.004476567707528896,
+ 0.0217673068762963,
+ 0.02358583143467586,
+ -0.013750537897239642,
+ 0.01170608470046173,
+ 0.001200918799471937,
+ -0.03563060130992543,
+ 0.023477943947110724,
+ 0.017095674661237334,
+ -0.029766700309187353,
+ 0.006271215506707029,
+ 0.013286911904454558,
+ -0.04472909282444158
]
}
],
@@ -4915,13 +4909,13 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "19a986cc55354dc580e74ae26a478eb3",
+ "model_id": "efa4d6a9d4154b209f0c9545323e8cf0",
"version_major": 2,
"version_minor": 0
},
@@ -4937,12 +4931,12 @@
"output_type": "stream",
"text": [
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 1.5s finished\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 1.4s finished\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Batch computation too fast (0.06139373779296875s.) Setting batch_size=2.\n",
+ "[Parallel(n_jobs=-1)]: Batch computation too fast (0.051928043365478516s.) Setting batch_size=2.\n",
"[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.1s finished\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 2.7s finished\n"
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 2.4s finished\n"
]
},
{
@@ -4973,28 +4967,28 @@
" \n",
" \n",
" | 0 | \n",
- " -0.010049 | \n",
- " -0.033896 | \n",
+ " 0.015336 | \n",
+ " 0.038018 | \n",
"
\n",
" \n",
" | 1 | \n",
- " 0.023339 | \n",
- " 0.036040 | \n",
+ " 0.019831 | \n",
+ " 0.036004 | \n",
"
\n",
" \n",
" | 2 | \n",
- " 0.000684 | \n",
- " -0.025150 | \n",
+ " -0.020689 | \n",
+ " -0.005872 | \n",
"
\n",
" \n",
" | 3 | \n",
- " -0.022801 | \n",
- " -0.009154 | \n",
+ " 0.015436 | \n",
+ " 0.021340 | \n",
"
\n",
" \n",
" | 4 | \n",
- " -0.010737 | \n",
- " -0.023260 | \n",
+ " -0.029310 | \n",
+ " -0.050130 | \n",
"
\n",
" \n",
" | ... | \n",
@@ -5003,28 +4997,28 @@
"
\n",
" \n",
" | 128 | \n",
- " 0.001602 | \n",
- " -0.024978 | \n",
+ " -0.003101 | \n",
+ " 0.017096 | \n",
"
\n",
" \n",
" | 129 | \n",
- " -0.004142 | \n",
- " -0.000301 | \n",
+ " -0.013886 | \n",
+ " -0.029767 | \n",
"
\n",
" \n",
" | 130 | \n",
- " 0.009457 | \n",
- " 0.006941 | \n",
+ " -0.012898 | \n",
+ " 0.006271 | \n",
"
\n",
" \n",
" | 131 | \n",
- " 0.015399 | \n",
- " 0.009518 | \n",
+ " 0.004467 | \n",
+ " 0.013287 | \n",
"
\n",
" \n",
" | 132 | \n",
- " 0.005704 | \n",
- " 0.010626 | \n",
+ " -0.025767 | \n",
+ " -0.044729 | \n",
"
\n",
" \n",
"\n",
@@ -5033,22 +5027,22 @@
],
"text/plain": [
" s1 s4\n",
- "0 -0.010049 -0.033896\n",
- "1 0.023339 0.036040\n",
- "2 0.000684 -0.025150\n",
- "3 -0.022801 -0.009154\n",
- "4 -0.010737 -0.023260\n",
+ "0 0.015336 0.038018\n",
+ "1 0.019831 0.036004\n",
+ "2 -0.020689 -0.005872\n",
+ "3 0.015436 0.021340\n",
+ "4 -0.029310 -0.050130\n",
".. ... ...\n",
- "128 0.001602 -0.024978\n",
- "129 -0.004142 -0.000301\n",
- "130 0.009457 0.006941\n",
- "131 0.015399 0.009518\n",
- "132 0.005704 0.010626\n",
+ "128 -0.003101 0.017096\n",
+ "129 -0.013886 -0.029767\n",
+ "130 -0.012898 0.006271\n",
+ "131 0.004467 0.013287\n",
+ "132 -0.025767 -0.044729\n",
"\n",
"[133 rows x 2 columns]"
]
},
- "execution_count": 14,
+ "execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -5090,13 +5084,13 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "8ef1605cda164f92898a8c9931062bed",
+ "model_id": "420d1a493db340ee9cc62e2ef7a8a158",
"version_major": 2,
"version_minor": 0
},
@@ -5112,14 +5106,56 @@
"output_type": "stream",
"text": [
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.8s finished\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 1.0s finished\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Batch computation too fast (0.05136895179748535s.) Setting batch_size=2.\n",
+ "[Parallel(n_jobs=-1)]: Batch computation too fast (0.06503176689147949s.) Setting batch_size=2.\n",
"[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.1s finished\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
"[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.5s finished\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 1.2s finished\n"
+ "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 2.3s finished\n",
+ "{0.5: s1 s4\n",
+ "0 -0.034592 -0.002592\n",
+ "1 0.052093 0.071210\n",
+ "2 -0.000193 -0.002592\n",
+ "3 0.035582 0.034309\n",
+ "4 0.002559 -0.002592\n",
+ ".. ... ...\n",
+ "128 0.061725 0.108111\n",
+ "129 -0.056607 -0.076395\n",
+ "130 -0.038720 -0.039493\n",
+ "131 0.006687 0.034309\n",
+ "132 -0.051103 -0.076395\n",
+ "\n",
+ "[133 rows x 2 columns]}\n",
+ "{0.5: s1 s4\n",
+ "0 0.011052 0.029581\n",
+ "1 0.015192 0.026147\n",
+ "2 -0.020610 -0.011374\n",
+ "3 0.015005 0.010960\n",
+ "4 -0.026301 -0.053540\n",
+ ".. ... ...\n",
+ "128 -0.006031 0.010273\n",
+ "129 -0.014371 -0.034410\n",
+ "130 -0.015924 0.000246\n",
+ "131 0.003446 0.004326\n",
+ "132 -0.026093 -0.048635\n",
+ "\n",
+ "[133 rows x 2 columns]}\n",
+ "{0.5: s1 s4\n",
+ "0 0.028702 0.071210\n",
+ "1 0.020446 -0.002592\n",
+ "2 -0.009825 -0.039493\n",
+ "3 -0.042848 -0.076395\n",
+ "4 -0.089630 -0.076395\n",
+ ".. ... ...\n",
+ "128 -0.015328 -0.039493\n",
+ "129 -0.015328 -0.002592\n",
+ "130 0.069981 0.071210\n",
+ "131 -0.026336 -0.039493\n",
+ "132 -0.037344 -0.039493\n",
+ "\n",
+ "[133 rows x 2 columns]}\n"
]
},
{
@@ -5154,7 +5190,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "pe",
+ "display_name": "pe3.13",
"language": "python",
"name": "python3"
},
@@ -5168,7 +5204,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.11"
+ "version": "3.13.0"
}
},
"nbformat": 4,
diff --git a/docs/examples/scf_to_cps/imputing-from-scf-to-cps.md b/docs/examples/scf_to_cps/imputing-from-scf-to-cps.md
index f769dad..487a245 100644
--- a/docs/examples/scf_to_cps/imputing-from-scf-to-cps.md
+++ b/docs/examples/scf_to_cps/imputing-from-scf-to-cps.md
@@ -26,6 +26,7 @@ from microimpute.config import (
)
from microimpute.comparisons import *
from microimpute.visualizations import *
+from microimpute.utils.data import preprocess_data
logger = logging.getLogger(__name__)
```
diff --git a/docs/imputation-benchmarking/benchmarking-methods.ipynb b/docs/imputation-benchmarking/benchmarking-methods.ipynb
index 3a6b2b8..7a0c26a 100644
--- a/docs/imputation-benchmarking/benchmarking-methods.ipynb
+++ b/docs/imputation-benchmarking/benchmarking-methods.ipynb
@@ -20,22 +20,12 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.\n",
- " warn(\n",
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
+ "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <668E1903-F0E7-30D5-BA27-15F8287F87F7> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
"Trying to import in ABI mode.\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"PWD\" redefined by R and overriding existing variable. Current: \"/\", R: \"/Users/movil1/Desktop/PYTHONJOBS/PolicyEngine/microimpute/docs/imputation-benchmarking\"\n",
+ "/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable \"PWD\" redefined by R and overriding existing variable. Current: \"/\", R: \"/Users/movil1/Desktop/PYTHONJOBS/PolicyEngine/microimpute/docs/imputation-benchmarking\"\n",
" warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpS4rNst\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpXjStes\"\n",
- " warnings.warn(\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n"
+ "/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmp5Lbp6u\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpVj8E2X\"\n",
+ " warnings.warn(\n"
]
},
{
@@ -150,20 +140,20 @@
0.009564564121272505,
0.011532085657157347,
0.013053903571819542,
- 0.014291605945810025,
- 0.015216700879810455,
- 0.015781289327923692,
- 0.016173141510520364,
- 0.01641601671465607,
+ 0.014291605945810021,
+ 0.015216700879810452,
+ 0.015781289327923685,
+ 0.01617314151052036,
+ 0.016416016714656072,
0.016480303014993068,
0.01632579975820032,
0.015891620338453375,
- 0.01523896098391175,
+ 0.015238960983911747,
0.014242468395712094,
0.012965220703710137,
- 0.011254457382706138,
- 0.008994380201344713,
- 0.005659521998029093
+ 0.011254457382706137,
+ 0.008994380201344711,
+ 0.005659521998029092
],
"yaxis": "y"
},
@@ -206,25 +196,25 @@
],
"xaxis": "x",
"y": [
- 0.005239275243674502,
- 0.00799603543382727,
- 0.0117399011273074,
- 0.013957266185122393,
- 0.017263499527228415,
- 0.018511119785826192,
- 0.02185473619066202,
- 0.02026064478603208,
- 0.020381045270475285,
- 0.02069131528786059,
- 0.02213152108169212,
- 0.021121538017694896,
- 0.020614221114154535,
- 0.019629395497276665,
- 0.01836918512076797,
- 0.015989935715202692,
- 0.014778348614227641,
- 0.012133777014980232,
- 0.0090018693320282
+ 0.004291351787466083,
+ 0.007272110073722339,
+ 0.011676671678615531,
+ 0.013027275015182326,
+ 0.016803790338788354,
+ 0.018110390197887235,
+ 0.022578661550766956,
+ 0.021111029043520753,
+ 0.02147035392015208,
+ 0.0211277021386684,
+ 0.022135149000879355,
+ 0.021675884069504915,
+ 0.020352699968171607,
+ 0.019579226557658855,
+ 0.018201264289815798,
+ 0.015392883871817184,
+ 0.014727143126270617,
+ 0.011637270360498182,
+ 0.007386201435700083
],
"yaxis": "y"
},
@@ -267,25 +257,25 @@
],
"xaxis": "x",
"y": [
- 0.004020869986270967,
- 0.006849660847954891,
- 0.009413850902507178,
- 0.011365912151055294,
- 0.01278641721054225,
- 0.014161926009552468,
- 0.015223350993261655,
- 0.01586450534456412,
- 0.016277396409780275,
- 0.016604962119646828,
- 0.016673676369518477,
- 0.01650625461447753,
- 0.01606565134149738,
- 0.015456297387661702,
- 0.014538800843641784,
- 0.013230135012355737,
- 0.01128507945452094,
- 0.008998790611303901,
- 0.005404152791496269
+ 0.0040208699862709935,
+ 0.0068496608479549055,
+ 0.009413850902584769,
+ 0.011365912150398267,
+ 0.012786417210542328,
+ 0.014161926009550265,
+ 0.015223350993261645,
+ 0.01586450534456375,
+ 0.01627739640978032,
+ 0.016604962119648296,
+ 0.01667367636951848,
+ 0.01650625461447754,
+ 0.01606565134149755,
+ 0.015456297387661706,
+ 0.014538800843642214,
+ 0.013230135012355235,
+ 0.011285079454520975,
+ 0.008998790611307789,
+ 0.0054041527914963005
],
"yaxis": "y"
}
@@ -325,8 +315,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 0.012657845989966958,
- "y1": 0.012657845989966958
+ "y0": 0.012657845989966957,
+ "y1": 0.012657845989966957
},
{
"line": {
@@ -338,8 +328,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 0.01640340159716006,
- "y1": 0.01640340159716006
+ "y0": 0.016239845180267713,
+ "y1": 0.016239845180267713
},
{
"line": {
@@ -351,8 +341,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 0.012669878442189983,
- "y1": 0.012669878442189983
+ "y0": 0.01266987844215965,
+ "y1": 0.01266987844215965
}
],
"template": {
@@ -1275,21 +1265,6 @@
"id": "82cd40c6",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 1/1 [00:00<00:00, 1.01it/s]\n",
- "Found 3 numeric columns with unique values < 10, treating as categorical: ['hhsex', 'married', 'race']. Converting to dummy variables.\n",
- "Found 3 numeric columns with unique values < 10, treating as categorical: ['hhsex', 'married', 'race']. Converting to dummy variables.\n",
- "Found 3 numeric columns with unique values < 10, treating as categorical: ['hhsex', 'married', 'race']. Converting to dummy variables.\n",
- "Found 3 numeric columns with unique values < 10, treating as categorical: ['hhsex', 'married', 'race']. Converting to dummy variables.\n",
- "Found 3 numeric columns with unique values < 10, treating as categorical: ['hhsex', 'married', 'race']. Converting to dummy variables.\n",
- "Found 3 numeric columns with unique values < 10, treating as categorical: ['hhsex', 'married', 'race']. Converting to dummy variables.\n",
- "Found 3 numeric columns with unique values < 10, treating as categorical: ['hhsex', 'married', 'race']. Converting to dummy variables.\n",
- "Found 3 numeric columns with unique values < 10, treating as categorical: ['hhsex', 'married', 'race']. Converting to dummy variables.\n"
- ]
- },
{
"data": {
"application/vnd.plotly.v1+json": {
@@ -1336,25 +1311,25 @@
],
"xaxis": "x",
"y": [
- 3034815.8529673447,
- 3033868.616352592,
- 3032921.37973784,
- 3031974.1431230884,
- 3031026.906508336,
- 3030079.6698935837,
- 3029132.4332788303,
- 3028185.1966640786,
- 3027237.960049326,
- 3026290.723434574,
- 3025343.4868198223,
- 3024396.25020507,
- 3023449.013590318,
- 3022501.776975565,
- 3021554.5403608135,
- 3020607.303746061,
- 3019660.067131309,
- 3018712.8305165567,
- 3017765.5939018037
+ 12257052.081521738,
+ 11844897.445652174,
+ 11432742.80978261,
+ 11020588.173913043,
+ 10608433.538043479,
+ 10196278.902173912,
+ 9784124.266304348,
+ 9371969.630434783,
+ 8959814.994565217,
+ 8547660.358695652,
+ 8135505.722826087,
+ 7723351.0869565215,
+ 7311196.451086956,
+ 6899041.815217392,
+ 6486887.179347826,
+ 6074732.543478261,
+ 5662577.907608695,
+ 5250423.271739131,
+ 4838268.635869565
],
"yaxis": "y"
},
@@ -1397,25 +1372,25 @@
],
"xaxis": "x",
"y": [
- 1397489.4381525521,
- 2229173.3779636365,
- 2771140.641600416,
- 3088054.4996551564,
- 3209342.9960465343,
- 3189311.1947113685,
- 3048732.3398239953,
- 2957118.9714465705,
- 2998770.990197091,
- 3162149.8323011436,
- 3555469.5569606656,
- 4026926.4816835476,
- 4424295.935403726,
- 4703973.8841415215,
- 4815927.996039472,
- 4756909.944740554,
- 4516354.0958521655,
- 4060978.5000600247,
- 3267478.1792489397
+ 10030371.987745736,
+ 10428650.95466243,
+ 10481079.0229648,
+ 10334829.373497885,
+ 10101451.430726487,
+ 9779301.8484671,
+ 9332518.999557137,
+ 8945728.183264535,
+ 8606805.346409503,
+ 8299717.323402288,
+ 7940102.339754164,
+ 7571653.4587681005,
+ 7094172.5811651815,
+ 6497751.156054732,
+ 5804031.812428825,
+ 4949092.780861613,
+ 3984160.096393986,
+ 2874503.8952907776,
+ 1650462.4604610256
],
"yaxis": "y"
},
@@ -1458,25 +1433,25 @@
],
"xaxis": "x",
"y": [
- 275471.4514564969,
- 525147.6390518384,
- 777762.9247261932,
- 947034.3454195022,
- 1576343.715606303,
- 1993360.17680749,
- 2105712.9035432744,
- 2636133.7744909795,
- 2418276.282347567,
- 2808823.1000137017,
- 2662871.4482758623,
- 2793367.038034254,
- 2447774.3527216264,
- 2443728.789188398,
- 2230923.218661795,
- 2506069.489048641,
- 2785339.1446174923,
- 2578752.672540763,
- 2266182.1076629367
+ 1286861.6641304349,
+ 2374732.8608695655,
+ 3259576.1608695653,
+ 4369523.769565217,
+ 12726122.43478261,
+ 6369122.239130435,
+ 12516260.560869563,
+ 7105905.247826086,
+ 9378343.482608695,
+ 9249420.56521739,
+ 7816548.561956522,
+ 7879498.678260869,
+ 7021816.3021739125,
+ 6887719.176086957,
+ 6295466.559782608,
+ 5571821.3999999985,
+ 4781336.205434782,
+ 2876876.4413043475,
+ 1977240.7739130442
],
"yaxis": "y"
},
@@ -1519,25 +1494,25 @@
],
"xaxis": "x",
"y": [
- 347695.6724753142,
- 541651.9665176176,
- 793908.404674841,
- 1467938.6305500015,
- 1823313.3400330453,
- 2320617.355629138,
- 2670884.145058224,
- 2963627.4553304156,
- 2565497.861722767,
- 2786684.844401457,
- 2501194.819308105,
- 2597643.1544135334,
- 2690474.726608,
- 2801237.2895192374,
- 2923139.645882121,
- 3010778.951479208,
- 2890941.220059649,
- 3421013.450508705,
- 2982167.326057204
+ 1030262.7873318979,
+ 2838847.5198398344,
+ 6576456.093627476,
+ 6799610.589435382,
+ 7025602.493850224,
+ 6784537.161760386,
+ 12318260.722365955,
+ 12529016.349454954,
+ 11600376.468864048,
+ 10733950.642877832,
+ 9811532.832926467,
+ 8858653.897544265,
+ 7916284.027393919,
+ 7017813.370916364,
+ 6002838.4184497,
+ 6954589.377981961,
+ 6743000.181936221,
+ 5720085.741960459,
+ 3874864.6706502736
],
"yaxis": "y"
}
@@ -1564,8 +1539,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 3026290.7234345744,
- "y1": 3026290.7234345744
+ "y0": 8547660.35869565,
+ "y1": 8547660.35869565
},
{
"line": {
@@ -1577,8 +1552,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 3483136.781896267,
- "y1": 3483136.781896267
+ "y0": 7616125.529046121,
+ "y1": 7616125.529046121
},
{
"line": {
@@ -1590,8 +1565,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 2041003.9249586905,
- "y1": 2041003.9249586905
+ "y0": 6302325.951830663,
+ "y1": 6302325.951830663
},
{
"line": {
@@ -1603,8 +1578,8 @@
"type": "line",
"x0": -0.5,
"x1": 18.5,
- "y0": 2321074.224222557,
- "y1": 2321074.224222557
+ "y0": 7428241.228903559,
+ "y1": 7428241.228903559
}
],
"template": {
@@ -2470,21 +2445,162 @@
"source": [
"# On the SCF Dataset\n",
"\n",
- "from typing import List, Type\n",
+ "from typing import List, Type, Optional, Union\n",
"\n",
+ "import io\n",
+ "import logging\n",
"import pandas as pd\n",
+ "from pydantic import validate_call\n",
+ "import requests\n",
+ "import zipfile\n",
"\n",
"from microimpute.comparisons import *\n",
- "from microimpute.config import RANDOM_STATE\n",
+ "from microimpute.config import RANDOM_STATE, VALIDATE_CONFIG, VALID_YEARS\n",
"from microimpute.models import *\n",
+ "from microimpute.utils.data import preprocess_data\n",
"\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
+ "logger = logging.getLogger(__name__)\n",
+ "\n",
"# 1. Prepare data\n",
- "X_train, X_test, PREDICTORS, IMPUTED_VARIABLES = prepare_scf_data(\n",
- " full_data=False, years=2019\n",
+ "@validate_call(config=VALIDATE_CONFIG)\n",
+ "def load_scf(\n",
+ " years: Optional[Union[int, List[int]]] = None,\n",
+ " columns: Optional[List[str]] = None,\n",
+ ") -> pd.DataFrame:\n",
+ " \"\"\"Load Survey of Consumer Finances data for specified years and columns.\n",
+ "\n",
+ " Args:\n",
+ " years: Year or list of years to load data for.\n",
+ " columns: List of column names to load.\n",
+ "\n",
+ " Returns:\n",
+ " DataFrame containing the requested data.\n",
+ "\n",
+ " Raises:\n",
+ " ValueError: If no Stata files are found in the downloaded zip\n",
+ " or invalid parameters\n",
+ " RuntimeError: If there's a network error or a problem processing\n",
+ " the downloaded data\n",
+ " \"\"\"\n",
+ " def scf_url(year: int) -> str:\n",
+ " \"\"\"Return the URL of the SCF summary microdata zip file for a year.\"\"\"\n",
+ "\n",
+ " if year not in VALID_YEARS:\n",
+ " logger.error(\n",
+ " f\"Invalid SCF year: {year}. Valid years are {VALID_YEARS}\"\n",
+ " )\n",
+ " raise\n",
+ "\n",
+ " url = f\"https://www.federalreserve.gov/econres/files/scfp{year}s.zip\"\n",
+ " return url\n",
+ "\n",
+ " logger.info(f\"Loading SCF data with years={years}\")\n",
+ "\n",
+ " try:\n",
+ " # Identify years for download\n",
+ " if years is None:\n",
+ " years = VALID_YEARS\n",
+ " logger.warning(f\"Using default years: {years}\")\n",
+ "\n",
+ " if isinstance(years, int):\n",
+ " years = [years]\n",
+ "\n",
+ " all_data: List[pd.DataFrame] = []\n",
+ "\n",
+ " for year in years:\n",
+ " logger.info(f\"Processing data for year {year}\")\n",
+ " try:\n",
+ " # Download zip file\n",
+ " logger.debug(f\"Downloading SCF data for year {year}\")\n",
+ " url = scf_url(year)\n",
+ " try:\n",
+ " response = requests.get(url, timeout=60)\n",
+ " response.raise_for_status() # Raise an error for bad responses\n",
+ " except requests.exceptions.RequestException as e:\n",
+ " logger.error(\n",
+ " f\"Network error downloading SCF data for year {year}: {str(e)}\"\n",
+ " )\n",
+ " raise\n",
+ "\n",
+ " # Process zip file\n",
+ " z = zipfile.ZipFile(io.BytesIO(response.content))\n",
+ " # Find the .dta file in the zip\n",
+ " dta_files: List[str] = [\n",
+ " f for f in z.namelist() if f.endswith(\".dta\")\n",
+ " ]\n",
+ " if not dta_files:\n",
+ " logger.error(\n",
+ " f\"No Stata files found in zip for year {year}\"\n",
+ " )\n",
+ " raise\n",
+ "\n",
+ " # Read the Stata file\n",
+ " try:\n",
+ " logger.debug(f\"Reading Stata file: {dta_files[0]}\")\n",
+ " with z.open(dta_files[0]) as f:\n",
+ " df = pd.read_stata(\n",
+ " io.BytesIO(f.read()), columns=columns\n",
+ " )\n",
+ " logger.debug(\n",
+ " f\"Read DataFrame with shape {df.shape}\"\n",
+ " )\n",
+ " except Exception as e:\n",
+ " logger.error(\n",
+ " f\"Error reading Stata file for year {year}: {str(e)}\"\n",
+ " )\n",
+ " raise\n",
+ "\n",
+ " # Add year column\n",
+ " df[\"year\"] = year\n",
+ " logger.info(\n",
+ " f\"Successfully processed data for year {year}, shape: {df.shape}\"\n",
+ " )\n",
+ " all_data.append(df)\n",
+ "\n",
+ " except Exception as e:\n",
+ " logger.error(f\"Error processing year {year}: {str(e)}\")\n",
+ " raise\n",
+ "\n",
+ " # Combine all years\n",
+ " logger.debug(f\"Combining data from {len(all_data)} years\")\n",
+ " if len(all_data) > 1:\n",
+ " result = pd.concat(all_data)\n",
+ " logger.info(\n",
+ " f\"Combined data from {len(years)} years, final shape: {result.shape}\"\n",
+ " )\n",
+ " return result\n",
+ " else:\n",
+ " logger.info(\n",
+ " f\"Returning data for single year, shape: {all_data[0].shape}\"\n",
+ " )\n",
+ " return all_data[0]\n",
+ "\n",
+ " except Exception as e:\n",
+ " logger.error(f\"Error in _load: {str(e)}\")\n",
+ " raise\n",
+ "\n",
+ "scf_data = load_scf(2022)\n",
+ "PREDICTORS: List[str] = [\n",
+ " \"hhsex\", # sex of head of household\n",
+ " \"age\", # age of respondent\n",
+ " \"married\", # marital status of respondent\n",
+ " # \"kids\", # number of children in household\n",
+ " \"race\", # race of respondent\n",
+ " \"income\", # total annual income of household\n",
+ " \"wageinc\", # income from wages and salaries\n",
+ " \"bussefarminc\", # income from business, self-employment or farm\n",
+ " \"intdivinc\", # income from interest and dividends\n",
+ " \"ssretinc\", # income from social security and retirement accounts\n",
+ " \"lf\", # labor force status\n",
+ " ]\n",
+ "IMPUTED_VARIABLES: List[str] = [\"networth\"]\n",
+ "\n",
+ "X_train, X_test = preprocess_data(\n",
+ " data=scf_data, full_data=False, normalize=False,\n",
")\n",
"\n",
"# Shrink down the data by sampling\n",
@@ -2524,11 +2640,9 @@
"source": [
"## Data preparation\n",
"\n",
- "The data preparation phase establishes the foundation for meaningful benchmarking comparisons. The `prepare_scf_data()` function specifically handles Survey of Consumer Finances data, though the framework accommodates any properly formatted dataset. This function downloads data from user-specified survey years, carefully selecting relevant predictor and target variables that capture the essential relationships for imputation. \n",
- "\n",
- "The function applies normalization techniques to the features, ensuring that variables with different scales don't unduly influence the imputation models. This preprocessing step is crucial, particularly if introducing additional for methods like nearest neighbor matching that rely on distance calculations. Finally, the function splits the data into training and testing sets, maintaining the statistical properties of both sets while creating an appropriate evaluation framework.\n",
+ "The data preparation phase establishes the foundation for meaningful benchmarking comparisons. The `load_scf()` function downloads data from user-specified survey years, carefully selecting relevant predictor and target variables that capture the essential relationships for imputation.\n",
"\n",
- "While the package provides this specialized function for SCF data, researchers can easily substitute their own data preparation pipeline as long as it produces properly formatted training and testing datasets that conform to the expected structure. For this, the `preprocess_data()` function enables basic normalization and defaults to train-test splitting on any dataset. If you would like to normalizing the data set without splitting it (for example in the event of performing cross-validation), set the `full_data` parameter to `False`. \n",
+ "The `preprocess_data()` applies normalization techniques to the features when normalize=True, ensuring that variables with different scales don't unduly influence the imputation models. This preprocessing step is crucial, particularly if introducing additional for methods like nearest neighbor matching that rely on distance calculations. Finally, the function splits the data into training and testing sets, maintaining the statistical properties of both sets while creating an appropriate evaluation framework. If you would like to normalizing the data set without splitting it (for example in the event of performing cross-validation), set the full_data parameter to False.\n",
"\n",
"```python\n",
"# Normalizing\n",
@@ -2607,7 +2721,7 @@
"\n",
"For particularly important decisions, enhance the reliability of your performance estimates through cross-validation techniques. Cross-validation provides a more stable estimate of model performance by averaging results across multiple train-test splits, reducing the impact of any particular data division. This approach is especially valuable when working with smaller datasets where a single train-test split might not be representative.\n",
"\n",
- "The package also supports detailed assessment of model behavior through train-test performance comparisons via the `plot_train_test_performance()` function. This visualization tool helps identify potential overfitting or underfitting issues by contrasting a model's performance on training data with its performance on held-out test data. Significant disparities between training and testing performance can reveal important limitations in a model's generalization capabilities.\n",
+ "The package also supports detailed assessment of model behavior through train-test performance comparisons via the `model_performance_results()` function. This visualization tool helps identify potential overfitting or underfitting issues by contrasting a model's performance on training data with its performance on held-out test data. Significant disparities between training and testing performance can reveal important limitations in a model's generalization capabilities.\n",
"\n",
"For specialized applications with particular interest in certain parts of the distribution, the framework accommodates custom quantile sets for targeted evaluation. Rather than using the default (random) quantiles, researchers can specify exactly which quantiles to evaluate, allowing focused assessment of performance in regions of particular interest. This flexibility enables tailored evaluations that align precisely with application-specific requirements and priorities.\n"
]
@@ -2615,7 +2729,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "pe",
+ "display_name": "pe3.13",
"language": "python",
"name": "python3"
},
@@ -2629,7 +2743,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.11"
+ "version": "3.13.0"
}
},
"nbformat": 4,
diff --git a/docs/models/imputer/implement-new-model.md b/docs/models/imputer/implement-new-model.md
index a04f328..d862de4 100644
--- a/docs/models/imputer/implement-new-model.md
+++ b/docs/models/imputer/implement-new-model.md
@@ -136,7 +136,7 @@ You can test the functionality of your newly implemented `NewModel` imputer mode
```python
from sklearn.datasets import load_diabetes
-from microimpute.comparisons.data import preprocess_data
+from microimpute.utils.data import preprocess_data
# Load the Diabetes dataset
diabetes = load_diabetes()
diff --git a/docs/models/imputer/index.md b/docs/models/imputer/index.md
index 6ad46d6..f24e0db 100644
--- a/docs/models/imputer/index.md
+++ b/docs/models/imputer/index.md
@@ -8,4 +8,4 @@ The Imputer architecture provides numerous benefits to the overall system design
The design carefully enforces proper usage by ensuring no model can call `predict()` without first fitting the model to the data. This logical constraint helps prevent common errors and makes the API more intuitive to use. Additionally, the base implementation handles validation of parameters and input data, reducing code duplication across different model implementations and ensuring that all models perform appropriate validation checks.
-When using the different imputers in isolation, and not as part of wider pipeline functions like `autoimpute` preprocessing and postprocessing is supported by `preprocess_data` and `postprocess_imputations` to ensure imputation takes place on data in the right format and can handle imputation of numerical, boolean and categorical variables. For an example of how to integrate them see [matching-imputation.ipynb](../matching/matching-imputation.ipynb).
+When using the different imputers in isolation, and not as part of wider pipeline functions like `autoimpute` preprocessing is supported by `preprocess_data` which can help normalize the data and split it into train and test splits. For an example of how to integrate them see [matching-imputation.ipynb](../matching/matching-imputation.ipynb).
diff --git a/docs/models/matching/matching-imputation.ipynb b/docs/models/matching/matching-imputation.ipynb
index eadc43f..90a8778 100644
--- a/docs/models/matching/matching-imputation.ipynb
+++ b/docs/models/matching/matching-imputation.ipynb
@@ -20,14 +20,12 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
+ "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <668E1903-F0E7-30D5-BA27-15F8287F87F7> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
"Trying to import in ABI mode.\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"PWD\" redefined by R and overriding existing variable. Current: \"/\", R: \"/Users/movil1/Desktop/PYTHONJOBS/PolicyEngine/microimpute/docs/models/matching\"\n",
+ "/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable \"PWD\" redefined by R and overriding existing variable. Current: \"/\", R: \"/Users/movil1/Desktop/PYTHONJOBS/PolicyEngine/microimpute/docs/models/matching\"\n",
" warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpIjUHLp\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmptntMbp\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.\n",
- " warn(\n"
+ "/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpQ8RhCP\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmp1cBQ0G\"\n",
+ " warnings.warn(\n"
]
}
],
@@ -54,7 +52,7 @@
"from microimpute.models import Matching\n",
"from microimpute.config import QUANTILES, RANDOM_STATE\n",
"from microimpute.visualizations.plotting import model_performance_results\n",
- "from microimpute.comparisons.data import preprocess_data"
+ "from microimpute.utils.data import preprocess_data"
]
},
{
@@ -108,7 +106,7 @@
" -0.002592 | \n",
" 0.019907 | \n",
" -0.017646 | \n",
- " True | \n",
+ " False | \n",
" 1 | \n",
" \n",
" \n",
@@ -136,7 +134,7 @@
" | -0.002592 | \n",
" 0.002861 | \n",
" -0.025930 | \n",
- " True | \n",
+ " False | \n",
" 3 | \n",
"
\n",
" \n",
@@ -150,7 +148,7 @@
" | 0.034309 | \n",
" 0.022688 | \n",
" -0.009362 | \n",
- " False | \n",
+ " True | \n",
" 4 | \n",
"
\n",
" \n",
@@ -192,7 +190,7 @@
" | -0.002592 | \n",
" 0.031193 | \n",
" 0.007207 | \n",
- " False | \n",
+ " True | \n",
" 438 | \n",
"
\n",
" \n",
@@ -234,7 +232,7 @@
" | 0.026560 | \n",
" 0.044529 | \n",
" -0.025930 | \n",
- " True | \n",
+ " False | \n",
" 441 | \n",
"
\n",
" \n",
@@ -248,7 +246,7 @@
" | -0.039493 | \n",
" -0.004222 | \n",
" 0.003064 | \n",
- " True | \n",
+ " False | \n",
" 442 | \n",
"
\n",
" \n",
@@ -258,17 +256,17 @@
],
"text/plain": [
" age sex bmi bp s1 ... s4 s5 s6 bool wgt\n",
- "0 0.038076 0.050680 0.061696 0.021872 -0.044223 ... -0.002592 0.019907 -0.017646 True 1\n",
+ "0 0.038076 0.050680 0.061696 0.021872 -0.044223 ... -0.002592 0.019907 -0.017646 False 1\n",
"1 -0.001882 -0.044642 -0.051474 -0.026328 -0.008449 ... -0.039493 -0.068332 -0.092204 False 2\n",
- "2 0.085299 0.050680 0.044451 -0.005670 -0.045599 ... -0.002592 0.002861 -0.025930 True 3\n",
- "3 -0.089063 -0.044642 -0.011595 -0.036656 0.012191 ... 0.034309 0.022688 -0.009362 False 4\n",
+ "2 0.085299 0.050680 0.044451 -0.005670 -0.045599 ... -0.002592 0.002861 -0.025930 False 3\n",
+ "3 -0.089063 -0.044642 -0.011595 -0.036656 0.012191 ... 0.034309 0.022688 -0.009362 True 4\n",
"4 0.005383 -0.044642 -0.036385 0.021872 0.003935 ... -0.002592 -0.031988 -0.046641 False 5\n",
".. ... ... ... ... ... ... ... ... ... ... ...\n",
- "437 0.041708 0.050680 0.019662 0.059744 -0.005697 ... -0.002592 0.031193 0.007207 False 438\n",
+ "437 0.041708 0.050680 0.019662 0.059744 -0.005697 ... -0.002592 0.031193 0.007207 True 438\n",
"438 -0.005515 0.050680 -0.015906 -0.067642 0.049341 ... 0.034309 -0.018114 0.044485 False 439\n",
"439 0.041708 0.050680 -0.015906 0.017293 -0.037344 ... -0.011080 -0.046883 0.015491 False 440\n",
- "440 -0.045472 -0.044642 0.039062 0.001215 0.016318 ... 0.026560 0.044529 -0.025930 True 441\n",
- "441 -0.045472 -0.044642 -0.073030 -0.081413 0.083740 ... -0.039493 -0.004222 0.003064 True 442\n",
+ "440 -0.045472 -0.044642 0.039062 0.001215 0.016318 ... 0.026560 0.044529 -0.025930 False 441\n",
+ "441 -0.045472 -0.044642 -0.073030 -0.081413 0.083740 ... -0.039493 -0.004222 0.003064 False 442\n",
"\n",
"[442 rows x 12 columns]"
]
@@ -466,6 +464,7 @@
"# Split data into training and testing sets, preprocessing data types all in one (this function also supports normalization)\n",
"X_train, X_test = preprocess_data(\n",
" diabetes_df,\n",
+ " full_data=False,\n",
" test_size=0.2,\n",
" normalize=False,\n",
")\n",
@@ -625,7 +624,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Modeling these quantiles: [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]\n"
+ "Modeling these quantiles: [np.float64(0.05), np.float64(0.1), np.float64(0.15), np.float64(0.2), np.float64(0.25), np.float64(0.3), np.float64(0.35), np.float64(0.4), np.float64(0.45), np.float64(0.5), np.float64(0.55), np.float64(0.6), np.float64(0.65), np.float64(0.7), np.float64(0.75), np.float64(0.8), np.float64(0.85), np.float64(0.9), np.float64(0.95)]\n"
]
}
],
@@ -639,15 +638,7 @@
"cell_type": "code",
"execution_count": 7,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"# Initialize the Matching imputer\n",
"matching_imputer = Matching()\n",
@@ -667,13 +658,6 @@
"execution_count": 8,
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n"
- ]
- },
{
"data": {
"text/html": [
@@ -723,13 +707,13 @@
" 321 | \n",
" -0.013953 | \n",
" -0.002592 | \n",
- " False | \n",
+ " True | \n",
" \n",
" \n",
" | 73 | \n",
" -0.031840 | \n",
" -0.039493 | \n",
- " True | \n",
+ " False | \n",
"
\n",
" \n",
"\n",
@@ -740,8 +724,8 @@
"287 0.024574 -0.039493 False\n",
"211 0.030078 -0.039493 True\n",
"72 0.038334 -0.039493 False\n",
- "321 -0.013953 -0.002592 False\n",
- "73 -0.031840 -0.039493 True"
+ "321 -0.013953 -0.002592 True\n",
+ "73 -0.031840 -0.039493 False"
]
},
"execution_count": 8,
@@ -769,7 +753,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -793,13 +777,13 @@
false,
0.10300345740307394,
-0.002592261998183278,
- false,
+ true,
0.05484510736603471,
0.14132210941786577,
- false,
+ true,
0.03833367306762126,
0.03430885887772673,
- false,
+ true,
0.09887559882847057,
-0.002592261998183278,
true,
@@ -808,10 +792,10 @@
true,
0.10988322169407955,
0.03430885887772673,
- true,
+ false,
-0.0249601584096303,
-0.002592261998183278,
- true,
+ false,
0.03695772020942014,
-0.002592261998183278,
false,
@@ -820,19 +804,19 @@
true,
-0.06761469701386505,
-0.002592261998183278,
- true,
+ false,
-0.05523112129005496,
-0.0763945037500033,
false,
0.014942474478202204,
0.03430885887772673,
- true,
+ false,
-0.027712064126032544,
-0.03949338287409329,
true,
-0.07311850844666953,
-0.03949338287409329,
- true,
+ false,
0.03833367306762126,
0.03430885887772673,
false,
@@ -844,16 +828,16 @@
true,
0.013566521620001083,
0.03430885887772673,
- false,
+ true,
-0.012576582685820214,
-0.002592261998183278,
- false,
+ true,
0.045213437358626866,
-0.002592261998183278,
true,
-0.007072771253015731,
-0.03949338287409329,
- false,
+ true,
0.016318427336403322,
-0.002592261998183278,
false,
@@ -865,13 +849,13 @@
true,
-0.02220825269322806,
-0.002592261998183278,
- false,
+ true,
-0.051103262715451604,
0.03430885887772673,
- true,
+ false,
-0.0249601584096303,
-0.0763945037500033,
- true,
+ false,
0.0342058144930179,
-0.002592261998183278,
false,
@@ -880,28 +864,28 @@
false,
0.0025588987543921156,
-0.002592261998183278,
- true,
+ false,
0.0025588987543921156,
-0.002592261998183278,
true,
-0.016704441260423575,
0.03430885887772673,
- false,
+ true,
0.045213437358626866,
0.03615391492152222,
- false,
+ true,
0.07823630595545376,
-0.002592261998183278,
- true,
+ false,
-0.011200629827619093,
-0.002592261998183278,
- false,
+ true,
0.03145390877661565,
0.019917421736121838,
- false,
+ true,
0.024574144485610048,
0.03430885887772673,
- true,
+ false,
-0.001568959820211247,
-0.03949338287409329,
false,
@@ -910,28 +894,28 @@
false,
-0.00019300696201012598,
-0.05056371913686628,
- false,
+ true,
-0.06623874415566393,
-0.002592261998183278,
- true,
+ false,
-0.004320865536613489,
0.07120997975363674,
- true,
+ false,
0.04383748450042574,
-0.014400620678474476,
- true,
+ false,
0.03282986163481677,
-0.03949338287409329,
- false,
+ true,
-0.038719686991641515,
-0.03949338287409329,
true,
-0.04422349842444599,
-0.0763945037500033,
- true,
+ false,
-0.035967781275239266,
-0.05167075276314359,
- false,
+ true,
-0.007072771253015731,
-0.002592261998183278,
false,
@@ -943,40 +927,40 @@
true,
-0.007072771253015731,
0.07120997975363674,
- true,
+ false,
-0.008448724111216851,
-0.03949338287409329,
true,
0.08924392882106273,
0.10811110062954676,
- false,
+ true,
-0.0249601584096303,
-0.03949338287409329,
- true,
+ false,
0.03282986163481677,
-0.002592261998183278,
- false,
+ true,
-0.04422349842444599,
-0.002592261998183278,
- true,
+ false,
-0.0029449126784123676,
-0.03949338287409329,
- true,
+ false,
-0.033215875558837024,
-0.0763945037500033,
- false,
+ true,
0.08236416453005713,
0.07120997975363674,
- false,
+ true,
-0.0318399227006359,
0.0029429061332032365,
- true,
+ false,
-0.04972730985725048,
-0.03949338287409329,
- true,
+ false,
0.010814615903598841,
-0.03949338287409329,
- true,
+ false,
-0.005696818394814609,
0.03430885887772673,
false,
@@ -988,7 +972,7 @@
true,
-0.007072771253015731,
-0.002592261998183278,
- true,
+ false,
-0.06348683843926169,
-0.03949338287409329,
false,
@@ -997,19 +981,19 @@
true,
-0.019456346976825818,
0.03430885887772673,
- true,
+ false,
0.039709625925822375,
0.07120997975363674,
- true,
+ false,
0.045213437358626866,
0.07120997975363674,
- false,
+ true,
-0.04972730985725048,
0.01585829843977173,
false,
-0.026336111267831423,
-0.03949338287409329,
- true,
+ false,
0.03833367306762126,
0.10811110062954676,
true,
@@ -1018,22 +1002,22 @@
false,
0.016318427336403322,
0.02655962349378563,
- true,
+ false,
0.020446285911006685,
-0.002592261998183278,
- false,
+ true,
0.01219056876179996,
0.10811110062954676,
false,
-0.0029449126784123676,
-0.03949338287409329,
- false,
+ true,
-0.046975404140848234,
-0.03949338287409329,
false,
-0.0029449126784123676,
-0.047242618258034386,
- true,
+ false,
0.04658939021682799,
-0.03949338287409329,
false,
@@ -1045,13 +1029,13 @@
true,
-0.08962994274508297,
-0.0763945037500033,
- true,
+ false,
-0.05935897986465832,
-0.03949338287409329,
- false,
+ true,
-0.030463969842434782,
-0.002592261998183278,
- true
+ false
],
"y": [
0.024574144485610048,
@@ -1065,25 +1049,25 @@
false,
-0.013952535544021335,
-0.002592261998183278,
- false,
+ true,
-0.0318399227006359,
-0.03949338287409329,
- true,
+ false,
0.04246153164222462,
-0.0763945037500033,
- true,
+ false,
0.041085578784023497,
0.07120997975363674,
- false,
+ true,
-0.062110885581060565,
0.026928634702544724,
false,
-0.04284754556624487,
-0.002592261998183278,
- false,
+ true,
-0.005696818394814609,
-0.03949338287409329,
- true,
+ false,
-0.015328488402222454,
-0.002592261998183278,
false,
@@ -1095,16 +1079,16 @@
true,
-0.005696818394814609,
-0.002592261998183278,
- false,
+ true,
-0.007072771253015731,
-0.03949338287409329,
- true,
+ false,
0.001182945896190995,
-0.015507654304751785,
- false,
+ true,
0.001182945896190995,
0.03430885887772673,
- true,
+ false,
-0.009824676969417972,
0.03430885887772673,
true,
@@ -1119,31 +1103,31 @@
false,
-0.034591828417038145,
-0.0763945037500033,
- false,
+ true,
-0.015328488402222454,
-0.002592261998183278,
false,
-0.004320865536613489,
0.03430885887772673,
- false,
+ true,
-0.05523112129005496,
-0.03949338287409329,
true,
-0.04284754556624487,
-0.002592261998183278,
- true,
+ false,
0.024574144485610048,
0.15534453535071155,
- false,
+ true,
-0.035967781275239266,
0.07120997975363674,
true,
-0.037343734133440394,
-0.03949338287409329,
- true,
+ false,
-0.0318399227006359,
-0.03949338287409329,
- true,
+ false,
-0.004320865536613489,
-0.0011162171631468765,
false,
@@ -1152,16 +1136,16 @@
true,
0.06998058880624704,
0.07120997975363674,
- false,
+ true,
0.006686757328995478,
0.03430885887772673,
- false,
+ true,
-0.02358420555142918,
-0.03949338287409329,
- false,
+ true,
0.06034891879883919,
0.10811110062954676,
- true,
+ false,
0.0080627101871966,
-0.002592261998183278,
false,
@@ -1170,7 +1154,7 @@
false,
0.0025588987543921156,
-0.03949338287409329,
- true,
+ false,
-0.0579830270064572,
-0.03949338287409329,
true,
@@ -1185,13 +1169,13 @@
false,
-0.07036660273026729,
-0.002592261998183278,
- true,
+ false,
0.027326050202012293,
-0.03949338287409329,
true,
-0.05660707414825608,
-0.03949338287409329,
- true,
+ false,
-0.0579830270064572,
-0.03949338287409329,
true,
@@ -1206,16 +1190,16 @@
true,
0.024574144485610048,
0.05091436327188625,
- true,
+ false,
-0.060734932722859444,
-0.0763945037500033,
- false,
+ true,
-0.007072771253015731,
0.03430885887772673,
- true,
+ false,
0.00943866304539772,
-0.002592261998183278,
- false,
+ true,
0.039709625925822375,
0.10811110062954676,
true,
@@ -1224,19 +1208,19 @@
false,
-0.04422349842444599,
-0.03949338287409329,
- true,
+ false,
0.001182945896190995,
0.03430885887772673,
- true,
+ false,
-0.05660707414825608,
-0.03949338287409329,
- true,
+ false,
0.08374011738825825,
-0.03949338287409329,
- true,
+ false,
-0.009824676969417972,
-0.002592261998183278,
- true,
+ false,
-0.05523112129005496,
-0.03949338287409329,
true,
@@ -1257,13 +1241,13 @@
true,
0.05071724879143135,
0.03430885887772673,
- false,
+ true,
-0.0579830270064572,
-0.03949338287409329,
true,
-0.02083229983502694,
0.07120997975363674,
- false,
+ true,
-0.037343734133440394,
-0.002592261998183278,
false,
@@ -1272,52 +1256,52 @@
false,
0.03558176735121902,
-0.0763945037500033,
- true,
+ false,
0.001182945896190995,
-0.015507654304751785,
- false,
+ true,
-0.0318399227006359,
-0.03949338287409329,
true,
-0.033215875558837024,
-0.002592261998183278,
- false,
+ true,
0.017694380194604446,
0.03430885887772673,
false,
-0.016704441260423575,
-0.002592261998183278,
- true,
+ false,
-0.04284754556624487,
-0.002592261998183278,
- true,
+ false,
0.04246153164222462,
-0.0763945037500033,
- true,
+ false,
-0.06623874415566393,
-0.03949338287409329,
- false,
+ true,
0.010814615903598841,
-0.03949338287409329,
true,
0.08374011738825825,
-0.03949338287409329,
- true,
+ false,
-0.04422349842444599,
-0.03949338287409329,
- true,
+ false,
-0.0029449126784123676,
-0.002592261998183278,
- false,
+ true,
0.03145390877661565,
-0.03949338287409329,
true,
-0.06623874415566393,
-0.03949338287409329,
- false,
+ true,
0.001182945896190995,
-0.007020396503292483,
- true,
+ false,
0.01219056876179996,
-0.03949338287409329,
false
@@ -2170,7 +2154,7 @@
}
},
"title": {
- "text": "Comparison of Actual vs. Imputed Values using Matching"
+ "text": "Comparison of actual vs. imputed values using Matching"
},
"width": 750,
"xaxis": {
@@ -2181,7 +2165,7 @@
"scaleanchor": "y",
"scaleratio": 1,
"title": {
- "text": "Actual Values"
+ "text": "Actual values"
}
},
"yaxis": {
@@ -2190,7 +2174,7 @@
1.0563390334958256
],
"title": {
- "text": "Imputed Values"
+ "text": "Imputed values"
}
}
}
@@ -2324,7 +2308,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -2448,7 +2432,7 @@
"[5 rows x 20 columns]"
]
},
- "execution_count": 11,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -2479,7 +2463,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -3781,7 +3765,7 @@
}
},
"title": {
- "text": "Matching Imputation Prediction Intervals"
+ "text": "Matching imputation prediction intervals"
},
"width": 750,
"xaxis": {
@@ -3789,7 +3773,7 @@
"gridwidth": 1,
"showgrid": true,
"title": {
- "text": "Data Record Index"
+ "text": "Data record index"
}
},
"yaxis": {
@@ -3937,7 +3921,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -3958,7 +3942,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -3966,88 +3950,9 @@
"output_type": "stream",
"text": [
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.\n",
- " warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.\n",
- " warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.\n",
- " warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.\n",
- " warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.\n",
- " warn(\n",
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/opt/homebrew/Cellar/r/4.5.0/lib/R/lib/libRblas.dylib' (no such file)\")\n",
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/opt/homebrew/Cellar/r/4.5.0/lib/R/lib/libRblas.dylib' (no such file)\")\n",
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/opt/homebrew/Cellar/r/4.5.0/lib/R/lib/libRblas.dylib' (no such file)\")\n",
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/opt/homebrew/Cellar/r/4.5.0/lib/R/lib/libRblas.dylib' (no such file)\")\n",
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/opt/homebrew/Cellar/r/4.5.0/lib/R/lib/libRblas.dylib' (no such file)\")\n",
- "Trying to import in ABI mode.\n",
- "Trying to import in ABI mode.\n",
- "Trying to import in ABI mode.\n",
- "Trying to import in ABI mode.\n",
- "Trying to import in ABI mode.\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpba7Rkm\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpnvHr2r\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpba7Rkm\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpYPOjOX\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpba7Rkm\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpMcHnyN\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpba7Rkm\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpPTk9Wv\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpba7Rkm\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpJeultz\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpMcHnyN\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpN142sg\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpPTk9Wv\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpXfkfUb\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpnvHr2r\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpnFs5HB\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpYPOjOX\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpUn2V5P\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpJeultz\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpd3OfR6\"\n",
- " warnings.warn(\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n",
- "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 7.5s remaining: 11.3s\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 7.6s remaining: 5.0s\n",
- "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 7.6s finished\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.\n",
- " warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.\n",
- " warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:623: UserWarning: is not a Python type (it may be an instance of an object), Pydantic will allow any object with no validation since we cannot even enforce that the input is an instance of the given type. To get rid of this error wrap the type with `pydantic.SkipValidation`.\n",
- " warn(\n",
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/opt/homebrew/Cellar/r/4.5.0/lib/R/lib/libRblas.dylib' (no such file)\")\n",
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/opt/homebrew/Cellar/r/4.5.0/lib/R/lib/libRblas.dylib' (no such file)\")\n",
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/opt/homebrew/Cellar/r/4.5.0/lib/R/lib/libRblas.dylib' (no such file)\")\n",
- "Trying to import in ABI mode.\n",
- "Trying to import in ABI mode.\n",
- "Trying to import in ABI mode.\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpba7Rkm\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpT7iokK\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpba7Rkm\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpaRRmzh\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpba7Rkm\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpY7AQBW\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpY7AQBW\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpjP7jok\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpaRRmzh\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//Rtmpoj8O4v\"\n",
- " warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpT7iokK\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpqQatPx\"\n",
- " warnings.warn(\n"
+ "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 11.3s remaining: 17.0s\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 11.4s remaining: 7.6s\n",
+ "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 11.5s finished\n"
]
},
{
@@ -4093,7 +3998,7 @@
" 0.000000 | \n",
" 0.000000 | \n",
" ... | \n",
- " 0.00000 | \n",
+ " 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
@@ -4101,17 +4006,17 @@
" \n",
" \n",
" | test | \n",
- " 0.024143 | \n",
- " 0.024089 | \n",
- " 0.024035 | \n",
- " 0.023982 | \n",
- " 0.023928 | \n",
+ " 0.023949 | \n",
+ " 0.023909 | \n",
+ " 0.023868 | \n",
+ " 0.023828 | \n",
+ " 0.023787 | \n",
" ... | \n",
- " 0.02339 | \n",
- " 0.023336 | \n",
- " 0.023282 | \n",
- " 0.023229 | \n",
- " 0.023175 | \n",
+ " 0.023382 | \n",
+ " 0.023342 | \n",
+ " 0.023302 | \n",
+ " 0.023261 | \n",
+ " 0.023221 | \n",
"
\n",
" \n",
"\n",
@@ -4119,14 +4024,14 @@
""
],
"text/plain": [
- " 0.05 0.10 0.15 0.20 0.25 ... 0.75 0.80 0.85 0.90 0.95\n",
- "train 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.00000 0.000000 0.000000 0.000000 0.000000\n",
- "test 0.024143 0.024089 0.024035 0.023982 0.023928 ... 0.02339 0.023336 0.023282 0.023229 0.023175\n",
+ " 0.05 0.10 0.15 0.20 0.25 ... 0.75 0.80 0.85 0.90 0.95\n",
+ "train 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000\n",
+ "test 0.023949 0.023909 0.023868 0.023828 0.023787 ... 0.023382 0.023342 0.023302 0.023261 0.023221\n",
"\n",
"[2 rows x 19 columns]"
]
},
- "execution_count": 14,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -4143,7 +4048,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
@@ -4230,25 +4135,25 @@
"0.95"
],
"y": [
- 0.024143030757925923,
- 0.024089235955485365,
- 0.0240354411530448,
- 0.023981646350604242,
- 0.023927851548163674,
- 0.023874056745723116,
- 0.02382026194328255,
- 0.02376646714084199,
- 0.023712672338401426,
- 0.02365887753596086,
- 0.0236050827335203,
- 0.023551287931079738,
- 0.023497493128639173,
- 0.023443698326198616,
- 0.02338990352375805,
- 0.02333610872131749,
- 0.023282313918876925,
- 0.023228519116436364,
- 0.023174724313995802
+ 0.023949237081135385,
+ 0.02390875418287496,
+ 0.02386827128461453,
+ 0.023827788386354097,
+ 0.023787305488093664,
+ 0.023746822589833235,
+ 0.0237063396915728,
+ 0.023665856793312373,
+ 0.023625373895051937,
+ 0.023584890996791508,
+ 0.023544408098531075,
+ 0.023503925200270646,
+ 0.023463442302010213,
+ 0.023422959403749784,
+ 0.02338247650548935,
+ 0.023341993607228922,
+ 0.02330151070896849,
+ 0.023261027810708056,
+ 0.023220544912447627
]
}
],
@@ -5086,7 +4991,7 @@
}
},
"title": {
- "text": "Matching Cross-Validation Performance"
+ "text": "Matching cross-validation performance"
},
"width": 750,
"xaxis": {
@@ -5141,7 +5046,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -5162,17 +5067,9 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 16,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"# To set specific hyperparameters pass them when fitting the model\n",
"fitted_matching_imputer = matching_imputer.fit(\n",
@@ -5185,7 +5082,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -5206,16 +5103,9 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 18,
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Found 1 numeric columns with unique values < 10, treating as categorical: ['sex']. Converting to dummy variables.\n"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
@@ -5239,7 +5129,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "pe",
+ "display_name": "pe3.13",
"language": "python",
"name": "python3"
},
@@ -5253,7 +5143,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.11"
+ "version": "3.13.0"
}
},
"nbformat": 4,
diff --git a/docs/models/ols/ols-imputation.ipynb b/docs/models/ols/ols-imputation.ipynb
index 0b193d1..fe15753 100644
--- a/docs/models/ols/ols-imputation.ipynb
+++ b/docs/models/ols/ols-imputation.ipynb
@@ -20,7 +20,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -50,7 +50,7 @@
"pd.set_option(\"display.expand_frame_repr\", False)\n",
"\n",
"# Import MicroImpute tools\n",
- "from microimpute.comparisons.data import preprocess_data\n",
+ "from microimpute.utils.data import preprocess_data\n",
"from microimpute.evaluations import *\n",
"from microimpute.models import OLS\n",
"from microimpute.config import QUANTILES\n",
diff --git a/docs/models/qrf/qrf-imputation.ipynb b/docs/models/qrf/qrf-imputation.ipynb
index d26cfb2..554b0ce 100644
--- a/docs/models/qrf/qrf-imputation.ipynb
+++ b/docs/models/qrf/qrf-imputation.ipynb
@@ -3,7 +3,20 @@
{
"cell_type": "markdown",
"metadata": {},
- "source": "# Quantile Regression Forest (QRF) imputation\n\nThis notebook demonstrates how to use MicroImpute's QRF imputer to impute values using Quantile Regression Forests. QRF extends traditional random forests to predict the entire conditional distribution of a target variable.\n\nThe QRF model supports iterative imputation with a single object and workflow. Pass a list of `imputed_variables` with all variables you want to impute, and the model imputes them sequentially."
+ "source": [
+ "# Quantile Regression Forest (QRF) imputation\n",
+ "\n",
+ "This notebook demonstrates how to use MicroImpute's QRF imputer to impute values using Quantile Regression Forests. QRF extends traditional random forests to predict the entire conditional distribution of a target variable.\n",
+ "\n",
+ "The QRF model supports sequential imputation with a single object and workflow. Pass a list of `imputed_variables` with all variables you want to impute, and the model imputes them sequentially. This means that previously imputed variables will serve as predictors for subsequent variables, capturing complex dependencies between the imputed variables.\n",
+ "\n",
+ "### How sequential imputation works\n",
+ "\n",
+ "1. **Variable 1**: Uses only the original predictors\n",
+ "2. **Variable 2**: Uses original predictors + Variable 1's imputed values \n",
+ "3. **Variable 3**: Uses original predictors + Variables 1 & 2's imputed values\n",
+ "4. And so on..."
+ ]
},
{
"cell_type": "markdown",
@@ -21,7 +34,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
+ "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <668E1903-F0E7-30D5-BA27-15F8287F87F7> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
"Trying to import in ABI mode.\n"
]
}
@@ -43,7 +56,7 @@
"pd.set_option(\"display.expand_frame_repr\", False)\n",
"\n",
"# Import MicroImpute tools\n",
- "from microimpute.comparisons.data import preprocess_data\n",
+ "from microimpute.utils.data import preprocess_data\n",
"from microimpute.evaluations import *\n",
"from microimpute.models import QRF\n",
"from microimpute.config import QUANTILES\n",
@@ -485,7 +498,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Modeling these quantiles: [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]\n"
+ "Modeling these quantiles: [np.float64(0.05), np.float64(0.1), np.float64(0.15), np.float64(0.2), np.float64(0.25), np.float64(0.3), np.float64(0.35), np.float64(0.4), np.float64(0.45), np.float64(0.5), np.float64(0.55), np.float64(0.6), np.float64(0.65), np.float64(0.7), np.float64(0.75), np.float64(0.8), np.float64(0.85), np.float64(0.9), np.float64(0.95)]\n"
]
}
],
@@ -548,41 +561,41 @@
" \n",
" \n",
" \n",
- " | 0 | \n",
+ " 287 | \n",
" -0.015328 | \n",
" -0.039493 | \n",
"
\n",
" \n",
- " | 1 | \n",
+ " 211 | \n",
" 0.039710 | \n",
- " -0.002592 | \n",
+ " 0.034309 | \n",
"
\n",
" \n",
- " | 2 | \n",
+ " 72 | \n",
" 0.069981 | \n",
- " 0.034309 | \n",
+ " 0.071210 | \n",
"
\n",
" \n",
- " | 3 | \n",
+ " 321 | \n",
" 0.046589 | \n",
- " 0.034309 | \n",
+ " 0.050914 | \n",
"
\n",
" \n",
- " | 4 | \n",
+ " 73 | \n",
" 0.031454 | \n",
- " 0.034309 | \n",
+ " 0.056081 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " s1 s4\n",
- "0 -0.015328 -0.039493\n",
- "1 0.039710 -0.002592\n",
- "2 0.069981 0.034309\n",
- "3 0.046589 0.034309\n",
- "4 0.031454 0.034309"
+ " s1 s4\n",
+ "287 -0.015328 -0.039493\n",
+ "211 0.039710 0.034309\n",
+ "72 0.069981 0.071210\n",
+ "321 0.046589 0.050914\n",
+ "73 0.031454 0.056081"
]
},
"execution_count": 8,
@@ -601,65 +614,12 @@
},
{
"cell_type": "markdown",
- "source": "### Benefits of Sequential Imputation\n\n1. **Captures dependencies**: Related variables benefit from using earlier imputations\n2. **Improves accuracy**: Often produces more accurate imputations when variables correlate\n3. **Preserves relationships**: Better maintains the joint distribution of imputed variables\n\n### When to Use Sequential vs Parallel\n\n- **Use sequential** when:\n - Variables correlate or have dependencies\n - You want to preserve relationships between imputed variables\n - The order of imputation makes logical sense (e.g., impute income before tax)\n\n- **Use parallel** when:\n - Variables are independent\n - You need faster computation (parallel can be distributed)\n - The order of imputation doesn't matter\n\n### Implementation Details\n\nSequential imputation uses the `_get_sequential_predictors` helper function to build the predictor set for each variable:\n\n```python\ndef _get_sequential_predictors(predictors, imputed_variables, current_variable_index):\n \"\"\"Get the predictor set for sequential imputation.\"\"\"\n return predictors + imputed_variables[:current_variable_index]\n```\n\nThis ensures that each variable uses all previously imputed variables as additional predictors.",
- "metadata": {}
- },
- {
- "cell_type": "code",
+ "metadata": {},
"source": [
- "# Sequential imputation (default behavior when passing multiple variables)\n",
- "sequential_imputer = QRF()\n",
- "sequential_fitted = sequential_imputer.fit(\n",
- " X_train,\n",
- " predictors,\n",
- " [\"s1\", \"s2\", \"s3\"], # Multiple variables\n",
- " n_estimators=50,\n",
- ")\n",
- "\n",
- "# Get sequential predictions\n",
- "sequential_preds = sequential_fitted.predict(\n",
- " X_test_missing.head(5), quantiles=[0.5]\n",
- ")\n",
- "print(\"Sequential imputation predictions:\")\n",
- "print(sequential_preds[0.5])\n",
- "\n",
- "# Parallel imputation (fitting each variable separately)\n",
- "parallel_predictions = pd.DataFrame()\n",
- "\n",
- "for var in [\"s1\", \"s2\", \"s3\"]:\n",
- " single_imputer = QRF()\n",
- " single_fitted = single_imputer.fit(\n",
- " X_train,\n",
- " predictors,\n",
- " [var], # Only one variable at a time\n",
- " n_estimators=50,\n",
- " )\n",
+ "## Evaluating the imputation results\n",
"\n",
- " single_pred = single_fitted.predict(\n",
- " X_test_missing.head(5), quantiles=[0.5]\n",
- " )\n",
- " parallel_predictions[var] = single_pred[0.5][var]\n",
- "\n",
- "print(\"\\nParallel imputation predictions:\")\n",
- "print(parallel_predictions)\n",
- "\n",
- "# Compare the results\n",
- "print(\"\\nDifference (Sequential - Parallel):\")\n",
- "print(sequential_preds[0.5] - parallel_predictions)"
- ],
- "metadata": {},
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": "## Sequential Imputation\n\nWhen imputing multiple variables, QRF uses a **sequential imputation** approach. Previously imputed variables serve as predictors for subsequent variables, capturing complex dependencies between the imputed variables.\n\n### How Sequential Imputation Works\n\n1. **Variable 1**: Uses only the original predictors\n2. **Variable 2**: Uses original predictors + Variable 1's imputed values \n3. **Variable 3**: Uses original predictors + Variables 1 & 2's imputed values\n4. And so on...\n\n### Example: Sequential vs Parallel Imputation\n\nLet's demonstrate the difference between sequential and parallel imputation:",
- "metadata": {}
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": "## Evaluating the imputation results\n\nNow let's compare the imputed values with the actual values to evaluate the performance of our imputer. To understand QRF's power to capture variability across quantiles, let's find and plot the prediction closest to the true value across quantiles for each data point."
+ "Now let's compare the imputed values with the actual values to evaluate the performance of our imputer. To understand QRF's power to capture variability across quantiles, let's find and plot the prediction closest to the true value across quantiles for each data point."
+ ]
},
{
"cell_type": "code",
@@ -869,13 +829,13 @@
"xaxis": "x",
"y": [
0.06310082451524143,
- -0.002592261998183278,
+ 0.03430885887772673,
-0.027712064126032544,
-0.03949338287409329,
0.08786797596286161,
-0.002592261998183278,
0.05759701308243695,
- 0.07120997975363674,
+ 0.10811110062954676,
0.03145390877661565,
0.03430885887772673,
0.04246153164222462,
@@ -889,7 +849,7 @@
0.03695772020942014,
-0.002592261998183278,
0.04246153164222462,
- -0.03949338287409329,
+ -0.002592261998183278,
-0.05385516843185383,
-0.002592261998183278,
-0.0579830270064572,
@@ -911,7 +871,7 @@
-0.02358420555142918,
-0.002592261998183278,
0.05209320164963247,
- -0.021411833644897377,
+ -0.002592261998183278,
-0.007072771253015751,
-0.03949338287409329,
0.020446285911006685,
@@ -929,9 +889,9 @@
0.03833367306762126,
-0.002592261998183278,
0.02319819162740893,
- -0.002592261998183278,
+ 0.05091436327188625,
-0.00019300696201012598,
- -0.002592261998183278,
+ 0.03430885887772673,
0.0011829458961909658,
-0.002592261998183278,
-0.009824676969417983,
@@ -945,7 +905,7 @@
0.020446285911006685,
0.03430885887772673,
0.019070333052805567,
- -0.002592261998183278,
+ 0.03430885887772673,
0.001182945896190995,
-0.03949338287409329,
0.001182945896190995,
@@ -959,7 +919,7 @@
0.03558176735121902,
-0.002592261998183278,
0.020446285911006685,
- -0.03949338287409329,
+ -0.0763945037500033,
-0.038719686991641515,
-0.03949338287409329,
-0.037343734133440394,
@@ -969,7 +929,7 @@
-0.009824676969417983,
-0.002592261998183278,
-0.04835135699904936,
- -0.021411833644897377,
+ -0.03949338287409329,
-0.037343734133440394,
-0.03949338287409329,
-0.0029449126784123775,
@@ -977,7 +937,7 @@
-0.009824676969417972,
-0.03949338287409329,
0.05484510736603471,
- 0.07120997975363674,
+ 0.10811110062954676,
-0.016704441260423575,
-0.03949338287409329,
-0.005696818394814609,
@@ -989,7 +949,7 @@
-0.037343734133440394,
-0.0763945037500033,
0.04796534307502911,
- -0.002592261998183278,
+ 0.07120997975363674,
-0.037343734133440394,
-0.002592261998183278,
-0.037343734133440394,
@@ -997,11 +957,11 @@
0.001182945896190995,
-0.03949338287409329,
-0.008448724111216851,
- -0.007020396503292483,
+ 0.03430885887772673,
0.06447677737344255,
-0.002592261998183278,
0.06034891879883919,
- 0.07120997975363674,
+ 0.05017634085436802,
-0.004320865536613489,
-0.002592261998183278,
-0.06623874415566393,
@@ -1011,23 +971,23 @@
-0.016704441260423575,
0.03430885887772673,
0.030077955918414535,
- 0.07120997975363674,
+ 0.03430885887772673,
0.05071724879143135,
0.07120997975363674,
-0.05660707414825608,
- -0.0018542395806650938,
+ 0.03430885887772673,
-0.0318399227006359,
-0.03949338287409329,
0.027326050202012293,
- 0.03430885887772673,
+ 0.07120997975363674,
0.08786797596286161,
- 0.07120997975363652,
+ 0.05091436327188625,
0.012190568761799941,
- 0.003311917341962329,
+ -0.033958214742706834,
0.020446285911006685,
-0.002592261998183278,
0.013566521620001064,
- -0.002592261998183278,
+ 0.03430885887772673,
0.00806271018719654,
-0.03949338287409329,
-0.046975404140848234,
@@ -1059,11 +1019,11 @@
"type": "scatter",
"x": [
-0.12678066991651324,
- 0.14132210941786577
+ 0.18523444326019867
],
"y": [
-0.12678066991651324,
- 0.14132210941786577
+ 0.18523444326019867
]
}
],
@@ -3619,7 +3579,11 @@
{
"cell_type": "markdown",
"metadata": {},
- "source": "## Assessing the method's performance\n\nTo verify our model doesn't overfit and ensure robust results, we can perform cross-validation and visualize the results."
+ "source": [
+ "## Assessing the method's performance\n",
+ "\n",
+ "To verify our model doesn't overfit and ensure robust results, we can perform cross-validation and visualize the results."
+ ]
},
{
"cell_type": "code",
@@ -3630,16 +3594,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 2.4s remaining: 3.7s\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 2.5s remaining: 1.6s\n",
- "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 2.5s finished\n"
+ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
+ "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 3.4s remaining: 5.1s\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 3.4s remaining: 2.3s\n",
+ "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 3.4s finished\n"
]
},
{
@@ -3679,31 +3637,31 @@
" \n",
" \n",
" | train | \n",
- " 0.001632 | \n",
- " 0.003260 | \n",
- " 0.004382 | \n",
- " 0.005816 | \n",
- " 0.006768 | \n",
+ " 0.001967 | \n",
+ " 0.003999 | \n",
+ " 0.005333 | \n",
+ " 0.007202 | \n",
+ " 0.008227 | \n",
" ... | \n",
- " 0.004193 | \n",
- " 0.003749 | \n",
- " 0.003189 | \n",
- " 0.002631 | \n",
- " 0.001613 | \n",
+ " 0.005756 | \n",
+ " 0.005267 | \n",
+ " 0.004522 | \n",
+ " 0.003761 | \n",
+ " 0.002310 | \n",
"
\n",
" \n",
" | test | \n",
- " 0.005352 | \n",
- " 0.008342 | \n",
- " 0.011274 | \n",
- " 0.013143 | \n",
- " 0.016183 | \n",
+ " 0.004520 | \n",
+ " 0.007870 | \n",
+ " 0.011132 | \n",
+ " 0.013861 | \n",
+ " 0.016690 | \n",
" ... | \n",
- " 0.017408 | \n",
- " 0.015590 | \n",
- " 0.013659 | \n",
- " 0.010796 | \n",
- " 0.007217 | \n",
+ " 0.018722 | \n",
+ " 0.016499 | \n",
+ " 0.013710 | \n",
+ " 0.010794 | \n",
+ " 0.006778 | \n",
"
\n",
" \n",
"\n",
@@ -3712,8 +3670,8 @@
],
"text/plain": [
" 0.05 0.10 0.15 0.20 0.25 ... 0.75 0.80 0.85 0.90 0.95\n",
- "train 0.001632 0.003260 0.004382 0.005816 0.006768 ... 0.004193 0.003749 0.003189 0.002631 0.001613\n",
- "test 0.005352 0.008342 0.011274 0.013143 0.016183 ... 0.017408 0.015590 0.013659 0.010796 0.007217\n",
+ "train 0.001967 0.003999 0.005333 0.007202 0.008227 ... 0.005756 0.005267 0.004522 0.003761 0.002310\n",
+ "test 0.004520 0.007870 0.011132 0.013861 0.016690 ... 0.018722 0.016499 0.013710 0.010794 0.006778\n",
"\n",
"[2 rows x 19 columns]"
]
@@ -3737,7 +3695,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -3775,25 +3733,25 @@
"0.95"
],
"y": [
- 0.0016322876282474183,
- 0.0032598945270124816,
- 0.0043824801378580915,
- 0.005815932308708196,
- 0.00676840055993066,
- 0.006741343067459824,
- 0.006393620796671657,
- 0.007068155206803815,
- 0.006712925916996848,
- 0.006561256549453533,
- 0.005538741426004324,
- 0.005264346234632853,
- 0.004508080850825122,
- 0.00403600112659596,
- 0.00419347622527966,
- 0.003748593867643589,
- 0.0031885489143531993,
- 0.0026313199597614613,
- 0.0016134470119617231
+ 0.0019670433713474544,
+ 0.003999197655156622,
+ 0.005333139306771688,
+ 0.007201748319103865,
+ 0.00822653405002437,
+ 0.008447665136067041,
+ 0.00814043105374781,
+ 0.008881960493892416,
+ 0.008437250297322941,
+ 0.008383766682265571,
+ 0.007259383947842546,
+ 0.0071726184738474304,
+ 0.006197312504980799,
+ 0.005609462003193234,
+ 0.005755736479811034,
+ 0.005266949046775604,
+ 0.00452158138273317,
+ 0.003761365065012276,
+ 0.0023099629328434757
]
},
{
@@ -3824,25 +3782,25 @@
"0.95"
],
"y": [
- 0.005352171226539223,
- 0.008342133722903221,
- 0.011273888664037245,
- 0.01314320339190125,
- 0.01618337544835127,
- 0.018554833559477395,
- 0.0196085243119515,
- 0.020654369858720927,
- 0.022360954681851557,
- 0.0220981470406656,
- 0.021789107592053002,
- 0.021260985296343494,
- 0.020212650758525342,
- 0.019000963803278897,
- 0.017408190126187738,
- 0.015589599833021823,
- 0.013658699070594785,
- 0.01079551091418708,
- 0.007216935247235432
+ 0.004520365920135072,
+ 0.007870428843228475,
+ 0.01113193972961984,
+ 0.013860904750387554,
+ 0.016689810586868782,
+ 0.019171101595596866,
+ 0.020931660127358354,
+ 0.021842069561388972,
+ 0.023747004741023488,
+ 0.02316096952448388,
+ 0.02282294983336111,
+ 0.022316332853161047,
+ 0.021388065478740385,
+ 0.020244041100969282,
+ 0.01872178310065714,
+ 0.01649923978979649,
+ 0.013710297740183203,
+ 0.010793916043219704,
+ 0.006778420409803051
]
}
],
@@ -4680,7 +4638,7 @@
}
},
"title": {
- "text": "QRF Cross-validation performance"
+ "text": "QRF cross-validation performance"
},
"width": 750,
"xaxis": {
@@ -4720,7 +4678,11 @@
{
"cell_type": "markdown",
"metadata": {},
- "source": "## Tuning the QRF model\n\nThe QRF imputer supports various parameters that you can adjust to improve performance. To set specific values you know improve performance for your dataset, see below. Additionally, automatic hyperparameter tuning specific to the target dataset is available by setting the parameter `tune_hyperparameters` to `True`."
+ "source": [
+ "## Tuning the QRF model\n",
+ "\n",
+ "The QRF imputer supports various parameters that you can adjust to improve performance. To set specific values you know improve performance for your dataset, see below. Additionally, automatic hyperparameter tuning specific to the target dataset is available by setting the parameter `tune_hyperparameters` to `True`."
+ ]
},
{
"cell_type": "code",
@@ -4776,7 +4738,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "{'n_estimators': 92, 'min_samples_split': 15, 'min_samples_leaf': 3, 'max_features': 0.41424062149385665, 'bootstrap': True}\n"
+ "{'n_estimators': 219, 'min_samples_split': 5, 'min_samples_leaf': 9, 'max_features': 0.3376633298161329, 'bootstrap': True}\n"
]
}
],
@@ -4795,7 +4757,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "pe",
+ "display_name": "pe3.13",
"language": "python",
"name": "python3"
},
@@ -4809,9 +4771,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.11"
+ "version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
-}
\ No newline at end of file
+}
diff --git a/docs/models/quantreg/quantreg-imputation.ipynb b/docs/models/quantreg/quantreg-imputation.ipynb
index c7b1d24..a5fddcf 100644
--- a/docs/models/quantreg/quantreg-imputation.ipynb
+++ b/docs/models/quantreg/quantreg-imputation.ipynb
@@ -27,11 +27,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <38886600-97A2-37BA-9F86-5263C9A3CF6D> /Users/movil1/envs/pe/lib/python3.11/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
+ "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <668E1903-F0E7-30D5-BA27-15F8287F87F7> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n",
"Trying to import in ABI mode.\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"PWD\" redefined by R and overriding existing variable. Current: \"/\", R: \"/Users/movil1/Desktop/PYTHONJOBS/PolicyEngine/microimpute/docs/models/quantreg\"\n",
+ "/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable \"PWD\" redefined by R and overriding existing variable. Current: \"/\", R: \"/Users/movil1/Desktop/PYTHONJOBS/PolicyEngine/microimpute/docs/models/quantreg\"\n",
" warnings.warn(\n",
- "/Users/movil1/envs/pe/lib/python3.11/site-packages/rpy2/rinterface/__init__.py:1185: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmppjOS8x\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpaZGEXJ\"\n",
+ "/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/rpy2/rinterface/__init__.py:1211: UserWarning: Environment variable \"R_SESSION_TMPDIR\" redefined by R and overriding existing variable. Current: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpvGTNaK\", R: \"/var/folders/yg/xdp70k_n4qj9ph1_0lm435c00000gp/T//RtmpZ2HAUU\"\n",
" warnings.warn(\n"
]
}
@@ -51,7 +51,7 @@
"pd.set_option(\"display.expand_frame_repr\", False)\n",
"\n",
"# Import MicroImpute tools\n",
- "from microimpute.comparisons.data import preprocess_data\n",
+ "from microimpute.utils.data import preprocess_data\n",
"from microimpute.evaluations import *\n",
"from microimpute.models import QuantReg\n",
"from microimpute.config import QUANTILES\n",
@@ -493,7 +493,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Modeling these quantiles: [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]\n"
+ "Modeling these quantiles: [np.float64(0.05), np.float64(0.1), np.float64(0.15), np.float64(0.2), np.float64(0.25), np.float64(0.3), np.float64(0.35), np.float64(0.4), np.float64(0.45), np.float64(0.5), np.float64(0.55), np.float64(0.6), np.float64(0.65), np.float64(0.7), np.float64(0.75), np.float64(0.8), np.float64(0.85), np.float64(0.9), np.float64(0.95)]\n"
]
}
],
@@ -555,27 +555,27 @@
" \n",
" | 287 | \n",
" 0.005433 | \n",
- " -0.016971 | \n",
+ " -0.016880 | \n",
"
\n",
" \n",
" | 211 | \n",
" 0.029121 | \n",
- " 0.002011 | \n",
+ " 0.002169 | \n",
"
\n",
" \n",
" | 72 | \n",
" 0.008247 | \n",
- " 0.012233 | \n",
+ " 0.012453 | \n",
"
\n",
" \n",
" | 321 | \n",
" 0.041591 | \n",
- " 0.006420 | \n",
+ " 0.006506 | \n",
"
\n",
" \n",
" | 73 | \n",
" -0.005648 | \n",
- " 0.001657 | \n",
+ " 0.001787 | \n",
"
\n",
" \n",
"\n",
@@ -583,11 +583,11 @@
],
"text/plain": [
" s1 s4\n",
- "287 0.005433 -0.016971\n",
- "211 0.029121 0.002011\n",
- "72 0.008247 0.012233\n",
- "321 0.041591 0.006420\n",
- "73 -0.005648 0.001657"
+ "287 0.005433 -0.016880\n",
+ "211 0.029121 0.002169\n",
+ "72 0.008247 0.012453\n",
+ "321 0.041591 0.006506\n",
+ "73 -0.005648 0.001787"
]
},
"execution_count": 8,
@@ -820,184 +820,184 @@
],
"xaxis": "x",
"y": [
- 0.08014287312680966,
- 0.03714649322909738,
- -0.026031660118498542,
- -0.04419388252136774,
- 0.09101448103178356,
- -0.0037397391866475296,
- 0.0548424202341721,
- 0.11826947324914402,
- 0.043667170248303216,
- 0.03236322750195658,
- 0.05947393085004187,
- -0.005766866352488934,
- 0.030816567306961723,
- 0.03544781498532656,
- 0.10668826264913452,
- 0.03673488295358713,
- -0.02997727997825208,
- -0.001172284754078809,
- 0.0421526910144827,
- -0.0020428217513670174,
- 0.04340403812560857,
- -0.02464140292550899,
- -0.07011149971643708,
- -0.0006709918804685732,
- -0.055786262963924145,
- -0.07796135775451905,
- 0.015576453203716114,
- 0.03517122107452982,
- -0.024743406630458025,
- -0.03805415173684641,
- -0.07259407040874181,
- -0.03840622837918534,
- 0.044101955692491444,
- 0.03382018457860895,
- 0.01890113115138764,
- 0.07033983877354757,
- -0.05736030794522752,
- -0.02947084792047011,
- 0.015880100228170643,
- 0.03196601875855348,
- -0.010779699976034505,
- -0.0032634057101995585,
- 0.04513439333830424,
- -0.003923939642998609,
- -0.0047263377854795605,
- -0.03980332752671312,
- 0.014840596881445175,
- 0.0003765539539652245,
- -0.013392987786084123,
- -0.035679435087377295,
- -0.029262629822172412,
- -0.004785461066430709,
- -0.022649019467823002,
- -0.005720324109430425,
- -0.04833045001104161,
- 0.032855304686263236,
- -0.02316257367133703,
- -0.07271741236867253,
- 0.03119953424549745,
- -0.00337410992833476,
- 0.01117040529246721,
- 0.019065749451247074,
- 0.001096593711312547,
- 0.0032764776241989745,
- -0.0007842682110089849,
- -0.0039152303656053294,
- -0.018175605741099096,
- 0.03188240383660154,
- 0.04589149969653419,
- 0.037680794786623825,
- 0.06381621551931232,
- -0.001300099500522376,
- -0.011581513455535248,
- -0.00356416384429303,
- 0.034817930797391275,
- 0.02147395520172911,
- 0.023771293901034028,
- 0.03828697235220315,
- -0.00426516603747955,
- -0.04043128049541657,
- 0.001133505063255234,
- -0.03988486514186909,
- 0.0011644884776124964,
- -0.051702018967330225,
- -0.06500683147104241,
- -0.002032228622198609,
- 0.0005104131195645305,
- 0.06427388343826385,
- 0.04725239836343738,
- -0.015297600677867899,
- 0.03663163021940125,
- -0.04143857996849862,
- -0.04101940705516686,
- -0.03869561116191949,
- -0.04229588426870764,
- -0.07291029934716144,
- -0.033572063560771026,
- -0.05135325196688825,
- -0.005753218358740356,
- -0.0028314338440194437,
- -0.05050664993692761,
- -0.04388589399336798,
- -0.022702186982402443,
- -0.04079343726026801,
- -0.0070652821976598485,
- 0.07004192563517589,
- -0.010178744858492876,
- -0.039931248455872606,
- 0.06951359663413569,
- 0.08941104316065188,
- -0.027397943204620853,
- -0.04023029424978346,
- 0.029514084636179775,
- 0.002853161545148106,
- -0.04641152875589703,
- -0.003951276700172136,
- -0.0040322940914847855,
- -0.04256950666865577,
- -0.03323628684807561,
- -0.0766037412567671,
- 0.08862334481236925,
- 0.0840567760133204,
- -0.029998060971523754,
- 0.0050852404280895625,
- -0.0523305597794523,
- -0.021572308051443417,
- 0.00978717733147738,
- -0.04161631057317193,
- -0.004830509676027382,
- 0.005164590939911355,
- 0.05477888882093925,
- -0.0016513067746095605,
- 0.07013162881141863,
- 0.076093639641334,
- -0.01135041667019683,
- -0.002387555531129553,
- -0.06169893575816622,
- -0.039005546406281406,
- -0.015346049444420284,
- 0.03262093563070047,
- -0.018091888374899244,
- 0.0340600437543946,
- 0.0413090073474649,
- 0.07832176156023411,
- 0.04250757902769648,
- 0.0617689595134583,
- -0.05002261146957063,
- 0.01200146480931186,
- -0.027678360943264344,
- -0.04165174085129374,
- 0.03906885865974004,
- 0.06545052953410013,
- 0.07368543011927131,
- 0.033109186111253124,
- 0.02006464257300308,
- 0.02755038598329229,
- 0.019531772490917615,
- -0.002485091247951189,
- 0.016374866616169884,
- 0.04386231179672989,
- -0.00040018349030098853,
- -0.03769154897188289,
- -0.04812692255364454,
- -0.03643251181358571,
- -0.006470190751119471,
- -0.04911173816890036,
- 0.029230850771749126,
- -0.03752490612322709,
- -0.0056869804123184375,
- -0.039284033635513436,
- -0.02999947997302002,
- -0.07259002364459859,
- -0.07848887421963123,
- -0.0786139412043755,
- -0.06170814317672509,
- -0.037509448288062026,
- -0.031820521866524024,
- -0.004699438744260415
+ 0.08014287312679225,
+ 0.037146493229097445,
+ -0.02603166011849872,
+ -0.04419388252136648,
+ 0.09101448103178689,
+ -0.003739739186647749,
+ 0.05484242023417882,
+ 0.11826947324914416,
+ 0.043667170248303466,
+ 0.03236322750195383,
+ 0.05947393085002755,
+ -0.005766866352497645,
+ 0.030816567306962285,
+ 0.03544781498534445,
+ 0.10668826264912369,
+ 0.03673488295358744,
+ -0.02997727997822197,
+ -0.0011722847540802288,
+ 0.04215269101439463,
+ -0.0020428217513680288,
+ 0.04340403812559551,
+ -0.024641402925508465,
+ -0.07011149971642418,
+ -0.0006709918804689117,
+ -0.05578626296392473,
+ -0.07796135775452492,
+ 0.015576453203715956,
+ 0.03517122107452937,
+ -0.02474340663045097,
+ -0.03805415173684171,
+ -0.07259407040873372,
+ -0.03840622837918552,
+ 0.04410195569249157,
+ 0.03382018457861236,
+ 0.018901131151387883,
+ 0.07033983877354717,
+ -0.057360307945296236,
+ -0.029470847920470403,
+ 0.015880100228158178,
+ 0.032188339311192554,
+ -0.010779699976034895,
+ -0.0032634057101994913,
+ 0.045134393338303234,
+ -0.003923939642999089,
+ -0.004726337785479798,
+ -0.039833314633538,
+ 0.014840596881445275,
+ 0.0003765539539653095,
+ -0.013392987786084177,
+ -0.0356794350873725,
+ -0.02926262988141643,
+ -0.004785461066430653,
+ -0.022649019532539852,
+ -0.005720324109497198,
+ -0.048330450011041734,
+ 0.03285530468625955,
+ -0.023162573671341155,
+ -0.07271741236866766,
+ 0.031199534245519456,
+ -0.0033741099283342434,
+ 0.011170405292467208,
+ 0.019500166699634985,
+ 0.0010965937113134872,
+ 0.0032764776242019842,
+ -0.0007842682110097126,
+ -0.003915230365607036,
+ -0.018175605740980157,
+ 0.03188240383660177,
+ 0.04589149969653465,
+ 0.0376807947866241,
+ 0.06381621551924108,
+ -0.0013000995005240752,
+ -0.011581513455535344,
+ -0.003426159609522095,
+ 0.03481793079739181,
+ 0.021473955201729303,
+ 0.023771293901034187,
+ 0.03828697235220336,
+ -0.0042651660374771125,
+ -0.04043128049541707,
+ 0.0011335050632551194,
+ -0.039884865141869216,
+ 0.0011644884776125172,
+ -0.05170201896733027,
+ -0.06500683147099663,
+ -0.002032228622196166,
+ 0.0005104131195647621,
+ 0.06427388343826383,
+ 0.04725239836343652,
+ -0.01529760067788053,
+ 0.03663163021911147,
+ -0.04143857996850098,
+ -0.041019407055166865,
+ -0.03863891565495734,
+ -0.04229588426869208,
+ -0.0729102993471632,
+ -0.033572063560751805,
+ -0.05149346461547897,
+ -0.005753218358743144,
+ -0.0028314338440196853,
+ -0.050506649936834055,
+ -0.04388589399336761,
+ -0.022702186982412553,
+ -0.040793437260267046,
+ -0.0070652821976595614,
+ 0.07004192563517486,
+ -0.010178744858492636,
+ -0.03993124845587448,
+ 0.06951359663415056,
+ 0.08941104316065179,
+ -0.027397943204625227,
+ -0.040230294249783305,
+ 0.029514084635860954,
+ 0.0028531615451487533,
+ -0.0464115287558972,
+ -0.003951276700204259,
+ -0.004032294091485318,
+ -0.04256950666865785,
+ -0.03323628684807554,
+ -0.07660374125676815,
+ 0.08862334481233934,
+ 0.08405677601332023,
+ -0.02999806097152329,
+ 0.005085240428089591,
+ -0.052330559807971896,
+ -0.021572308051443556,
+ 0.009787177331477846,
+ -0.04161631057284951,
+ -0.004830509676031602,
+ 0.0051645909399113165,
+ 0.054778888820944546,
+ -0.0016513067746098602,
+ 0.07013162881142856,
+ 0.07609363964133396,
+ -0.011350416670077967,
+ -0.002387555531126056,
+ -0.06169893575816659,
+ -0.03896187279419636,
+ -0.01534604944442809,
+ 0.032620935630707296,
+ -0.018091888374898852,
+ 0.03406004375439377,
+ 0.04130900734749639,
+ 0.07832176156023328,
+ 0.04250757902769903,
+ 0.06176895951345832,
+ -0.05002261146958561,
+ 0.012001464809312412,
+ -0.027678360943272848,
+ -0.04165174085129545,
+ 0.03906885865974049,
+ 0.06545052953410002,
+ 0.07368543011925997,
+ 0.03310918611125322,
+ 0.02006464257300284,
+ 0.027550385983291542,
+ 0.019531772490917483,
+ -0.002485091247951132,
+ 0.016374866616169874,
+ 0.043862311796729704,
+ -0.00040018349030142286,
+ -0.03769154897188579,
+ -0.048126922553643064,
+ -0.03643251181358872,
+ -0.006470190751129062,
+ -0.04911173816890494,
+ 0.029230850771744046,
+ -0.03752490612322689,
+ -0.0056869804123184705,
+ -0.03928403363551332,
+ -0.029999479973019803,
+ -0.07259002364459839,
+ -0.07848887421962725,
+ -0.07861394120437759,
+ -0.0617081431385601,
+ -0.03750944828806203,
+ -0.03182052186652391,
+ -0.0046994387442603515
],
"yaxis": "y"
},
@@ -2145,8 +2145,8 @@
0
],
"y": [
- -0.04518524299380677,
- 0.06988282126146635
+ -0.04518524299380705,
+ 0.06988282126134315
]
},
{
@@ -2163,8 +2163,8 @@
1
],
"y": [
- -0.026031660118498542,
- 0.10471629771223552
+ -0.02603166011849872,
+ 0.10471629771224936
]
},
{
@@ -2181,8 +2181,8 @@
2
],
"y": [
- -0.056055505792245136,
- 0.07587211346742362
+ -0.056055505792245386,
+ 0.07587211346751242
]
},
{
@@ -2199,8 +2199,8 @@
3
],
"y": [
- -0.013884103418189884,
- 0.11355448959636372
+ -0.013884103418189957,
+ 0.1135544895963472
]
},
{
@@ -2217,8 +2217,8 @@
4
],
"y": [
- -0.06437114670847809,
- 0.05217792701456639
+ -0.06437114670847834,
+ 0.052177927014512196
]
},
{
@@ -2235,8 +2235,8 @@
5
],
"y": [
- -0.05494656629446064,
- 0.04972303381680831
+ -0.054946566294460944,
+ 0.04972303381658967
]
},
{
@@ -2253,8 +2253,8 @@
6
],
"y": [
- -0.04213141363241548,
- 0.12503735559420137
+ -0.0421314136324156,
+ 0.1250373555945116
]
},
{
@@ -2271,8 +2271,8 @@
7
],
"y": [
- -0.058228464390390464,
- 0.06797437342122079
+ -0.058228464390390644,
+ 0.06797437342121643
]
},
{
@@ -2289,8 +2289,8 @@
8
],
"y": [
- -0.09897658305825777,
- -0.013934604796718727
+ -0.09897658305825816,
+ -0.013934604797056514
]
},
{
@@ -2307,8 +2307,8 @@
9
],
"y": [
- -0.07467755232189739,
- 0.0421526910144827
+ -0.07467755232189764,
+ 0.04215269101439463
]
},
{
@@ -2325,8 +2325,8 @@
0
],
"y": [
- -0.016404090820769195,
- 0.027024786872119577
+ -0.016404090820768717,
+ 0.027024786872119726
]
},
{
@@ -2343,8 +2343,8 @@
1
],
"y": [
- 0.005940290336302163,
- 0.0516328311145064
+ 0.0059402903363031355,
+ 0.05163283111450674
]
},
{
@@ -2361,8 +2361,8 @@
2
],
"y": [
- -0.01579395905900039,
- 0.028110108139132506
+ -0.015793959058881177,
+ 0.028110108139132686
]
},
{
@@ -2379,8 +2379,8 @@
3
],
"y": [
- 0.01727669309268168,
- 0.06151832455135668
+ 0.017276693092682677,
+ 0.06151832455135696
]
},
{
@@ -2397,8 +2397,8 @@
4
],
"y": [
- -0.02794852279852104,
- 0.01390715919249955
+ -0.027948522798402216,
+ 0.013907159192499487
]
},
{
@@ -2415,8 +2415,8 @@
5
],
"y": [
- -0.028654420329541134,
- 0.013545783272544578
+ -0.02865442032954095,
+ 0.013545783272544566
]
},
{
@@ -2433,8 +2433,8 @@
6
],
"y": [
- -0.0016305808732168438,
- 0.05838285026547863
+ -0.0016305808730966331,
+ 0.058382850265478604
]
},
{
@@ -2451,8 +2451,8 @@
7
],
"y": [
- -0.022615213244675268,
- 0.02461760308444715
+ -0.02261521324455614,
+ 0.024617603084446948
]
},
{
@@ -2469,8 +2469,8 @@
8
],
"y": [
- -0.0696366188833811,
- -0.031738932686820105
+ -0.06963661888326324,
+ -0.031738932686820605
]
},
{
@@ -2487,8 +2487,8 @@
9
],
"y": [
- -0.04111041471939729,
- 0.00530402936890708
+ -0.041110414719278496,
+ 0.005304029368906754
]
},
{
@@ -2545,16 +2545,16 @@
9
],
"y": [
- 0.005433071406361662,
- 0.029120585569587135,
- 0.008247326099856494,
- 0.04159065178640773,
- -0.005647923851165088,
- -0.008201999229279923,
- 0.007374388603541528,
- -0.006076490925649943,
- -0.05107311829794033,
- -0.02620387273488733
+ 0.005433071406361607,
+ 0.029120585569587003,
+ 0.008247326099856492,
+ 0.0415906517864076,
+ -0.005647923851165009,
+ -0.008201999229279925,
+ 0.00737438860354152,
+ -0.006076490925649854,
+ -0.05107311829794009,
+ -0.02620387273488719
]
},
{
@@ -3572,7 +3572,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "##Â Assesing the method's performance\n",
+ "## Assesing the method's performance\n",
"\n",
"To check whether our model is overfitting and ensure robust results we can perform cross-validation and visualize the results."
]
@@ -3586,10 +3586,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
- "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 2.4s remaining: 3.6s\n",
- "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 2.4s remaining: 1.6s\n",
- "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 2.6s finished\n"
+ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 4.8s remaining: 7.2s\n",
+ "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 4.9s remaining: 3.3s\n",
+ "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 5.1s finished\n"
]
},
{
@@ -3646,13 +3652,13 @@
" 0.003702 | \n",
" 0.006428 | \n",
" 0.008920 | \n",
- " 0.010948 | \n",
+ " 0.010950 | \n",
" 0.012559 | \n",
" ... | \n",
" 0.014497 | \n",
" 0.013223 | \n",
" 0.011248 | \n",
- " 0.008737 | \n",
+ " 0.008736 | \n",
" 0.005352 | \n",
" \n",
" \n",
@@ -3663,7 +3669,7 @@
"text/plain": [
" 0.05 0.10 0.15 0.20 0.25 ... 0.75 0.80 0.85 0.90 0.95\n",
"train 0.003559 0.006266 0.008625 0.010612 0.012261 ... 0.014217 0.012836 0.010951 0.008441 0.005035\n",
- "test 0.003702 0.006428 0.008920 0.010948 0.012559 ... 0.014497 0.013223 0.011248 0.008737 0.005352\n",
+ "test 0.003702 0.006428 0.008920 0.010950 0.012559 ... 0.014497 0.013223 0.011248 0.008736 0.005352\n",
"\n",
"[2 rows x 19 columns]"
]
@@ -3727,25 +3733,25 @@
"0.95"
],
"y": [
- 0.003558561337780733,
- 0.006265855621881848,
- 0.008625419002211416,
- 0.010611926929967935,
- 0.012260635572630749,
- 0.013608913676947828,
- 0.014687816023102568,
- 0.015480342389821197,
- 0.016035153061010856,
- 0.01636737549657335,
+ 0.003558561337796822,
+ 0.0062658556218819425,
+ 0.008625419002211315,
+ 0.010611927646580647,
+ 0.012260635572631316,
+ 0.013608913676947837,
+ 0.014687816023102571,
+ 0.0154803423898212,
+ 0.016035153061010853,
+ 0.01636737614305893,
0.016484091061958146,
- 0.01634202361231802,
- 0.015917613063373793,
- 0.015215424354852338,
- 0.014217213287838603,
- 0.012835893477185223,
- 0.010950667994380126,
- 0.008441312015956522,
- 0.005034747517506452
+ 0.01634202381594042,
+ 0.015917613624349118,
+ 0.015215424354852342,
+ 0.01421721328783859,
+ 0.012835893477184435,
+ 0.010950667994380015,
+ 0.008441313864545713,
+ 0.005034747517506451
]
},
{
@@ -3776,25 +3782,25 @@
"0.95"
],
"y": [
- 0.0037022360192326595,
- 0.006427987072928047,
- 0.008919835958633724,
- 0.010948272176203791,
- 0.012558890052304993,
- 0.013976140969748594,
- 0.015110463655576833,
- 0.01584773610600515,
- 0.01628869930290136,
- 0.01663373310348861,
- 0.016767153425770797,
- 0.01662945912605408,
- 0.01621382760598824,
- 0.015504261943136149,
- 0.01449700415107979,
- 0.013222510063111408,
- 0.011247535781780584,
- 0.00873703551498941,
- 0.005351881561834753
+ 0.0037022360178084295,
+ 0.006427987072919422,
+ 0.008919835958642538,
+ 0.010950435483101617,
+ 0.01255889005228547,
+ 0.013976140969748727,
+ 0.015110463655577283,
+ 0.015847736106005268,
+ 0.01628869930290131,
+ 0.016634252566334336,
+ 0.016767153425770804,
+ 0.01662916798202697,
+ 0.016215370930187933,
+ 0.015504261943137768,
+ 0.014497004150961904,
+ 0.013222510062559815,
+ 0.011247535781760675,
+ 0.00873620925188529,
+ 0.005351881561834736
]
}
],
@@ -4672,7 +4678,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "pe",
+ "display_name": "pe3.13",
"language": "python",
"name": "python3"
},
@@ -4686,7 +4692,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.11"
+ "version": "3.13.0"
}
},
"nbformat": 4,
diff --git a/docs/myst.yml b/docs/myst.yml
new file mode 100644
index 0000000..c0e500a
--- /dev/null
+++ b/docs/myst.yml
@@ -0,0 +1,64 @@
+version: 1
+project:
+ title: MicroImpute documentation
+ authors:
+ - name: PolicyEngine
+ copyright: "2025"
+ github: https://github.com/policyengine/microimpute
+ repository:
+ url: https://github.com/policyengine/microimpute
+ branch: main
+ path: docs
+ toc:
+ - file: index
+ - title: Models
+ children:
+ - file: models/imputer/index
+ children:
+ - file: models/imputer/implement-new-model
+ - file: models/matching/index
+ children:
+ - file: models/matching/matching-imputation
+ - file: models/ols/index
+ children:
+ - file: models/ols/ols-imputation
+ - file: models/qrf/index
+ children:
+ - file: models/qrf/qrf-imputation
+ - file: models/quantreg/index
+ children:
+ - file: models/quantreg/quantreg-imputation
+ - title: Imputation and benchmarking
+ children:
+ - file: imputation-benchmarking/index
+ children:
+ - file: imputation-benchmarking/benchmarking-methods
+ - file: imputation-benchmarking/imputing-across-surveys
+ - title: AutoImpute
+ children:
+ - file: autoimpute/index
+ children:
+ - file: autoimpute/autoimpute
+ - title: SCF to CPS example
+ children:
+ - file: examples/scf_to_cps/index
+ children:
+ - file: examples/scf_to_cps/imputing-from-scf-to-cps
+site:
+ options:
+ logo: logo.png
+ template: book-theme
+ extensions:
+ - sphinx.ext.autodoc
+ - sphinx.ext.viewcode
+ - sphinx.ext.napoleon
+ - sphinx.ext.mathjax
+ config:
+ html_js_files:
+ - https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js
+ html_theme: furo
+ pygments_style: default
+ html_css_files:
+ - style.css
+ execute:
+ execute_notebooks: force
diff --git a/microimpute/__init__.py b/microimpute/__init__.py
index cc13e0d..238a90e 100644
--- a/microimpute/__init__.py
+++ b/microimpute/__init__.py
@@ -3,10 +3,9 @@
A package for benchmarking different imputation methods using microdata.
"""
-__version__ = "0.1.5"
+__version__ = "1.1.2"
-# Import data handling functions
-from microimpute.comparisons.data import prepare_scf_data, preprocess_data
+from microimpute.comparisons.autoimpute import autoimpute
from microimpute.comparisons.imputations import get_imputations
# Import comparison utilities
@@ -30,6 +29,9 @@
# Import main models and utilities
from microimpute.models import OLS, QRF, Imputer, ImputerResults, QuantReg
+# Import data handling functions
+from microimpute.utils.data import preprocess_data
+
try:
from microimpute.models.matching import Matching
except ImportError:
diff --git a/microimpute/comparisons/__init__.py b/microimpute/comparisons/__init__.py
index a069abb..958c806 100644
--- a/microimpute/comparisons/__init__.py
+++ b/microimpute/comparisons/__init__.py
@@ -6,9 +6,6 @@
# Import automated imputation utilities
from .autoimpute import autoimpute
-# Import data handling functions
-from .data import prepare_scf_data, preprocess_data, scf_url
-
# Import imputation utilities
from .imputations import get_imputations
diff --git a/microimpute/comparisons/autoimpute.py b/microimpute/comparisons/autoimpute.py
index d6d5364..3e7cc74 100644
--- a/microimpute/comparisons/autoimpute.py
+++ b/microimpute/comparisons/autoimpute.py
@@ -14,7 +14,6 @@
from tqdm.auto import tqdm
from microimpute.comparisons import *
-from microimpute.comparisons.data import preprocess_data
from microimpute.config import (
QUANTILES,
RANDOM_STATE,
@@ -23,6 +22,7 @@
)
from microimpute.evaluations import cross_validate_model
from microimpute.models import *
+from microimpute.utils.data import preprocess_data
log = logging.getLogger(__name__)
diff --git a/microimpute/comparisons/data.py b/microimpute/comparisons/data.py
deleted file mode 100644
index 6d5d0bb..0000000
--- a/microimpute/comparisons/data.py
+++ /dev/null
@@ -1,418 +0,0 @@
-"""Data preparation utilities for imputation benchmarking.
-
-This module provides functions for acquiring, preprocessing, and splitting data for imputation benchmarking.
-It includes utilities for downloading Survey of Consumer Finances
-(SCF) data, normalizing features, and creating train-test splits with consistent parameters.
-"""
-
-import io
-import logging
-import zipfile
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import pandas as pd
-import requests
-from pydantic import validate_call
-from sklearn.model_selection import train_test_split
-from tqdm import tqdm
-
-from microimpute.config import (
- RANDOM_STATE,
- TEST_SIZE,
- TRAIN_SIZE,
- VALID_YEARS,
- VALIDATE_CONFIG,
-)
-
-logger = logging.getLogger(__name__)
-
-
-@validate_call(config=VALIDATE_CONFIG)
-def scf_url(year: int) -> str:
- """Return the URL of the SCF summary microdata zip file for a year.
-
- Args:
- year: Year of SCF summary microdata to retrieve.
-
- Returns:
- URL of summary microdata zip file for the given year.
-
- Raises:
- ValueError: If the year is not in VALID_YEARS.
- """
- logger.debug(f"Generating SCF URL for year {year}")
-
- if year not in VALID_YEARS:
- logger.error(
- f"Invalid SCF year: {year}. Valid years are {VALID_YEARS}"
- )
- raise ValueError(
- f"The SCF is not available for {year}. Valid years are {VALID_YEARS}"
- )
-
- url = f"https://www.federalreserve.gov/econres/files/scfp{year}s.zip"
- logger.debug(f"Generated URL: {url}")
- return url
-
-
-@validate_call(config=VALIDATE_CONFIG)
-def _load(
- years: Optional[Union[int, List[int]]] = None,
- columns: Optional[List[str]] = None,
-) -> pd.DataFrame:
- """Load Survey of Consumer Finances data for specified years and columns.
-
- Args:
- years: Year or list of years to load data for.
- columns: List of column names to load.
-
- Returns:
- DataFrame containing the requested data.
-
- Raises:
- ValueError: If no Stata files are found in the downloaded zip
- or invalid parameters
- RuntimeError: If there's a network error or a problem processing
- the downloaded data
- """
-
- logger.info(f"Loading SCF data with years={years}")
-
- try:
- # Identify years for download
- if years is None:
- years = VALID_YEARS
- logger.warning(f"Using default years: {years}")
-
- if isinstance(years, int):
- years = [years]
-
- # Validate all years are valid
- invalid_years = [year for year in years if year not in VALID_YEARS]
- if invalid_years:
- logger.error(f"Invalid years specified: {invalid_years}")
- raise ValueError(
- f"Invalid years: {invalid_years}. Valid years are {VALID_YEARS}"
- )
-
- all_data: List[pd.DataFrame] = []
-
- for year in tqdm(years):
- logger.info(f"Processing data for year {year}")
- try:
- # Download zip file
- logger.debug(f"Downloading SCF data for year {year}")
- url = scf_url(year)
- try:
- response = requests.get(url, timeout=60)
- response.raise_for_status() # Raise an error for bad responses
- except requests.exceptions.RequestException as e:
- logger.error(
- f"Network error downloading SCF data for year {year}: {str(e)}"
- )
- raise RuntimeError(
- f"Failed to download SCF data for year {year}"
- ) from e
-
- # Process zip file
- try:
- logger.debug("Creating zipfile from downloaded content")
- z = zipfile.ZipFile(io.BytesIO(response.content))
-
- # Find the .dta file in the zip
- dta_files: List[str] = [
- f for f in z.namelist() if f.endswith(".dta")
- ]
- if not dta_files:
- logger.error(
- f"No Stata files found in zip for year {year}"
- )
- raise ValueError(
- f"No Stata files found in zip for year {year}"
- )
-
- logger.debug(f"Found Stata files: {dta_files}")
-
- # Read the Stata file
- try:
- logger.debug(f"Reading Stata file: {dta_files[0]}")
- with z.open(dta_files[0]) as f:
- df = pd.read_stata(
- io.BytesIO(f.read()), columns=columns
- )
- logger.debug(
- f"Read DataFrame with shape {df.shape}"
- )
-
- # Ensure 'wgt' is included
- if (
- columns is not None
- and "wgt" not in df.columns
- and "wgt" not in columns
- ):
- logger.debug("Re-reading with 'wgt' column added")
- # Re-read to include weights
- with z.open(dta_files[0]) as f:
- cols_with_weight: List[str] = list(
- set(columns) | {"wgt"}
- )
- df = pd.read_stata(
- io.BytesIO(f.read()),
- columns=cols_with_weight,
- )
- logger.debug(
- f"Re-read DataFrame with shape {df.shape}"
- )
- except Exception as e:
- logger.error(
- f"Error reading Stata file for year {year}: {str(e)}"
- )
- raise RuntimeError(
- f"Failed to process Stata file for year {year}"
- ) from e
-
- except zipfile.BadZipFile as e:
- logger.error(f"Bad zip file for year {year}: {str(e)}")
- raise RuntimeError(
- f"Downloaded zip file is corrupt for year {year}"
- ) from e
-
- # Add year column
- df["year"] = year
- logger.info(
- f"Successfully processed data for year {year}, shape: {df.shape}"
- )
- all_data.append(df)
-
- except Exception as e:
- logger.error(f"Error processing year {year}: {str(e)}")
- raise
-
- # Combine all years
- logger.debug(f"Combining data from {len(all_data)} years")
- if len(all_data) > 1:
- result = pd.concat(all_data)
- logger.info(
- f"Combined data from {len(years)} years, final shape: {result.shape}"
- )
- return result
- else:
- logger.info(
- f"Returning data for single year, shape: {all_data[0].shape}"
- )
- return all_data[0]
-
- except Exception as e:
- logger.error(f"Error in _load: {str(e)}")
- raise
-
-
-@validate_call(config=VALIDATE_CONFIG)
-def prepare_scf_data(
- full_data: bool = False, years: Optional[Union[int, List[int]]] = None
-) -> Union[
- Tuple[pd.DataFrame, List[str], List[str], dict], # when full_data=True
- Tuple[
- pd.DataFrame, pd.DataFrame, List[str], List[str], dict
- ], # when full_data=False
-]:
- """Preprocess the Survey of Consumer Finances data for model training and testing.
-
- Args:
- full_data: Whether to return the complete dataset without splitting.
- years: Year or list of years to load data for.
-
- Returns:
- Different tuple formats depending on the value of full_data:
- - If full_data=True: (data, predictor_columns, imputed_columns, dummy_info)
- - If full_data=False: (train_data, test_data,
- predictor_columns, imputed_columns, dummy_info)
-
- Where dummy_info is a dictionary with information about dummy variables created from string columns.
-
- Raises:
- ValueError: If required columns are missing from the data
- RuntimeError: If data processing fails
- """
- logger.info(
- f"Preparing SCF data with full_data={full_data}, years={years}"
- )
-
- try:
- # Load the raw data
- logger.debug("Loading SCF data")
- data = _load(years=years)
-
- # Define columns needed for analysis
- # predictors shared with cps data
- PREDICTORS: List[str] = [
- "hhsex", # sex of head of household
- "age", # age of respondent
- "married", # marital status of respondent
- # "kids", # number of children in household
- "race", # race of respondent
- "income", # total annual income of household
- "wageinc", # income from wages and salaries
- "bussefarminc", # income from business, self-employment or farm
- "intdivinc", # income from interest and dividends
- "ssretinc", # income from social security and retirement accounts
- "lf", # labor force status
- ]
-
- IMPUTED_VARIABLES: List[str] = ["networth"]
-
- # Validate that all required columns exist in the data
- missing_columns = [
- col
- for col in PREDICTORS + IMPUTED_VARIABLES
- if col not in data.columns
- ]
- if missing_columns:
- logger.error(
- f"Required columns missing from SCF data: {missing_columns}"
- )
- raise ValueError(
- f"Required columns missing from SCF data: {missing_columns}"
- )
-
- logger.debug(
- f"Selecting {len(PREDICTORS)} predictors and {len(IMPUTED_VARIABLES)} imputation variables"
- )
- data = data[PREDICTORS + IMPUTED_VARIABLES]
- logger.debug(f"Data shape after column selection: {data.shape}")
-
- if full_data:
- logger.info("Processing full dataset without splitting")
- data = preprocess_data(data, full_data=True)
- logger.info(
- f"Returning full processed dataset with shape {data.shape}"
- )
- return data, PREDICTORS, IMPUTED_VARIABLES
- else:
- logger.info("Splitting data into train and test sets")
- X_train, X_test = preprocess_data(data)
- logger.info(
- f"Train set shape: {X_train.shape}, Test set shape: {X_test.shape}"
- )
- return X_train, X_test, PREDICTORS, IMPUTED_VARIABLES
-
- except Exception as e:
- logger.error(f"Error in prepare_scf_data: {str(e)}")
- raise RuntimeError(f"Failed to prepare SCF data: {str(e)}") from e
-
-
-@validate_call(config=VALIDATE_CONFIG)
-def preprocess_data(
- data: pd.DataFrame,
- full_data: Optional[bool] = False,
- train_size: Optional[float] = TRAIN_SIZE,
- test_size: Optional[float] = TEST_SIZE,
- random_state: Optional[int] = RANDOM_STATE,
- normalize: Optional[bool] = False,
-) -> Union[
- Tuple[pd.DataFrame, dict], # when full_data=True
- Tuple[pd.DataFrame, pd.DataFrame, dict], # when full_data=False
-]:
- """Preprocess the data for model training and testing.
-
- Args:
- data: DataFrame containing the data to preprocess.
- full_data: Whether to return the complete dataset without splitting.
- train_size: Proportion of the dataset to include in the train split.
- test_size: Proportion of the dataset to include in the test split.
- random_state: Random seed for reproducibility.
- normalize: Whether to normalize the data.
-
- Returns:
- Different tuple formats depending on the value of full_data:
- - If full_data=True: (data, dummy_info)
- - If full_data=False: (X_train, X_test, dummy_info)
-
- Where dummy_info is a dictionary mapping original columns to their resulting dummy columns
-
- Raises:
- ValueError: If data is empty or invalid
- RuntimeError: If data preprocessing fails
- """
-
- logger.debug(
- f"Preprocessing data with shape {data.shape}, full_data={full_data}"
- )
-
- if data.empty:
- raise ValueError("Data must not be None or empty")
- # Check for missing values
- missing_count = data.isna().sum().sum()
- if missing_count > 0:
- logger.warning(f"Data contains {missing_count} missing values")
-
- if normalize:
- logger.debug("Normalizing data")
- try:
- mean = data.mean(axis=0)
- std = data.std(axis=0)
-
- # Check for constant columns (std=0)
- constant_cols = std[std == 0].index.tolist()
- if constant_cols:
- logger.warning(
- f"Found constant columns (std=0): {constant_cols}"
- )
- # Handle constant columns by setting std to 1 to avoid division by zero
- for col in constant_cols:
- std[col] = 1
-
- # Apply normalization
- data = (data - mean) / std
- logger.debug("Data normalized successfully")
-
- # Store normalization parameters
- normalization_params = {
- col: {"mean": mean[col], "std": std[col]}
- for col in data.columns
- }
-
- logger.debug(f"Normalization parameters: {normalization_params}")
-
- except Exception as e:
- logger.error(f"Error during data normalization: {str(e)}")
- raise RuntimeError("Failed to normalize data") from e
-
- if full_data and normalize:
- logger.info("Returning full preprocessed dataset")
- return (
- data,
- normalization_params,
- )
- elif full_data:
- logger.info("Returning full preprocessed dataset")
- return data
- else:
- logger.debug(
- f"Splitting data with train_size={train_size}, test_size={test_size}"
- )
- try:
- X_train, X_test = train_test_split(
- data,
- test_size=test_size,
- train_size=train_size,
- random_state=random_state,
- )
- logger.info(
- f"Data split into train ({X_train.shape}) and test ({X_test.shape}) sets"
- )
- if normalize:
- return (
- X_train,
- X_test,
- normalization_params,
- )
- else:
- return (
- X_train,
- X_test,
- )
-
- except Exception as e:
- logger.error(f"Error in processing data: {str(e)}")
- raise
diff --git a/microimpute/models/qrf.py b/microimpute/models/qrf.py
index 7b18752..631faa8 100644
--- a/microimpute/models/qrf.py
+++ b/microimpute/models/qrf.py
@@ -179,9 +179,7 @@ def _predict(
for i, variable in enumerate(self.imputed_variables):
var_start_time = time.time()
- if (
- not quantiles
- ): # Only log per-variable when not processing multiple quantiles
+ if not quantiles:
self.logger.info(
f"[{i+1}/{len(self.imputed_variables)}] Predicting for '{variable}'"
)
@@ -224,11 +222,12 @@ def _predict(
f" ✓ {variable} predicted in {var_time:.2f}s ({len(imputed_values)} samples)"
)
+ self.logger.info(
+ f"QRF predictions completed for {variable} imputed variable"
+ )
+
imputations[q] = imputed_df
- self.logger.info(
- f"QRF predictions completed for {len(X_test)} samples"
- )
return imputations
except Exception as e:
@@ -377,9 +376,6 @@ def _fit(
self.logger.info(
f" Features: {len(current_predictors)} predictors"
)
- self.logger.info(
- f" Training data shape: {X_train[current_predictors + [variable]].shape}"
- )
self.logger.info(
f" Memory usage: {self._get_memory_usage_info()}"
)
@@ -508,15 +504,11 @@ def _fit(
self.logger.info(
f" Features: {len(current_predictors)} predictors"
)
- self.logger.info(
- f" Training data shape: {X_train[current_predictors + [variable]].shape}"
- )
self.logger.info(
f" Memory usage: {self._get_memory_usage_info()}"
)
# Create and fit model
- # Note: X_train is already preprocessed by base class
model = _QRFModel(seed=self.seed, logger=self.logger)
try:
@@ -610,9 +602,6 @@ def _fit_variable_batch(
self.logger.info(
f" Features: {len(current_predictors)} predictors"
)
- self.logger.info(
- f" Training data shape: {X_train[current_predictors + [variable]].shape}"
- )
self.logger.info(
f" Memory usage: {self._get_memory_usage_info()}"
)
diff --git a/microimpute/utils/__init__.py b/microimpute/utils/__init__.py
new file mode 100644
index 0000000..2f694c7
--- /dev/null
+++ b/microimpute/utils/__init__.py
@@ -0,0 +1,7 @@
+"""
+This module contains utilities that support microimpute processes.
+"""
+
+from .data import preprocess_data
+from .logging_utils import configure_logging
+from .statmatch_hotdeck import nnd_hotdeck_using_rpy2
diff --git a/microimpute/utils/data.py b/microimpute/utils/data.py
new file mode 100644
index 0000000..08694e0
--- /dev/null
+++ b/microimpute/utils/data.py
@@ -0,0 +1,139 @@
+"""Data preparation utilities for imputation benchmarking.
+
+This module provides functions for acquiring, preprocessing, and splitting data for imputation benchmarking.
+It includes utilities for downloading Survey of Consumer Finances
+(SCF) data, normalizing features, and creating train-test splits with consistent parameters.
+"""
+
+import logging
+from typing import Optional, Tuple, Union
+
+import pandas as pd
+from pydantic import validate_call
+from sklearn.model_selection import train_test_split
+
+from microimpute.config import (
+ RANDOM_STATE,
+ TEST_SIZE,
+ TRAIN_SIZE,
+ VALIDATE_CONFIG,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@validate_call(config=VALIDATE_CONFIG)
+def preprocess_data(
+ data: pd.DataFrame,
+ full_data: Optional[bool] = False,
+ train_size: Optional[float] = TRAIN_SIZE,
+ test_size: Optional[float] = TEST_SIZE,
+ random_state: Optional[int] = RANDOM_STATE,
+ normalize: Optional[bool] = False,
+) -> Union[
+ Tuple[pd.DataFrame, dict], # when full_data=True
+ Tuple[pd.DataFrame, pd.DataFrame, dict], # when full_data=False
+]:
+ """Preprocess the data for model training and testing.
+
+ Args:
+ data: DataFrame containing the data to preprocess.
+ full_data: Whether to return the complete dataset without splitting.
+ train_size: Proportion of the dataset to include in the train split.
+ test_size: Proportion of the dataset to include in the test split.
+ random_state: Random seed for reproducibility.
+ normalize: Whether to normalize the data.
+
+ Returns:
+ Different tuple formats depending on the value of full_data:
+ - If full_data=True: (data, dummy_info)
+ - If full_data=False: (X_train, X_test, dummy_info)
+
+ Where dummy_info is a dictionary mapping original columns to their resulting dummy columns
+
+ Raises:
+ ValueError: If data is empty or invalid
+ RuntimeError: If data preprocessing fails
+ """
+
+ logger.debug(
+ f"Preprocessing data with shape {data.shape}, full_data={full_data}"
+ )
+
+ if data.empty:
+ raise ValueError("Data must not be None or empty")
+ # Check for missing values
+ missing_count = data.isna().sum().sum()
+ if missing_count > 0:
+ logger.warning(f"Data contains {missing_count} missing values")
+
+ if normalize:
+ logger.debug("Normalizing data")
+ try:
+ mean = data.mean(axis=0)
+ std = data.std(axis=0)
+
+ # Check for constant columns (std=0)
+ constant_cols = std[std == 0].index.tolist()
+ if constant_cols:
+ logger.warning(
+ f"Found constant columns (std=0): {constant_cols}"
+ )
+ # Handle constant columns by setting std to 1 to avoid division by zero
+ for col in constant_cols:
+ std[col] = 1
+
+ # Apply normalization
+ data = (data - mean) / std
+ logger.debug("Data normalized successfully")
+
+ # Store normalization parameters
+ normalization_params = {
+ col: {"mean": mean[col], "std": std[col]}
+ for col in data.columns
+ }
+
+ logger.debug(f"Normalization parameters: {normalization_params}")
+
+ except Exception as e:
+ logger.error(f"Error during data normalization: {str(e)}")
+ raise RuntimeError("Failed to normalize data") from e
+
+ if full_data and normalize:
+ logger.info("Returning full preprocessed dataset")
+ return (
+ data,
+ normalization_params,
+ )
+ elif full_data:
+ logger.info("Returning full preprocessed dataset")
+ return data
+ else:
+ logger.debug(
+ f"Splitting data with train_size={train_size}, test_size={test_size}"
+ )
+ try:
+ X_train, X_test = train_test_split(
+ data,
+ test_size=test_size,
+ train_size=train_size,
+ random_state=random_state,
+ )
+ logger.info(
+ f"Data split into train ({X_train.shape}) and test ({X_test.shape}) sets"
+ )
+ if normalize:
+ return (
+ X_train,
+ X_test,
+ normalization_params,
+ )
+ else:
+ return (
+ X_train,
+ X_test,
+ )
+
+ except Exception as e:
+ logger.error(f"Error in processing data: {str(e)}")
+ raise
diff --git a/myst.yml b/myst.yml
new file mode 100644
index 0000000..160145e
--- /dev/null
+++ b/myst.yml
@@ -0,0 +1,64 @@
+version: 1
+project:
+ title: MicroImpute documentation
+ authors:
+ - name: PolicyEngine
+ copyright: "2025"
+ github: https://github.com/policyengine/microimpute
+ repository:
+ url: https://github.com/policyengine/microimpute
+ branch: main
+ path: docs
+ toc:
+ - file: docs/index
+ - title: Models
+ children:
+ - file: docs/models/imputer/index
+ children:
+ - file: docs/models/imputer/implement-new-model
+ - file: docs/models/matching/index
+ children:
+ - file: docs/models/matching/matching-imputation
+ - file: docs/models/ols/index
+ children:
+ - file: docs/models/ols/ols-imputation
+ - file: docs/models/qrf/index
+ children:
+ - file: docs/models/qrf/qrf-imputation
+ - file: docs/models/quantreg/index
+ children:
+ - file: docs/models/quantreg/quantreg-imputation
+ - title: Imputation and benchmarking
+ children:
+ - file: docs/imputation-benchmarking/index
+ children:
+ - file: docs/imputation-benchmarking/benchmarking-methods
+ - file: docs/imputation-benchmarking/imputing-across-surveys
+ - title: AutoImpute
+ children:
+ - file: docs/autoimpute/index
+ children:
+ - file: docs/autoimpute/autoimpute
+ - title: SCF to CPS example
+ children:
+ - file: docs/examples/scf_to_cps/index
+ children:
+ - file: docs/examples/scf_to_cps/imputing-from-scf-to-cps
+site:
+ options:
+ logo: docs/logo.png
+ template: book-theme
+ extensions:
+ - sphinx.ext.autodoc
+ - sphinx.ext.viewcode
+ - sphinx.ext.napoleon
+ - sphinx.ext.mathjax
+ config:
+ html_js_files:
+ - https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js
+ html_theme: furo
+ pygments_style: default
+ html_css_files:
+ - _static/style.css
+ execute:
+ execute_notebooks: force
diff --git a/paper/imputing-from-scf-to-cps.ipynb b/paper/imputing-from-scf-to-cps.ipynb
index fc723f4..614f755 100644
--- a/paper/imputing-from-scf-to-cps.ipynb
+++ b/paper/imputing-from-scf-to-cps.ipynb
@@ -16,7 +16,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"id": "1c672d6a",
"metadata": {},
"outputs": [
@@ -58,6 +58,7 @@
"from microimpute.comparisons import *\n",
"from microimpute.visualizations import *\n",
"from microimpute.evaluations import *\n",
+ "from microimpute.utils.data import preprocess_data\n",
"\n",
"logger = logging.getLogger(__name__)"
]
diff --git a/pyproject.toml b/pyproject.toml
index 62eb60f..84349f0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,7 +23,7 @@ dependencies = [
"scipy>=1.16.0,<1.17.0",
"requests>=2.32.0,<3.0.0",
"tqdm>=4.65.0,<5.0.0",
- "statsmodels>=0.14.0,<0.16.0",
+ "statsmodels>=0.14.5,<0.16.0",
"quantile-forest>=1.4.0,<1.5.0",
"pydantic>=2.8.0,<3.0.0",
"optuna>=4.3.0,<5.0.0",
@@ -48,11 +48,11 @@ matching = [
]
docs = [
- "jupyter-book>=2.0.0b2", # JupyterBook 2.0 (beta)
+ "jupyter-book",
"furo>=2024.0.0", # Sphinx theme for documentation
- "ipywidgets>=8.0.0,<9.0.0", # For notebook interactivity
- "plotly>=5.24.0,<6.0.0", # For visualization in notebooks
- "h5py>=3.1.0,<4.0.0", # For data file support
+ "ipywidgets>=8.0.0,<9.0.0",
+ "plotly>=5.24.0,<6.0.0",
+ "h5py>=3.1.0,<4.0.0",
]
images = [
diff --git a/tests/test_models/test_imputers.py b/tests/test_models/test_imputers.py
index 9e80c27..3b4a550 100644
--- a/tests/test_models/test_imputers.py
+++ b/tests/test_models/test_imputers.py
@@ -12,7 +12,7 @@
import pytest
from sklearn.datasets import load_diabetes
-from microimpute.comparisons.data import preprocess_data
+from microimpute.utils.data import preprocess_data
from microimpute.config import QUANTILES
from microimpute.models import *
diff --git a/tests/test_models/test_matching.py b/tests/test_models/test_matching.py
index f6847b0..0a79e89 100644
--- a/tests/test_models/test_matching.py
+++ b/tests/test_models/test_matching.py
@@ -7,7 +7,7 @@
from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_squared_error
-from microimpute.comparisons.data import preprocess_data
+from microimpute.utils.data import preprocess_data
from microimpute.config import QUANTILES
from microimpute.evaluations import *
diff --git a/tests/test_models/test_ols.py b/tests/test_models/test_ols.py
index 048a2b8..e787c64 100644
--- a/tests/test_models/test_ols.py
+++ b/tests/test_models/test_ols.py
@@ -6,7 +6,7 @@
import pandas as pd
from sklearn.datasets import load_diabetes
-from microimpute.comparisons.data import preprocess_data
+from microimpute.utils.data import preprocess_data
from microimpute.config import QUANTILES
from microimpute.evaluations import *
from microimpute.models.ols import OLS
diff --git a/tests/test_models/test_qrf.py b/tests/test_models/test_qrf.py
index 078380f..17ac05a 100644
--- a/tests/test_models/test_qrf.py
+++ b/tests/test_models/test_qrf.py
@@ -7,7 +7,7 @@
from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_squared_error
-from microimpute.comparisons.data import preprocess_data
+from microimpute.utils.data import preprocess_data
from microimpute.config import QUANTILES
from microimpute.evaluations import *
from microimpute.models.qrf import QRF
diff --git a/tests/test_models/test_quantreg.py b/tests/test_models/test_quantreg.py
index 9a2b3c1..2241e3a 100644
--- a/tests/test_models/test_quantreg.py
+++ b/tests/test_models/test_quantreg.py
@@ -5,7 +5,7 @@
import pandas as pd
from sklearn.datasets import load_diabetes
-from microimpute.comparisons.data import preprocess_data
+from microimpute.utils.data import preprocess_data
from microimpute.config import QUANTILES, RANDOM_STATE
from microimpute.evaluations import *
from microimpute.models.quantreg import QuantReg
diff --git a/tests/test_quantile_comparison.py b/tests/test_quantile_comparison.py
index a87f98b..6d759d4 100644
--- a/tests/test_quantile_comparison.py
+++ b/tests/test_quantile_comparison.py
@@ -10,14 +10,18 @@
from typing import List, Type
+import io
import pandas as pd
+import requests
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
+import zipfile
from microimpute.comparisons import *
-from microimpute.config import RANDOM_STATE
+from microimpute.config import RANDOM_STATE, VALID_YEARS
from microimpute.models import *
from microimpute.visualizations.plotting import *
+from microimpute.utils.data import preprocess_data
def test_quantile_comparison_diabetes() -> None:
@@ -64,8 +68,26 @@ def test_quantile_comparison_diabetes() -> None:
def test_quantile_comparison_scf() -> None:
"""Test the end-to-end quantile loss comparison workflow on the scf data set."""
- X_train, X_test, PREDICTORS, IMPUTED_VARIABLES = prepare_scf_data(
- full_data=False, years=2019
+ scf_data = load_scf(2022)
+ PREDICTORS: List[str] = [
+ "hhsex", # sex of head of household
+ "age", # age of respondent
+ "married", # marital status of respondent
+ # "kids", # number of children in household
+ "race", # race of respondent
+ "income", # total annual income of household
+ "wageinc", # income from wages and salaries
+ "bussefarminc", # income from business, self-employment or farm
+ "intdivinc", # income from interest and dividends
+ "ssretinc", # income from social security and retirement accounts
+ "lf", # labor force status
+ ]
+ IMPUTED_VARIABLES: List[str] = ["networth"]
+
+ X_train, X_test = preprocess_data(
+ data=scf_data,
+ full_data=False,
+ normalize=False,
)
# Shrink down the data by sampling
@@ -97,3 +119,122 @@ def test_quantile_comparison_scf() -> None:
show_mean=True,
save_path="scf_model_comparison.jpg",
)
+
+
+@validate_call(config=VALIDATE_CONFIG)
+def load_scf(
+ years: Optional[Union[int, List[int]]] = None,
+ columns: Optional[List[str]] = None,
+) -> pd.DataFrame:
+ """Load Survey of Consumer Finances data for specified years and columns.
+
+ Args:
+ years: Year or list of years to load data for.
+ columns: List of column names to load.
+
+ Returns:
+ DataFrame containing the requested data.
+
+ Raises:
+ ValueError: If no Stata files are found in the downloaded zip
+ or invalid parameters
+ RuntimeError: If there's a network error or a problem processing
+ the downloaded data
+ """
+
+ def scf_url(year: int) -> str:
+ """Return the URL of the SCF summary microdata zip file for a year."""
+ logger.debug(f"Generating SCF URL for year {year}")
+
+ if year not in VALID_YEARS:
+ logger.error(
+ f"Invalid SCF year: {year}. Valid years are {VALID_YEARS}"
+ )
+ raise
+
+ url = f"https://www.federalreserve.gov/econres/files/scfp{year}s.zip"
+ logger.debug(f"Generated URL: {url}")
+ return url
+
+ logger.info(f"Loading SCF data with years={years}")
+
+ try:
+ # Identify years for download
+ if years is None:
+ years = VALID_YEARS
+ logger.warning(f"Using default years: {years}")
+
+ if isinstance(years, int):
+ years = [years]
+
+ all_data: List[pd.DataFrame] = []
+
+ for year in years:
+ logger.info(f"Processing data for year {year}")
+ try:
+ # Download zip file
+ logger.debug(f"Downloading SCF data for year {year}")
+ url = scf_url(year)
+ try:
+ response = requests.get(url, timeout=60)
+ response.raise_for_status() # Raise an error for bad responses
+ except requests.exceptions.RequestException as e:
+ logger.error(
+ f"Network error downloading SCF data for year {year}: {str(e)}"
+ )
+ raise
+
+ # Process zip file
+ z = zipfile.ZipFile(io.BytesIO(response.content))
+ # Find the .dta file in the zip
+ dta_files: List[str] = [
+ f for f in z.namelist() if f.endswith(".dta")
+ ]
+ if not dta_files:
+ logger.error(
+ f"No Stata files found in zip for year {year}"
+ )
+ raise
+
+ # Read the Stata file
+ try:
+ logger.debug(f"Reading Stata file: {dta_files[0]}")
+ with z.open(dta_files[0]) as f:
+ df = pd.read_stata(
+ io.BytesIO(f.read()), columns=columns
+ )
+ logger.debug(f"Read DataFrame with shape {df.shape}")
+ except Exception as e:
+ logger.error(
+ f"Error reading Stata file for year {year}: {str(e)}"
+ )
+ raise
+
+ # Add year column
+ df["year"] = year
+ logger.info(
+ f"Successfully processed data for year {year}, shape: {df.shape}"
+ )
+ all_data.append(df)
+
+ except Exception as e:
+ logger.error(f"Error processing year {year}: {str(e)}")
+ raise
+
+ # Combine all years
+ logger.debug(f"Combining data from {len(all_data)} years")
+ if len(all_data) > 1:
+ result = pd.concat(all_data)
+ logger.info(
+ f"Combined data from {len(years)} years, final shape: {result.shape}"
+ )
+ return result
+ else:
+ logger.info(
+ f"Returning data for single year, shape: {all_data[0].shape}"
+ )
+ return all_data[0]
+
+ except Exception as e:
+ logger.error(f"Error in _load: {str(e)}")
+ raise