diff --git a/ml/testing-debugging/testing-debugging-classification.ipynb b/ml/testing-debugging/testing-debugging-classification.ipynb index c4faefc..bb94e4c 100644 --- a/ml/testing-debugging/testing-debugging-classification.ipynb +++ b/ml/testing-debugging/testing-debugging-classification.ipynb @@ -979,7 +979,8 @@ }, "source": [ "from sklearn.metrics import classification_report\n", - "mnistPred = model.predict_classes(x = mnistData)\n", + "mnistPred = model.predict(x = mnistData)\n", + "mnistPred_classes=np.argmax(mnistPred,axis=1)\n" "print(classification_report(mnistLabels, mnistPred))" ], "execution_count": 0, @@ -1047,7 +1048,8 @@ " \n", " def testStd(self):\n", " y = model.predict(mnistData)\n", - " yStd = np.std(y)\n", + " yClasses = np.argmax(y,axis=1)\n" + " yStd = np.std(yClasses)\n", " yStdActual = np.std(mnistLabels)\n", " deltaStd = 0.05\n", " errorMsg = 'Std. dev. of predicted values ' + str(yStd) + \\\n", @@ -1057,7 +1059,8 @@ "\n", " def testMean(self):\n", " y = model.predict(mnistData)\n", - " yMean = np.mean(y)\n", + " yClasses = np.argmax(y,axis=1)\n" + " yMean = np.mean(yClasses)\n", " yMeanActual = np.mean(mnistLabels)\n", " deltaMean = 0.05\n", " errorMsg = 'Mean of predicted values ' + str(yMean) + \\\n",