|
20 | 20 |
|
21 | 21 | MODEL_PATH = "one_class_svm_model" |
22 | 22 |
|
| 23 | +ENABLE_EAGER_EXECUTION = False |
| 24 | + |
| 25 | +try: |
| 26 | + tf.enable_eager_execution() |
| 27 | + ENABLE_EAGER_EXECUTION = True |
| 28 | +except Exception: |
| 29 | + try: |
| 30 | + tf.compat.v1.enable_eager_execution() |
| 31 | + ENABLE_EAGER_EXECUTION = True |
| 32 | + except Exception: |
| 33 | + ENABLE_EAGER_EXECUTION = False |
| 34 | + |
| 35 | +if ENABLE_EAGER_EXECUTION: |
| 36 | + print('eager execution mode is enabled') |
| 37 | +else: |
| 38 | + print('eager execution mode is disabled') |
| 39 | + |
| 40 | + |
| 41 | +def dataset_reader(dataset): |
| 42 | + if ENABLE_EAGER_EXECUTION: |
| 43 | + for features in dataset: |
| 44 | + yield features |
| 45 | + else: |
| 46 | + iter = dataset.make_one_shot_iterator() |
| 47 | + one_element = iter.get_next() |
| 48 | + with tf.Session() as sess: |
| 49 | + try: |
| 50 | + while True: |
| 51 | + yield sess.run(one_element) |
| 52 | + except tf.errors.OutOfRangeError: |
| 53 | + pass |
| 54 | + |
23 | 55 |
|
24 | 56 | class OneClassSVM(tf.keras.Model): |
25 | 57 | def __init__(self, |
@@ -52,13 +84,15 @@ def __init__(self, |
52 | 84 | def concat_features(self, features): |
53 | 85 | assert isinstance(features, dict) |
54 | 86 | each_feature = [] |
55 | | - for _, v in features.items(): |
56 | | - each_feature.append(v.numpy()) |
| 87 | + for k, v in features.items(): |
| 88 | + if ENABLE_EAGER_EXECUTION: |
| 89 | + v = v.numpy() |
| 90 | + each_feature.append(v) |
57 | 91 | return np.concatenate(each_feature, axis=1) |
58 | 92 |
|
59 | 93 | def sqlflow_train_loop(self, dataset): |
60 | 94 | X = [] |
61 | | - for features in dataset: |
| 95 | + for features in dataset_reader(dataset): |
62 | 96 | X.append(self.concat_features(features)) |
63 | 97 | X = np.concatenate(X) |
64 | 98 |
|
|
0 commit comments