1515"""adan"""
1616from __future__ import absolute_import
1717
18- from mindspore ._checkparam import Rel
19- from mindspore ._checkparam import Validator as validator
2018from mindspore .common import dtype as mstype
2119from mindspore .common .api import ms_function
2220from mindspore .common .tensor import Tensor
@@ -127,14 +125,15 @@ def _update_run_op(
127125 return op_cast (next_param , F .dtype (param ))
128126
129127
130- def _check_param_value (beta1 , beta2 , eps , prim_name ):
128+ def _check_param_value (beta1 , beta2 , eps , use_locking , prim_name ):
131129 """Check the type of inputs."""
132- validator .check_value_type ("beta1" , beta1 , [float ], prim_name )
133- validator .check_value_type ("beta2" , beta2 , [float ], prim_name )
134- validator .check_value_type ("eps" , eps , [float ], prim_name )
135- validator .check_float_range (beta1 , 0.0 , 1.0 , Rel .INC_NEITHER , "beta1" , prim_name )
136- validator .check_float_range (beta2 , 0.0 , 1.0 , Rel .INC_NEITHER , "beta2" , prim_name )
137- validator .check_positive_float (eps , "eps" , prim_name )
130+ assert isinstance (beta1 , float ), f"For '{ prim_name } ', the type of 'beta1' must be 'float', but got type '{ type (beta1 ).__name__ } '."
131+ assert isinstance (beta2 , float ), f"For '{ prim_name } ', the type of 'beta2' must be 'float', but got type '{ type (beta2 ).__name__ } '."
132+ assert isinstance (eps , float ), f"For '{ prim_name } ', the type of 'eps' must be 'float', but got type '{ type (eps ).__name__ } '."
133+ assert 0.0 < beta1 < 1.0 , f"For '{ prim_name } ', the range of 'beta1' must be (0.0, 1.0), but got { beta1 } ."
134+ assert 0.0 < beta2 < 1.0 , f"For '{ prim_name } ', the range of 'beta2' must be (0.0, 1.0), but got { beta2 } ."
135+ assert eps > 0 , f"For '{ prim_name } ', the 'eps' must be positive, but got { eps } ."
136+ assert isinstance (use_locking , bool ), f"For '{ prim_name } ', the type of 'use_locking' must be 'bool', but got type '{ type (use_locking ).__name__ } '."
138137
139138
140139class Adan (Optimizer ):
@@ -161,8 +160,7 @@ def __init__(
161160 learning_rate , params , weight_decay = weight_decay , loss_scale = loss_scale
162161 ) # Optimized inherit weight decay is bloaked. weight decay is computed in this py.
163162
164- _check_param_value (beta1 , beta2 , eps , self .cls_name )
165- validator .check_value_type ("use_locking" , use_locking , [bool ], self .cls_name )
163+ _check_param_value (beta1 , beta2 , eps , use_locking , self .cls_name )
166164
167165 self .beta1 = Tensor (beta1 , mstype .float32 )
168166 self .beta2 = Tensor (beta2 , mstype .float32 )
0 commit comments