Skip to content

Commit f742231

Browse files
authored
Add OneClassSVM model (#97)
* add one_class_svm * fix shape check * expand params
1 parent 36a2f21 commit f742231

File tree

3 files changed

+135
-0
lines changed

3 files changed

+135
-0
lines changed

sqlflow_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .native_keras import RawDNNClassifier
1212
from .custom_model_example import CustomClassifier
1313
from .gcn import GCN
14+
from .one_class_svm import OneClassSVM
1415
try:
1516
# NOTE: statsmodels have version conflict on PAI
1617
from .arima_with_stl_decomposition import ARIMAWithSTLDecomposition

sqlflow_models/one_class_svm.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import os
15+
import pickle
16+
17+
import numpy as np
18+
import tensorflow as tf
19+
from sklearn.svm import OneClassSVM as SklearnOneClassSVM
20+
21+
MODEL_PATH = "one_class_svm_model"
22+
23+
24+
class OneClassSVM(tf.keras.Model):
25+
def __init__(self,
26+
feature_columns=None,
27+
kernel='rbf',
28+
degree=3,
29+
gamma='scale',
30+
coef0=0.0,
31+
tol=0.001,
32+
nu=0.5,
33+
shrinking=True,
34+
cache_size=200,
35+
verbose=False,
36+
max_iter=-1):
37+
if os.path.exists(MODEL_PATH):
38+
with open(MODEL_PATH, "rb") as f:
39+
self.svm = pickle.load(f)
40+
else:
41+
self.svm = SklearnOneClassSVM(kernel=kernel,
42+
degree=degree,
43+
gamma=gamma,
44+
coef0=coef0,
45+
tol=tol,
46+
nu=nu,
47+
shrinking=shrinking,
48+
cache_size=cache_size,
49+
verbose=verbose,
50+
max_iter=max_iter)
51+
52+
def concat_features(self, features):
53+
assert isinstance(features, dict)
54+
each_feature = []
55+
for _, v in features.items():
56+
each_feature.append(v.numpy())
57+
return np.concatenate(each_feature, axis=1)
58+
59+
def sqlflow_train_loop(self, dataset):
60+
X = []
61+
for features in dataset:
62+
X.append(self.concat_features(features))
63+
X = np.concatenate(X)
64+
65+
self.svm.fit(X)
66+
with open(MODEL_PATH, "wb") as f:
67+
pickle.dump(self.svm, f, protocol=2)
68+
69+
def sqlflow_predict_one(self, features):
70+
features = self.concat_features(features)
71+
pred = self.svm.predict(features)
72+
return [pred]

tests/test_one_class_svm.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import os
15+
import shutil
16+
import tempfile
17+
import unittest
18+
19+
import numpy as np
20+
import tensorflow as tf
21+
from sqlflow_models import OneClassSVM
22+
23+
24+
class TestOneClassSVM(unittest.TestCase):
25+
def setUp(self):
26+
self.tmp_dir = tempfile.mkdtemp()
27+
self.old_cwd = os.getcwd()
28+
os.chdir(self.tmp_dir)
29+
30+
def tearDown(self):
31+
os.chdir(self.old_cwd)
32+
shutil.rmtree(self.tmp_dir)
33+
34+
def create_dataset(self):
35+
def generator():
36+
for _ in range(10):
37+
x1 = np.random.random(size=[1, 1])
38+
x2 = np.random.random(size=[1, 1])
39+
yield x1, x2
40+
41+
def dict_mapper(x1, x2):
42+
return {"x1": x1, "x2": x2}
43+
44+
dataset = tf.data.Dataset.from_generator(
45+
generator, output_types=(tf.dtypes.float32, tf.dtypes.float32))
46+
return dataset.map(dict_mapper)
47+
48+
def test_main(self):
49+
svm = OneClassSVM()
50+
train_dataset = self.create_dataset()
51+
svm.sqlflow_train_loop(train_dataset)
52+
53+
predict_dataset = self.create_dataset()
54+
for features in predict_dataset:
55+
pred = svm.sqlflow_predict_one(features)
56+
pred = np.array(pred)
57+
self.assertEqual(pred.shape, (1, 1))
58+
self.assertTrue(pred[0][0] == 1 or pred[0][0] == -1)
59+
60+
61+
if __name__ == '__main__':
62+
unittest.main()

0 commit comments

Comments
 (0)