Skip to content

Commit bdb00fe

Browse files
committed
Fix initial value of nonlinear optimization in COPT
1 parent 0049661 commit bdb00fe

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

src/pyoptinterface/_src/copt.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def init_default_env():
8787
VariableAttribute.Value: lambda model, v: model.get_variable_info(v, "Value"),
8888
VariableAttribute.LowerBound: lambda model, v: model.get_variable_info(v, "LB"),
8989
VariableAttribute.UpperBound: lambda model, v: model.get_variable_info(v, "UB"),
90-
VariableAttribute.PrimalStart: lambda model, v: model.mip_start_values.get(v, None),
90+
VariableAttribute.PrimalStart: lambda model, v: model.variable_start_values.get(
91+
v, None
92+
),
9193
VariableAttribute.Domain: lambda model, v: model.get_variable_type(v),
9294
VariableAttribute.Name: lambda model, v: model.get_variable_name(v),
9395
VariableAttribute.IISLowerBound: lambda model, v: model._get_variable_lowerbound_IIS(
@@ -102,8 +104,7 @@ def init_default_env():
102104

103105

104106
def set_variable_start(model, v, val):
105-
model.mip_start_values[v] = val
106-
model.nl_start_values[v] = val
107+
model.variable_start_values[v] = val
107108

108109

109110
variable_attribute_set_func_map = {
@@ -380,7 +381,7 @@ def __init__(self, env=None):
380381

381382
# We must keep a reference to the environment to prevent it from being garbage collected
382383
self._env = env
383-
self.mip_start_values: Dict[VariableIndex, float] = dict()
384+
self.variable_start_values: Dict[VariableIndex, float] = dict()
384385
self.nl_start_values: Dict[VariableIndex, float] = dict()
385386

386387
def add_variables(self, *args, **kwargs):
@@ -449,7 +450,7 @@ def _is_mip(self):
449450
def _has_nl(self):
450451
nlconstrs = self.get_raw_attribute_int("NLConstrs")
451452
hasnlobj = self.get_raw_attribute_int("HasNLObj")
452-
return nlconstrs > 0 and hasnlobj > 0
453+
return nlconstrs > 0 or hasnlobj > 0
453454

454455
def get_model_attribute(self, attribute: ModelAttribute):
455456
def e(attribute):
@@ -537,20 +538,21 @@ def get_raw_attribute(self, param_name: str):
537538
return get_function(param_name)
538539

539540
def optimize(self):
540-
if self._is_mip():
541-
mip_start = self.mip_start_values
542-
if len(mip_start) != 0:
543-
variables = list(mip_start.keys())
544-
values = list(mip_start.values())
545-
self.add_mip_start(variables, values)
546-
mip_start.clear()
547-
if self._has_nl():
548-
nl_start = self.nl_start_values
549-
if len(nl_start) != 0:
550-
variables = list(nl_start.keys())
551-
values = list(nl_start.values())
552-
self.add_nl_start(variables, values)
553-
nl_start.clear()
541+
is_mip = self._is_mip()
542+
is_nl = self._has_nl()
543+
544+
if is_mip or is_nl:
545+
variable_start_values = self.variable_start_values
546+
if len(variable_start_values) != 0:
547+
variables = list(variable_start_values.keys())
548+
values = list(variable_start_values.values())
549+
variable_start_values.clear()
550+
551+
if is_mip:
552+
self.add_mip_start(variables, values)
553+
if is_nl:
554+
self.add_nl_start(variables, values)
555+
554556
super().optimize()
555557

556558
def cb_get_info(self, what):

0 commit comments

Comments
 (0)