Skip to content

Commit 2edacf9

Browse files
authored
Merge pull request #15 from rxavier/main
Bump tf and add plotting
2 parents 8d2b407 + 8ed136d commit 2edacf9

File tree

12 files changed

+261
-45
lines changed

12 files changed

+261
-45
lines changed

README.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
## Overview
44

5-
Embedding Encoder is a scikit-learn-compliant transformer that converts categorical variables to numeric vector representations. This is achieved by creating a small multilayer perceptron architecture in which each categorical variable is passed through an embedding layer, for which weights are extracted and turned into DataFrame columns.
5+
Embedding Encoder is a scikit-learn-compliant transformer that converts categorical variables into numeric vector representations. This is achieved by creating a small multilayer perceptron architecture in which each categorical variable is passed through an embedding layer, for which weights are extracted and turned into DataFrame columns.
6+
7+
While the idea is not new (it was popularized after [the team that landed in the 3rd place of the Rossmann Kaggle competition used it](https://www.kaggle.com/c/rossmann-store-sales/discussion/17974)), and although Python implementations have surfaced over the years, we are not aware of any library that integrates this functionality into scikit-learn.
68

79
## Installation and dependencies
810

@@ -88,7 +90,7 @@ from sklearn.preprocessing import StandardScaler
8890
from sklearn.impute import SimpleImputer
8991

9092
from embedding_encoder import EmbeddingEncoder
91-
from embedding_encoder.compose import ColumnTransformerWithNames
93+
from embedding_encoder.utils import ColumnTransformerWithNames
9294

9395
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
9496

@@ -106,6 +108,20 @@ pipe.fit(X_train, y_train)
106108

107109
Like scikit transformers, Embedding Encoder also has a `inverse_transform` method that recomposes the original input.
108110

111+
## Plotting embeddings
112+
113+
The idea behind embeddings is that categories that are conceptually similar should have similar vector representations. For example, "December" and "January" should be close to each other when the target variable is ice cream sales.
114+
115+
This can be analyzed with the `plot_embeddings` function.
116+
117+
```python
118+
from embedding_encoder import EmbeddingEncoder
119+
120+
ee = EmbeddingEncoder(task="classification")
121+
ee.fit(X=X, y=y)
122+
plot_embeddings(ee, variable="", )
123+
```
124+
109125
## Advanced usage
110126

111127
Embedding Encoder gives some control over the neural network. In particular, its constructor allows setting how deep and large the network should be (by modifying `layers_units`), as well as the dropout rate between dense layers. Epochs and batch size can also be modified.

docs/source/api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ EmbeddingEncoder class
1515

1616
Utilities
1717
---------------
18-
.. automodule:: embedding_encoder.compose
18+
.. automodule:: embedding_encoder.utils.compose
1919
:members:
2020
:undoc-members:
2121
:show-inheritance:

embedding_encoder/core.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
from sklearn.base import BaseEstimator, TransformerMixin
99

10-
from embedding_encoder.custom_ordinal import OrdinalEncoderStart1
10+
from embedding_encoder.utils import OrdinalEncoderStart1
1111

1212

1313
class EmbeddingEncoder(BaseEstimator, TransformerMixin):
@@ -465,3 +465,61 @@ def get_feature_names_out(self, input_features=None):
465465

466466
def get_feature_names(self, input_features=None):
467467
return self._columns_out
468+
469+
def plot_embeddings(self, variable: str, model: str = "pca"):
470+
"""Plot embeddings for a variable by passing a fitted EmbeddingEncoder and reducing to 2D.
471+
472+
Parameters
473+
----------
474+
variable :
475+
Variable to plot. Please note that scikit-learn's Pipeline might strip column names.
476+
model : str, optional
477+
Dimensionality reduction model. Either "tsne" or "pca". Default "pca".
478+
479+
Returns
480+
-------
481+
matplotlib.axes._subplots.AxesSubplot
482+
Seaborn scatterplot (Matplotlib axes)
483+
484+
Raises
485+
------
486+
ValueError
487+
If selected variable has less than 3 unique values.
488+
ValueError
489+
If selected model is not "tsne" or "pca".
490+
ImportError
491+
If seaborn is not installed.
492+
"""
493+
if self._embeddings_mapping[variable].shape[0] < 3:
494+
raise ValueError("Nothing to plot when variable has less than 3 unique values.")
495+
dimensions = 2
496+
if model not in ["tsne", "pca"]:
497+
raise ValueError("model must be either 'tsne' or 'pca'.")
498+
try:
499+
import seaborn as sns
500+
sns.set(rc={"figure.figsize": (8, 6), "figure.dpi": 100})
501+
except ImportError:
502+
raise ImportError("Plotting requires seaborn.")
503+
if model == "tsne":
504+
from sklearn.manifold import TSNE
505+
506+
model = TSNE(init="pca", n_components=dimensions, learning_rate="auto")
507+
else:
508+
from sklearn.decomposition import PCA
509+
510+
model = PCA(n_components=dimensions)
511+
512+
embeddings = self._embeddings_mapping[variable]
513+
variable_position = self._categorical_vars.index(variable)
514+
original_classes = self._ordinal_encoder.categories_[variable_position]
515+
original_index = ["OOV"] + list(original_classes)
516+
517+
reduced = model.fit_transform(embeddings)
518+
reduced = pd.DataFrame(
519+
reduced,
520+
index=original_index,
521+
columns=[f"Component {i}" for i in range(dimensions)],
522+
).rename_axis("Classes").reset_index()
523+
plot = sns.scatterplot(data=reduced, x="Component 0", y="Component 1", hue="Classes", s=100)
524+
plot.set_title(f"{model.__class__.__name__} embeddings projection for variable '{variable}'")
525+
return plot

embedding_encoder/examples/titanic.ipynb

Lines changed: 51 additions & 11 deletions
Large diffs are not rendered by default.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from embedding_encoder.utils.plot import plot_embeddings
2+
from embedding_encoder.utils.compose import ColumnTransformerWithNames
3+
from embedding_encoder.utils.custom_ordinal import OrdinalEncoderStart1
4+
5+
__all__ = ["plot_embeddings", "ColumnTransformerWithNames", "OrdinalEncoderStart1"]

embedding_encoder/utils/plot.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import pandas as pd
2+
3+
4+
def plot_embeddings(embedding_encoder, variable: str, model: str = "pca"):
5+
"""Plot embeddings for a variable by passing a fitted EmbeddingEncoder and reducing to 2D.
6+
7+
Parameters
8+
----------
9+
embedding_encoder : EmbeddingEncoder
10+
Fitted transformer.
11+
variable :
12+
Variable to plot. Please note that scikit-learn's Pipeline might strip column names.
13+
model : str, optional
14+
Dimensionality reduction model. Either "tsne" or "pca". Default "pca".
15+
16+
Returns
17+
-------
18+
matplotlib.axes._subplots.AxesSubplot
19+
Seaborn scatterplot (Matplotlib axes)
20+
21+
Raises
22+
------
23+
ValueError
24+
If selected variable has less than 3 unique values.
25+
ValueError
26+
If selected model is not "tsne" or "pca".
27+
ImportError
28+
If seaborn is not installed.
29+
"""
30+
if embedding_encoder._embeddings_mapping[variable].shape[0] < 3:
31+
raise ValueError("Nothing to plot when variable has less than 3 unique values.")
32+
dimensions = 2
33+
if model not in ["tsne", "pca"]:
34+
raise ValueError("model must be either 'tsne' or 'pca'.")
35+
try:
36+
import seaborn as sns
37+
sns.set(rc={"figure.figsize": (8, 6), "figure.dpi": 100})
38+
sns.set_palette("viridis")
39+
except ImportError:
40+
raise ImportError("Plotting requires seaborn.")
41+
if model == "tsne":
42+
from sklearn.manifold import TSNE
43+
44+
model = TSNE(init="pca", n_components=dimensions, learning_rate="auto")
45+
else:
46+
from sklearn.decomposition import PCA
47+
48+
model = PCA(n_components=dimensions)
49+
50+
embeddings = embedding_encoder._embeddings_mapping[variable]
51+
variable_position = embedding_encoder._categorical_vars.index(variable)
52+
original_classes = embedding_encoder._ordinal_encoder.categories_[variable_position]
53+
original_index = ["OOV"] + list(original_classes)
54+
55+
reduced = model.fit_transform(embeddings)
56+
reduced = pd.DataFrame(
57+
reduced,
58+
index=original_index,
59+
columns=[f"Component {i}" for i in range(dimensions)],
60+
).rename_axis("Classes").reset_index()
61+
plot = sns.scatterplot(data=reduced, x="Component 0", y="Component 1", hue="Classes", s=100)
62+
plot.set_title(f"{model.__class__.__name__} embeddings projection for variable '{variable}'")
63+
return plot

requirements-dev.txt

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ certifi==2021.10.8
2828
# via
2929
# -c requirements.txt
3030
# requests
31-
charset-normalizer==2.0.10
31+
charset-normalizer==2.0.12
3232
# via
3333
# -c requirements.txt
3434
# requests
@@ -78,7 +78,7 @@ idna==3.3
7878
# requests
7979
imagesize==1.3.0
8080
# via sphinx
81-
importlib-metadata==4.10.1
81+
importlib-metadata==4.11.1
8282
# via
8383
# -c requirements.txt
8484
# click
@@ -110,6 +110,7 @@ natsort==8.0.2
110110
# via domdf-python-tools
111111
packaging==21.3
112112
# via
113+
# -c requirements.txt
113114
# deprecation-alias
114115
# pytest
115116
# sphinx
@@ -126,8 +127,10 @@ pygments==2.10.0
126127
# sphinx
127128
# sphinx-prompt
128129
# sphinx-tabs
129-
pyparsing==3.0.6
130-
# via packaging
130+
pyparsing==3.0.7
131+
# via
132+
# -c requirements.txt
133+
# packaging
131134
pytest==6.2.5
132135
# via -r requirements-dev.in
133136
pytz==2021.3
@@ -142,9 +145,9 @@ requests==2.27.1
142145
# apeye
143146
# cachecontrol
144147
# sphinx
145-
ruamel.yaml==0.17.20
148+
ruamel-yaml==0.17.21
146149
# via sphinx-toolbox
147-
ruamel.yaml.clib==0.2.6
150+
ruamel-yaml-clib==0.2.6
148151
# via ruamel.yaml
149152
six==1.16.0
150153
# via
@@ -201,7 +204,7 @@ tornado==6.1
201204
# via livereload
202205
typed-ast==1.5.1
203206
# via black
204-
typing-extensions==4.0.1
207+
typing-extensions==4.1.1
205208
# via
206209
# -c requirements.txt
207210
# black

requirements.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
scikit-learn
2+
pandas
3+
tensorflow>=2.8.0
4+
seaborn

0 commit comments

Comments
 (0)