@@ -87,7 +87,9 @@ def init_default_env():
8787 VariableAttribute .Value : lambda model , v : model .get_variable_info (v , "Value" ),
8888 VariableAttribute .LowerBound : lambda model , v : model .get_variable_info (v , "LB" ),
8989 VariableAttribute .UpperBound : lambda model , v : model .get_variable_info (v , "UB" ),
90- VariableAttribute .PrimalStart : lambda model , v : model .mip_start_values .get (v , None ),
90+ VariableAttribute .PrimalStart : lambda model , v : model .variable_start_values .get (
91+ v , None
92+ ),
9193 VariableAttribute .Domain : lambda model , v : model .get_variable_type (v ),
9294 VariableAttribute .Name : lambda model , v : model .get_variable_name (v ),
9395 VariableAttribute .IISLowerBound : lambda model , v : model ._get_variable_lowerbound_IIS (
@@ -102,8 +104,7 @@ def init_default_env():
102104
103105
104106def set_variable_start (model , v , val ):
105- model .mip_start_values [v ] = val
106- model .nl_start_values [v ] = val
107+ model .variable_start_values [v ] = val
107108
108109
109110variable_attribute_set_func_map = {
@@ -380,7 +381,7 @@ def __init__(self, env=None):
380381
381382 # We must keep a reference to the environment to prevent it from being garbage collected
382383 self ._env = env
383- self .mip_start_values : Dict [VariableIndex , float ] = dict ()
384+ self .variable_start_values : Dict [VariableIndex , float ] = dict ()
384385 self .nl_start_values : Dict [VariableIndex , float ] = dict ()
385386
386387 def add_variables (self , * args , ** kwargs ):
@@ -449,7 +450,7 @@ def _is_mip(self):
449450 def _has_nl (self ):
450451 nlconstrs = self .get_raw_attribute_int ("NLConstrs" )
451452 hasnlobj = self .get_raw_attribute_int ("HasNLObj" )
452- return nlconstrs > 0 and hasnlobj > 0
453+ return nlconstrs > 0 or hasnlobj > 0
453454
454455 def get_model_attribute (self , attribute : ModelAttribute ):
455456 def e (attribute ):
@@ -537,20 +538,21 @@ def get_raw_attribute(self, param_name: str):
537538 return get_function (param_name )
538539
539540 def optimize (self ):
540- if self ._is_mip ():
541- mip_start = self .mip_start_values
542- if len (mip_start ) != 0 :
543- variables = list (mip_start .keys ())
544- values = list (mip_start .values ())
545- self .add_mip_start (variables , values )
546- mip_start .clear ()
547- if self ._has_nl ():
548- nl_start = self .nl_start_values
549- if len (nl_start ) != 0 :
550- variables = list (nl_start .keys ())
551- values = list (nl_start .values ())
552- self .add_nl_start (variables , values )
553- nl_start .clear ()
541+ is_mip = self ._is_mip ()
542+ is_nl = self ._has_nl ()
543+
544+ if is_mip or is_nl :
545+ variable_start_values = self .variable_start_values
546+ if len (variable_start_values ) != 0 :
547+ variables = list (variable_start_values .keys ())
548+ values = list (variable_start_values .values ())
549+ variable_start_values .clear ()
550+
551+ if is_mip :
552+ self .add_mip_start (variables , values )
553+ if is_nl :
554+ self .add_nl_start (variables , values )
555+
554556 super ().optimize ()
555557
556558 def cb_get_info (self , what ):
0 commit comments