diff --git a/docs/examples/basic_ranking.ipynb b/docs/examples/basic_ranking.ipynb
index f37750dd..c43df29a 100644
--- a/docs/examples/basic_ranking.ipynb
+++ b/docs/examples/basic_ranking.ipynb
@@ -39,20 +39,20 @@
"source": [
"# Recommending movies: ranking\n",
"\n",
- "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/recommenders/examples/basic_ranking\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/recommenders/blob/main/docs/examples/basic_ranking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/recommenders/blob/main/docs/examples/basic_ranking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/recommenders/docs/examples/basic_ranking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- "\u003c/table\u003e\n",
+ "
\n",
"\n"
]
},
@@ -360,11 +360,11 @@
" metrics=[tf.keras.metrics.RootMeanSquaredError()]\n",
" )\n",
"\n",
- " def call(self, features: Dict[str, tf.Tensor]) -\u003e tf.Tensor:\n",
+ " def call(self, features: Dict[str, tf.Tensor]) -> tf.Tensor:\n",
" return self.ranking_model(\n",
" (features[\"user_id\"], features[\"movie_title\"]))\n",
"\n",
- " def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -\u003e tf.Tensor:\n",
+ " def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:\n",
" labels = features.pop(\"user_rating\")\n",
" \n",
" rating_predictions = self(features)\n",
@@ -395,7 +395,7 @@
"outputs": [],
"source": [
"model = MovielensModel()\n",
- "model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))"
+ "callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)"
]
},
{
@@ -436,7 +436,8 @@
},
"outputs": [],
"source": [
- "model.fit(cached_train, epochs=3)"
+ "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))\n",
+ "model.fit(cached_train, epochs=100, validation_data=cached_test, callbacks=[callback])"
]
},
{
@@ -477,6 +478,49 @@
"The lower the RMSE metric, the more accurate our model is at predicting ratings."
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "x0p5KVl-3w1w"
+ },
+ "source": [
+ "### Test predictions distribution"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "e_GB-NMw3syS"
+ },
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import seaborn as sns\n",
+ "\n",
+ "sns.set_style('darkgrid')\n",
+ "\n",
+ "res_data = pd.DataFrame()\n",
+ "\n",
+ "# get predictions from cached test \n",
+ "res_data['predictions'] = model.predict(cached_test)[:, 0]\n",
+ "\n",
+ "# get user rating from test dataset\n",
+ "test_labels = []\n",
+ "for r in cached_test:\n",
+ " test_labels.append((r['user_rating']).numpy())\n",
+ "\n",
+ "res_data['test_labels'] = np.concatenate(test_labels)\n",
+ "\n",
+ "# plot everythin as kde\n",
+ "plt.figure(figsize=(10,6), dpi=100)\n",
+ "#sns.kdeplot(data=res_data, fill=True, bw_adjust=0.9, alpha=0.6, linewidth=0, legend=False)\n",
+ "sns.histplot(data=res_data, legend=False, alpha=0.8, bins=20)\n",
+ "plt.legend([\"Predictions\", \"Test Labels\"][::-1], title=\"Legend\", fontsize=12, title_fontsize=16)\n",
+ "plt.title('Predictions vs. test labels', fontsize=20);"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {
@@ -635,14 +679,13 @@
],
"metadata": {
"colab": {
- "collapsed_sections": [],
"name": "basic_ranking.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3.9.12 ('base')",
"language": "python",
"name": "python3"
},
@@ -656,7 +699,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.3"
+ "version": "3.9.12"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "4f523f7c76dd18e7ed336217f32f6f704c23c323644912475b9d3570cf04b060"
+ }
}
},
"nbformat": 4,