From 55d6d8b6b4373cba3eaba572bb4ec7bcffccd15f Mon Sep 17 00:00:00 2001 From: chemistatgoogle <128421022+chemistatgoogle@users.noreply.github.com> Date: Mon, 20 Mar 2023 12:44:02 -0700 Subject: [PATCH] Update for changes in TensorFlow 2.6 tf.keras.Sequential.predict and tf.keras.Sequential.predict_classes has been updated. --- .../testing-debugging-classification.ipynb | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ml/testing-debugging/testing-debugging-classification.ipynb b/ml/testing-debugging/testing-debugging-classification.ipynb index c4faefc..bb94e4c 100644 --- a/ml/testing-debugging/testing-debugging-classification.ipynb +++ b/ml/testing-debugging/testing-debugging-classification.ipynb @@ -979,7 +979,8 @@ }, "source": [ "from sklearn.metrics import classification_report\n", - "mnistPred = model.predict_classes(x = mnistData)\n", + "mnistPred = model.predict(x = mnistData)\n", + "mnistPred_classes=np.argmax(mnistPred,axis=1)\n" "print(classification_report(mnistLabels, mnistPred))" ], "execution_count": 0, @@ -1047,7 +1048,8 @@ " \n", " def testStd(self):\n", " y = model.predict(mnistData)\n", - " yStd = np.std(y)\n", + " yClasses = np.argmax(y,axis=1)\n" + " yStd = np.std(yClasses)\n", " yStdActual = np.std(mnistLabels)\n", " deltaStd = 0.05\n", " errorMsg = 'Std. dev. of predicted values ' + str(yStd) + \\\n", @@ -1057,7 +1059,8 @@ "\n", " def testMean(self):\n", " y = model.predict(mnistData)\n", - " yMean = np.mean(y)\n", + " yClasses = np.argmax(y,axis=1)\n" + " yMean = np.mean(yClasses)\n", " yMeanActual = np.mean(mnistLabels)\n", " deltaMean = 0.05\n", " errorMsg = 'Mean of predicted values ' + str(yMean) + \\\n",