Skip to content

Commit 2e7e35c

Browse files
authored
fix one class svm (#98)
1 parent f742231 commit 2e7e35c

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

sqlflow_models/one_class_svm.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,38 @@
2020

2121
MODEL_PATH = "one_class_svm_model"
2222

23+
ENABLE_EAGER_EXECUTION = False
24+
25+
try:
26+
tf.enable_eager_execution()
27+
ENABLE_EAGER_EXECUTION = True
28+
except Exception:
29+
try:
30+
tf.compat.v1.enable_eager_execution()
31+
ENABLE_EAGER_EXECUTION = True
32+
except Exception:
33+
ENABLE_EAGER_EXECUTION = False
34+
35+
if ENABLE_EAGER_EXECUTION:
36+
print('eager execution mode is enabled')
37+
else:
38+
print('eager execution mode is disabled')
39+
40+
41+
def dataset_reader(dataset):
42+
if ENABLE_EAGER_EXECUTION:
43+
for features in dataset:
44+
yield features
45+
else:
46+
iter = dataset.make_one_shot_iterator()
47+
one_element = iter.get_next()
48+
with tf.Session() as sess:
49+
try:
50+
while True:
51+
yield sess.run(one_element)
52+
except tf.errors.OutOfRangeError:
53+
pass
54+
2355

2456
class OneClassSVM(tf.keras.Model):
2557
def __init__(self,
@@ -52,13 +84,15 @@ def __init__(self,
5284
def concat_features(self, features):
5385
assert isinstance(features, dict)
5486
each_feature = []
55-
for _, v in features.items():
56-
each_feature.append(v.numpy())
87+
for k, v in features.items():
88+
if ENABLE_EAGER_EXECUTION:
89+
v = v.numpy()
90+
each_feature.append(v)
5791
return np.concatenate(each_feature, axis=1)
5892

5993
def sqlflow_train_loop(self, dataset):
6094
X = []
61-
for features in dataset:
95+
for features in dataset_reader(dataset):
6296
X.append(self.concat_features(features))
6397
X = np.concatenate(X)
6498

tests/test_one_class_svm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import tensorflow as tf
2121
from sqlflow_models import OneClassSVM
22+
from sqlflow_models.one_class_svm import dataset_reader
2223

2324

2425
class TestOneClassSVM(unittest.TestCase):
@@ -51,7 +52,7 @@ def test_main(self):
5152
svm.sqlflow_train_loop(train_dataset)
5253

5354
predict_dataset = self.create_dataset()
54-
for features in predict_dataset:
55+
for features in dataset_reader(predict_dataset):
5556
pred = svm.sqlflow_predict_one(features)
5657
pred = np.array(pred)
5758
self.assertEqual(pred.shape, (1, 1))

0 commit comments

Comments
 (0)