diff --git a/docs/examples/basic_ranking.ipynb b/docs/examples/basic_ranking.ipynb index f37750dd..c43df29a 100644 --- a/docs/examples/basic_ranking.ipynb +++ b/docs/examples/basic_ranking.ipynb @@ -39,20 +39,20 @@ "source": [ "# Recommending movies: ranking\n", "\n", - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/recommenders/examples/basic_ranking\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/recommenders/blob/main/docs/examples/basic_ranking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/recommenders/blob/main/docs/examples/basic_ranking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/recommenders/docs/examples/basic_ranking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", - " \u003c/td\u003e\n", - "\u003c/table\u003e\n", + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
\n", "\n" ] }, @@ -360,11 +360,11 @@ " metrics=[tf.keras.metrics.RootMeanSquaredError()]\n", " )\n", "\n", - " def call(self, features: Dict[str, tf.Tensor]) -\u003e tf.Tensor:\n", + " def call(self, features: Dict[str, tf.Tensor]) -> tf.Tensor:\n", " return self.ranking_model(\n", " (features[\"user_id\"], features[\"movie_title\"]))\n", "\n", - " def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -\u003e tf.Tensor:\n", + " def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:\n", " labels = features.pop(\"user_rating\")\n", " \n", " rating_predictions = self(features)\n", @@ -395,7 +395,7 @@ "outputs": [], "source": [ "model = MovielensModel()\n", - "model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))" + "callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)" ] }, { @@ -436,7 +436,8 @@ }, "outputs": [], "source": [ - "model.fit(cached_train, epochs=3)" + "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))\n", + "model.fit(cached_train, epochs=100, validation_data=cached_test, callbacks=[callback])" ] }, { @@ -477,6 +478,49 @@ "The lower the RMSE metric, the more accurate our model is at predicting ratings." ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "x0p5KVl-3w1w" + }, + "source": [ + "### Test predictions distribution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e_GB-NMw3syS" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "sns.set_style('darkgrid')\n", + "\n", + "res_data = pd.DataFrame()\n", + "\n", + "# get predictions from cached test \n", + "res_data['predictions'] = model.predict(cached_test)[:, 0]\n", + "\n", + "# get user rating from test dataset\n", + "test_labels = []\n", + "for r in cached_test:\n", + " test_labels.append((r['user_rating']).numpy())\n", + "\n", + "res_data['test_labels'] = np.concatenate(test_labels)\n", + "\n", + "# plot everythin as kde\n", + "plt.figure(figsize=(10,6), dpi=100)\n", + "#sns.kdeplot(data=res_data, fill=True, bw_adjust=0.9, alpha=0.6, linewidth=0, legend=False)\n", + "sns.histplot(data=res_data, legend=False, alpha=0.8, bins=20)\n", + "plt.legend([\"Predictions\", \"Test Labels\"][::-1], title=\"Legend\", fontsize=12, title_fontsize=16)\n", + "plt.title('Predictions vs. test labels', fontsize=20);" + ] + }, { "cell_type": "markdown", "metadata": { @@ -635,14 +679,13 @@ ], "metadata": { "colab": { - "collapsed_sections": [], "name": "basic_ranking.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.9.12 ('base')", "language": "python", "name": "python3" }, @@ -656,7 +699,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.9.12" + }, + "vscode": { + "interpreter": { + "hash": "4f523f7c76dd18e7ed336217f32f6f704c23c323644912475b9d3570cf04b060" + } } }, "nbformat": 4,