Skip to content

Commit 677dd14

Browse files
committed
rm _checkparam usage
1 parent 905767b commit 677dd14

File tree

4 files changed

+25
-33
lines changed

4 files changed

+25
-33
lines changed

mindocr/optim/adamw.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
import mindspore as ms
66
from mindspore import ops
7-
from mindspore._checkparam import Rel
8-
from mindspore._checkparam import Validator as validator
97
from mindspore.common.initializer import initializer
108
from mindspore.common.parameter import Parameter
119
from mindspore.common.tensor import Tensor
@@ -15,12 +13,12 @@
1513

1614
def _check_param_value(beta1, beta2, eps, prim_name):
1715
"""Check the type of inputs."""
18-
validator.check_value_type("beta1", beta1, [float], prim_name)
19-
validator.check_value_type("beta2", beta2, [float], prim_name)
20-
validator.check_value_type("eps", eps, [float], prim_name)
21-
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
22-
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
23-
validator.check_positive_float(eps, "eps", prim_name)
16+
assert isinstance(beta1, float), f"For '{prim_name}', the type of 'beta1' must be 'float', but got type '{type(beta1).__name__}'."
17+
assert isinstance(beta2, float), f"For '{prim_name}', the type of 'beta2' must be 'float', but got type '{type(beta2).__name__}'."
18+
assert isinstance(eps, float), f"For '{prim_name}', the type of 'eps' must be 'float', but got type '{type(eps).__name__}'."
19+
assert 0.0 < beta1 < 1.0, f"For '{prim_name}', the range of 'beta1' must be (0.0, 1.0), but got {beta1}."
20+
assert 0.0 < beta2 < 1.0, f"For '{prim_name}', the range of 'beta2' must be (0.0, 1.0), but got {beta2}."
21+
assert eps > 0, f"For '{prim_name}', the 'eps' must be positive, but got {eps}."
2422

2523

2624
_grad_scale = ops.MultitypeFuncGraph("grad_scale")

mindocr/optim/adan.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
"""adan"""
1616
from __future__ import absolute_import
1717

18-
from mindspore._checkparam import Rel
19-
from mindspore._checkparam import Validator as validator
2018
from mindspore.common import dtype as mstype
2119
from mindspore.common.api import ms_function
2220
from mindspore.common.tensor import Tensor
@@ -127,14 +125,15 @@ def _update_run_op(
127125
return op_cast(next_param, F.dtype(param))
128126

129127

130-
def _check_param_value(beta1, beta2, eps, prim_name):
128+
def _check_param_value(beta1, beta2, eps, use_locking, prim_name):
131129
"""Check the type of inputs."""
132-
validator.check_value_type("beta1", beta1, [float], prim_name)
133-
validator.check_value_type("beta2", beta2, [float], prim_name)
134-
validator.check_value_type("eps", eps, [float], prim_name)
135-
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
136-
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
137-
validator.check_positive_float(eps, "eps", prim_name)
130+
assert isinstance(beta1, float), f"For '{prim_name}', the type of 'beta1' must be 'float', but got type '{type(beta1).__name__}'."
131+
assert isinstance(beta2, float), f"For '{prim_name}', the type of 'beta2' must be 'float', but got type '{type(beta2).__name__}'."
132+
assert isinstance(eps, float), f"For '{prim_name}', the type of 'eps' must be 'float', but got type '{type(eps).__name__}'."
133+
assert 0.0 < beta1 < 1.0, f"For '{prim_name}', the range of 'beta1' must be (0.0, 1.0), but got {beta1}."
134+
assert 0.0 < beta2 < 1.0, f"For '{prim_name}', the range of 'beta2' must be (0.0, 1.0), but got {beta2}."
135+
assert eps > 0, f"For '{prim_name}', the 'eps' must be positive, but got {eps}."
136+
assert isinstance(use_locking, bool), f"For '{prim_name}', the type of 'use_locking' must be 'bool', but got type '{type(use_locking).__name__}'."
138137

139138

140139
class Adan(Optimizer):
@@ -161,8 +160,7 @@ def __init__(
161160
learning_rate, params, weight_decay=weight_decay, loss_scale=loss_scale
162161
) # Optimized inherit weight decay is bloaked. weight decay is computed in this py.
163162

164-
_check_param_value(beta1, beta2, eps, self.cls_name)
165-
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
163+
_check_param_value(beta1, beta2, eps, use_locking, self.cls_name)
166164

167165
self.beta1 = Tensor(beta1, mstype.float32)
168166
self.beta2 = Tensor(beta2, mstype.float32)

mindocr/optim/lion.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
import mindspore as ms
44
from mindspore import ops
5-
from mindspore._checkparam import Rel
6-
from mindspore._checkparam import Validator as validator
75
from mindspore.common.initializer import initializer
86
from mindspore.common.parameter import Parameter
97
from mindspore.common.tensor import Tensor
@@ -13,10 +11,10 @@
1311

1412
def _check_param_value(beta1, beta2, prim_name):
1513
"""Check the type of inputs."""
16-
validator.check_value_type("beta1", beta1, [float], prim_name)
17-
validator.check_value_type("beta2", beta2, [float], prim_name)
18-
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
19-
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
14+
assert isinstance(beta1, float), f"For '{prim_name}', the type of 'beta1' must be 'float', but got type '{type(beta1).__name__}'."
15+
assert isinstance(beta2, float), f"For '{prim_name}', the type of 'beta2' must be 'float', but got type '{type(beta2).__name__}'."
16+
assert 0.0 < beta1 < 1.0, f"For '{prim_name}', the range of 'beta1' must be (0.0, 1.0), but got {beta1}."
17+
assert 0.0 < beta2 < 1.0, f"For '{prim_name}', the range of 'beta2' must be (0.0, 1.0), but got {beta2}."
2018

2119

2220
_grad_scale = ops.MultitypeFuncGraph("grad_scale")

mindocr/optim/nadam.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
import mindspore as ms
55
from mindspore import ops
6-
from mindspore._checkparam import Rel
7-
from mindspore._checkparam import Validator as validator
86
from mindspore.common.api import ms_function
97
from mindspore.common.initializer import initializer
108
from mindspore.common.parameter import Parameter
@@ -15,12 +13,12 @@
1513

1614
def _check_param_value(beta1, beta2, eps, prim_name):
1715
"""Check the type of inputs."""
18-
validator.check_value_type("beta1", beta1, [float], prim_name)
19-
validator.check_value_type("beta2", beta2, [float], prim_name)
20-
validator.check_value_type("eps", eps, [float], prim_name)
21-
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
22-
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
23-
validator.check_positive_float(eps, "eps", prim_name)
16+
assert isinstance(beta1, float), f"For '{prim_name}', the type of 'beta1' must be 'float', but got type '{type(beta1).__name__}'."
17+
assert isinstance(beta2, float), f"For '{prim_name}', the type of 'beta2' must be 'float', but got type '{type(beta2).__name__}'."
18+
assert isinstance(eps, float), f"For '{prim_name}', the type of 'eps' must be 'float', but got type '{type(eps).__name__}'."
19+
assert 0.0 < beta1 < 1.0, f"For '{prim_name}', the range of 'beta1' must be (0.0, 1.0), but got {beta1}."
20+
assert 0.0 < beta2 < 1.0, f"For '{prim_name}', the range of 'beta2' must be (0.0, 1.0), but got {beta2}."
21+
assert eps > 0, f"For '{prim_name}', the 'eps' must be positive, but got {eps}."
2422

2523

2624
_scaler_one = Tensor(1, ms.float32)

0 commit comments

Comments
 (0)