Skip to content

Commit 91c0846

Browse files
pkan2Squadrick
authored andcommitted
Implement Conditional Gradient Optimizer (#469)
1 parent 370cc2a commit 91c0846

File tree

5 files changed

+748
-0
lines changed

5 files changed

+748
-0
lines changed

tensorflow_addons/optimizers/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ py_library(
66
name = "optimizers",
77
srcs = [
88
"__init__.py",
9+
"conditional_gradient.py",
910
"lazy_adam.py",
1011
"lookahead.py",
1112
"moving_average.py",
@@ -18,6 +19,19 @@ py_library(
1819
],
1920
)
2021

22+
py_test(
23+
name = "conditional_gradient_test",
24+
size = "small",
25+
srcs = [
26+
"conditional_gradient_test.py",
27+
],
28+
main = "conditional_gradient_test.py",
29+
srcs_version = "PY2AND3",
30+
deps = [
31+
":optimizers",
32+
],
33+
)
34+
2135
py_test(
2236
name = "lazy_adam_test",
2337
size = "small",

tensorflow_addons/optimizers/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Maintainers
44
| Submodule | Maintainers | Contact Info |
55
|:---------- |:------------- |:--------------|
6+
| conditional_gradient | Pengyu Kan, Vishnu Lokhande | pkan2@wisc.edu, lokhande@cs.wisc.edu |
67
| lazy_adam | Saishruthi Swaminathan | saishruthi.tn@gmail.com |
78
| lookahead | Zhao Hanguang | cyberzhg@gmail.com |
89
| moving_average | Dheeraj R. Reddy | dheeraj98reddy@gmail.com |
@@ -13,6 +14,7 @@
1314
## Components
1415
| Submodule | Optimizer | Reference |
1516
|:--------- |:---------- |:---------|
17+
| conditional_gradient | ConditionalGradient | https://arxiv.org/pdf/1803.06453.pdf |
1618
| lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 |
1719
| lookahead | Lookahead | https://arxiv.org/abs/1907.08610v1 |
1820
| moving_average | MovingAverage | |

tensorflow_addons/optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
from tensorflow_addons.optimizers.conditional_gradient import ConditionalGradient
2122
from tensorflow_addons.optimizers.lazy_adam import LazyAdam
2223
from tensorflow_addons.optimizers.lookahead import Lookahead
2324
from tensorflow_addons.optimizers.moving_average import MovingAverage
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Conditional Gradient method for TensorFlow."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
from tensorflow_addons.utils import keras_utils
22+
23+
24+
@keras_utils.register_keras_custom_object
25+
class ConditionalGradient(tf.keras.optimizers.Optimizer):
26+
"""Optimizer that implements the Conditional Gradient optimization.
27+
28+
This optimizer helps handle constraints well.
29+
30+
Currently only supports frobenius norm constraint.
31+
See https://arxiv.org/pdf/1803.06453.pdf
32+
33+
```
34+
variable -= (1-learning_rate)
35+
* (variable + lambda_ * gradient / frobenius_norm(gradient))
36+
```
37+
38+
Note that we choose "lambda_" here to refer to the constraint "lambda" in the paper.
39+
"""
40+
41+
def __init__(self,
42+
learning_rate,
43+
lambda_,
44+
use_locking=False,
45+
name='ConditionalGradient',
46+
**kwargs):
47+
"""Construct a conditional gradient optimizer.
48+
49+
Args:
50+
learning_rate: A `Tensor` or a floating point value.
51+
The learning rate.
52+
lambda_: A `Tensor` or a floating point value. The constraint.
53+
use_locking: If `True` use locks for update operations.
54+
name: Optional name prefix for the operations created when
55+
applying gradients. Defaults to 'ConditionalGradient'
56+
"""
57+
super(ConditionalGradient, self).__init__(name=name, **kwargs)
58+
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
59+
self._set_hyper('lambda_', lambda_)
60+
self._set_hyper('use_locking', use_locking)
61+
62+
def get_config(self):
63+
config = {
64+
'learning_rate': self._serialize_hyperparameter('learning_rate'),
65+
'lambda_': self._serialize_hyperparameter('lambda_'),
66+
'use_locking': self._serialize_hyperparameter('use_locking')
67+
}
68+
base_config = super(ConditionalGradient, self).get_config()
69+
return dict(list(base_config.items()) + list(config.items()))
70+
71+
def _create_slots(self, var_list):
72+
for v in var_list:
73+
self.add_slot(v, 'conditional_gradient')
74+
75+
def _prepare_local(self, var_device, var_dtype, apply_state):
76+
super(ConditionalGradient, self)._prepare_local(
77+
var_device, var_dtype, apply_state)
78+
apply_state[(var_device, var_dtype)]['learning_rate'] = tf.identity(
79+
self._get_hyper('learning_rate', var_dtype))
80+
apply_state[(var_device, var_dtype)]['lambda_'] = tf.identity(
81+
self._get_hyper('lambda_', var_dtype))
82+
83+
def _resource_apply_dense(self, grad, var, apply_state=None):
84+
def frobenius_norm(m):
85+
return tf.math.reduce_sum(m**2)**0.5
86+
87+
var_device, var_dtype = var.device, var.dtype.base_dtype
88+
coefficients = ((apply_state or {}).get((var_device, var_dtype))
89+
or self._fallback_apply_state(var_device, var_dtype))
90+
norm = tf.convert_to_tensor(
91+
frobenius_norm(grad), name='norm', dtype=var.dtype.base_dtype)
92+
lr = coefficients['learning_rate']
93+
lambda_ = coefficients['lambda_']
94+
var_update_tensor = (
95+
tf.math.multiply(var, lr) - (1 - lr) * lambda_ * grad / norm)
96+
var_update_kwargs = {
97+
'resource': var.handle,
98+
'value': var_update_tensor,
99+
}
100+
var_update_op = tf.raw_ops.AssignVariableOp(**var_update_kwargs)
101+
return tf.group(var_update_op)
102+
103+
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
104+
def frobenius_norm(m):
105+
return tf.reduce_sum(m**2)**0.5
106+
107+
var_device, var_dtype = var.device, var.dtype.base_dtype
108+
coefficients = ((apply_state or {}).get((var_device, var_dtype))
109+
or self._fallback_apply_state(var_device, var_dtype))
110+
norm = tf.convert_to_tensor(
111+
frobenius_norm(grad), name='norm', dtype=var.dtype.base_dtype)
112+
lr = coefficients['learning_rate']
113+
lambda_ = coefficients['lambda_']
114+
var_slice = tf.gather(var, indices)
115+
var_update_value = (
116+
tf.math.multiply(var_slice, lr) - (1 - lr) * lambda_ * grad / norm)
117+
var_update_kwargs = {
118+
'resource': var.handle,
119+
'indices': indices,
120+
'updates': var_update_value
121+
}
122+
var_update_op = tf.raw_ops.ResourceScatterUpdate(**var_update_kwargs)
123+
return tf.group(var_update_op)

0 commit comments

Comments
 (0)