Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 41 additions & 66 deletions impedance/models/circuits/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,27 +96,24 @@ def fit(self, frequencies, impedance, bounds=None,
if len(frequencies) != len(impedance):
raise TypeError('length of frequencies and impedance do not match')

if self.initial_guess != []:
parameters, conf = circuit_fit(frequencies, impedance,
self.circuit, self.initial_guess,
constants=self.constants,
bounds=bounds,
weight_by_modulus=weight_by_modulus,
**kwargs)
self.parameters_ = parameters
if conf is not None:
self.conf_ = conf
else:
if self.initial_guess == []:
raise ValueError('No initial guess supplied')

parameters, conf = circuit_fit(frequencies, impedance,
self.circuit, self.initial_guess,
constants=self.constants,
bounds=bounds,
weight_by_modulus=weight_by_modulus,
**kwargs)
self.parameters_ = parameters
if conf is not None:
self.conf_ = conf

return self

def _is_fit(self):
""" check if model has been fit (parameters_ is not None) """
if self.parameters_ is not None:
return True
else:
return False
return (self.parameters_ is not None)

def predict(self, frequencies, use_initial=False):
""" Predict impedance using an equivalent circuit model
Expand All @@ -135,19 +132,16 @@ def predict(self, frequencies, use_initial=False):
"""
frequencies = np.array(frequencies, dtype=float)

if self._is_fit() and not use_initial:
return eval(buildCircuit(self.circuit, frequencies,
*self.parameters_,
constants=self.constants, eval_string='',
index=0)[0],
circuit_elements)
else:
parameters_for_fit = self.parameters_
if not self._is_fit() or use_initial:
warnings.warn("Simulating circuit based on initial parameters")
return eval(buildCircuit(self.circuit, frequencies,
*self.initial_guess,
constants=self.constants, eval_string='',
index=0)[0],
circuit_elements)
parameters_for_fit = self.initial_guess

return eval(buildCircuit(self.circuit, frequencies,
*parameters_for_fit,
constants=self.constants, eval_string='',
index=0)[0],
circuit_elements)

def get_param_names(self):
""" Converts circuit string to names and units """
Expand Down Expand Up @@ -236,6 +230,8 @@ def plot(self, ax=None, f_data=None, Z_data=None, kind='altair', **kwargs):
axes of the created nyquist plot
"""

f_pred = f_data if f_data is not None else np.logspace(5, -3)
Z_fit = self.predict(f_pred) if self._is_fit() else None
if kind == 'nyquist':
if ax is None:
_, ax = plt.subplots(figsize=(5, 5))
Expand All @@ -244,23 +240,12 @@ def plot(self, ax=None, f_data=None, Z_data=None, kind='altair', **kwargs):
ax = plot_nyquist(Z_data, ls='', marker='s', ax=ax, **kwargs)

if self._is_fit():
if f_data is not None:
f_pred = f_data
else:
f_pred = np.logspace(5, -3)

Z_fit = self.predict(f_pred)
ax = plot_nyquist(Z_fit, ls='-', marker='', ax=ax, **kwargs)
return ax
elif kind == 'bode':
if ax is None:
_, ax = plt.subplots(nrows=2, figsize=(5, 5))

if f_data is not None:
f_pred = f_data
else:
f_pred = np.logspace(5, -3)

if Z_data is not None:
if f_data is None:
raise ValueError('f_data must be specified if' +
Expand All @@ -269,7 +254,6 @@ def plot(self, ax=None, f_data=None, Z_data=None, kind='altair', **kwargs):
axes=ax, **kwargs)

if self._is_fit():
Z_fit = self.predict(f_pred)
ax = plot_bode(f_pred, Z_fit, ls='-', marker='',
axes=ax, **kwargs)
return ax
Expand All @@ -280,12 +264,6 @@ def plot(self, ax=None, f_data=None, Z_data=None, kind='altair', **kwargs):
plot_dict['data'] = {'f': f_data, 'Z': Z_data}

if self._is_fit():
if f_data is not None:
f_pred = f_data
else:
f_pred = np.logspace(5, -3)

Z_fit = self.predict(f_pred)
if self.name is not None:
name = self.name
else:
Expand All @@ -312,24 +290,20 @@ def save(self, filepath):

initial_guess = self.initial_guess

data_dict = {"Name": model_name,
"Circuit String": model_string,
"Initial Guess": initial_guess,
"Constants": self.constants,
"Fit": False}

if self._is_fit():
data_dict["Fit"] = True
parameters_ = list(self.parameters_)
model_conf_ = list(self.conf_)

data_dict = {"Name": model_name,
"Circuit String": model_string,
"Initial Guess": initial_guess,
"Constants": self.constants,
"Fit": True,
"Parameters": parameters_,
"Confidence": model_conf_,
}
else:
data_dict = {"Name": model_name,
"Circuit String": model_string,
"Initial Guess": initial_guess,
"Constants": self.constants,
"Fit": False}
data_dict.update({
"Parameters": parameters_,
"Confidence": model_conf_
})

with open(filepath, 'w') as f:
json.dump(data_dict, f)
Expand Down Expand Up @@ -364,12 +338,13 @@ def load(self, filepath, fitted_as_initial=False):
self.constants = model_constants
self.name = model_name

if json_data["Fit"]:
if fitted_as_initial:
self.initial_guess = np.array(json_data['Parameters'])
else:
self.parameters_ = np.array(json_data["Parameters"])
self.conf_ = np.array(json_data["Confidence"])
if not json_data["Fit"]:
return
if fitted_as_initial:
self.initial_guess = np.array(json_data['Parameters'])
else:
self.parameters_ = np.array(json_data["Parameters"])
self.conf_ = np.array(json_data["Confidence"])


class Randles(BaseCircuit):
Expand Down
2 changes: 1 addition & 1 deletion impedance/models/circuits/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def wrapper(p, f):
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__

global circuit_elements
# global circuit_elements
if func.__name__ in ["s", "p"]:
raise ElementError("cannot redefine elements 's' (series)" +
"or 'p' (parallel)")
Expand Down
83 changes: 42 additions & 41 deletions impedance/models/circuits/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def opt_function(x):
np.hstack([Z.real, Z.imag]))

class BasinhoppingBounds(object):
""" Adapted from the basinhopping documetation
""" Adapted from the basinhopping documentation
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.basinhopping.html
"""

Expand Down Expand Up @@ -293,24 +293,28 @@ def count_parens(string):
result = []
skipped = []
for i, sub_str in enumerate(split):
if i not in skipped:
if '(' not in sub_str and ')' not in sub_str:
result.append(sub_str)
else:
open_parens, closed_parens = count_parens(sub_str)
if open_parens == closed_parens:
result.append(sub_str)
else:
uneven = True
while i < len(split) - 1 and uneven:
sub_str += special + split[i+1]

open_parens, closed_parens = count_parens(sub_str)
uneven = open_parens != closed_parens

i += 1
skipped.append(i)
result.append(sub_str)
if i in skipped:
continue

if '(' not in sub_str and ')' not in sub_str:
result.append(sub_str)
continue

open_parens, closed_parens = count_parens(sub_str)
if open_parens == closed_parens:
result.append(sub_str)
continue

uneven = True
while i < len(split) - 1 and uneven:
sub_str += special + split[i+1]

open_parens, closed_parens = count_parens(sub_str)
uneven = open_parens != closed_parens

i += 1
skipped.append(i)
result.append(sub_str)
return result

parallel = parse_circuit(circuit, parallel=True)
Expand Down Expand Up @@ -338,10 +342,9 @@ def count_parens(string):
elem_number = check_and_eval(raw_elem).num_params
param_list = []
for j in range(elem_number):
current_elem = elem
if elem_number > 1:
current_elem = elem + '_{}'.format(j)
else:
current_elem = elem
current_elem += '_{}'.format(j)

if current_elem in constants.keys():
param_list.append(constants[current_elem])
Expand All @@ -352,13 +355,12 @@ def count_parens(string):
param_string += str(param_list)
new = raw_elem + '(' + param_string + ',' + str(frequencies) + ')'
eval_string += new

if i == len(split) - 1:
if len(split) > 1: # do not add closing brackets if single element
eval_string += '])'
else:
if i < len(split) - 1:
eval_string += ','

if len(split) > 1: # do not add closing brackets if single element
eval_string += '])'

return eval_string, index


Expand All @@ -383,14 +385,12 @@ def extract_circuit_elements(circuit):
for i, char in enumerate(p_string):
if char not in ints:
current_element.append(char)
else:
# min to prevent looking ahead past end of list
if p_string[min(i+1, length-1)] not in ints:
current_element.append(char)
extracted_elements.append(''.join(current_element))
current_element = []
else:
current_element.append(char)
continue
current_element.append(char)
# min to prevent looking ahead past end of list
if p_string[min(i+1, length-1)] not in ints:
extracted_elements.append(''.join(current_element))
current_element = []
extracted_elements.append(''.join(current_element))
return extracted_elements

Expand All @@ -410,12 +410,13 @@ def calculateCircuitLength(circuit):

"""
length = 0
if circuit:
extracted_elements = extract_circuit_elements(circuit)
for elem in extracted_elements:
raw_element = get_element_from_name(elem)
num_params = check_and_eval(raw_element).num_params
length += num_params
if not circuit:
return length
extracted_elements = extract_circuit_elements(circuit)
for elem in extracted_elements:
raw_element = get_element_from_name(elem)
num_params = check_and_eval(raw_element).num_params
length += num_params
return length


Expand Down
Loading