diff --git a/impedance/models/circuits/circuits.py b/impedance/models/circuits/circuits.py index b670d198..8300ea1c 100644 --- a/impedance/models/circuits/circuits.py +++ b/impedance/models/circuits/circuits.py @@ -96,27 +96,24 @@ def fit(self, frequencies, impedance, bounds=None, if len(frequencies) != len(impedance): raise TypeError('length of frequencies and impedance do not match') - if self.initial_guess != []: - parameters, conf = circuit_fit(frequencies, impedance, - self.circuit, self.initial_guess, - constants=self.constants, - bounds=bounds, - weight_by_modulus=weight_by_modulus, - **kwargs) - self.parameters_ = parameters - if conf is not None: - self.conf_ = conf - else: + if self.initial_guess == []: raise ValueError('No initial guess supplied') + parameters, conf = circuit_fit(frequencies, impedance, + self.circuit, self.initial_guess, + constants=self.constants, + bounds=bounds, + weight_by_modulus=weight_by_modulus, + **kwargs) + self.parameters_ = parameters + if conf is not None: + self.conf_ = conf + return self def _is_fit(self): """ check if model has been fit (parameters_ is not None) """ - if self.parameters_ is not None: - return True - else: - return False + return (self.parameters_ is not None) def predict(self, frequencies, use_initial=False): """ Predict impedance using an equivalent circuit model @@ -135,19 +132,16 @@ def predict(self, frequencies, use_initial=False): """ frequencies = np.array(frequencies, dtype=float) - if self._is_fit() and not use_initial: - return eval(buildCircuit(self.circuit, frequencies, - *self.parameters_, - constants=self.constants, eval_string='', - index=0)[0], - circuit_elements) - else: + parameters_for_fit = self.parameters_ + if not self._is_fit() or use_initial: warnings.warn("Simulating circuit based on initial parameters") - return eval(buildCircuit(self.circuit, frequencies, - *self.initial_guess, - constants=self.constants, eval_string='', - index=0)[0], - circuit_elements) + parameters_for_fit = self.initial_guess + + return eval(buildCircuit(self.circuit, frequencies, + *parameters_for_fit, + constants=self.constants, eval_string='', + index=0)[0], + circuit_elements) def get_param_names(self): """ Converts circuit string to names and units """ @@ -236,6 +230,8 @@ def plot(self, ax=None, f_data=None, Z_data=None, kind='altair', **kwargs): axes of the created nyquist plot """ + f_pred = f_data if f_data is not None else np.logspace(5, -3) + Z_fit = self.predict(f_pred) if self._is_fit() else None if kind == 'nyquist': if ax is None: _, ax = plt.subplots(figsize=(5, 5)) @@ -244,23 +240,12 @@ def plot(self, ax=None, f_data=None, Z_data=None, kind='altair', **kwargs): ax = plot_nyquist(Z_data, ls='', marker='s', ax=ax, **kwargs) if self._is_fit(): - if f_data is not None: - f_pred = f_data - else: - f_pred = np.logspace(5, -3) - - Z_fit = self.predict(f_pred) ax = plot_nyquist(Z_fit, ls='-', marker='', ax=ax, **kwargs) return ax elif kind == 'bode': if ax is None: _, ax = plt.subplots(nrows=2, figsize=(5, 5)) - if f_data is not None: - f_pred = f_data - else: - f_pred = np.logspace(5, -3) - if Z_data is not None: if f_data is None: raise ValueError('f_data must be specified if' + @@ -269,7 +254,6 @@ def plot(self, ax=None, f_data=None, Z_data=None, kind='altair', **kwargs): axes=ax, **kwargs) if self._is_fit(): - Z_fit = self.predict(f_pred) ax = plot_bode(f_pred, Z_fit, ls='-', marker='', axes=ax, **kwargs) return ax @@ -280,12 +264,6 @@ def plot(self, ax=None, f_data=None, Z_data=None, kind='altair', **kwargs): plot_dict['data'] = {'f': f_data, 'Z': Z_data} if self._is_fit(): - if f_data is not None: - f_pred = f_data - else: - f_pred = np.logspace(5, -3) - - Z_fit = self.predict(f_pred) if self.name is not None: name = self.name else: @@ -312,24 +290,20 @@ def save(self, filepath): initial_guess = self.initial_guess + data_dict = {"Name": model_name, + "Circuit String": model_string, + "Initial Guess": initial_guess, + "Constants": self.constants, + "Fit": False} + if self._is_fit(): + data_dict["Fit"] = True parameters_ = list(self.parameters_) model_conf_ = list(self.conf_) - - data_dict = {"Name": model_name, - "Circuit String": model_string, - "Initial Guess": initial_guess, - "Constants": self.constants, - "Fit": True, - "Parameters": parameters_, - "Confidence": model_conf_, - } - else: - data_dict = {"Name": model_name, - "Circuit String": model_string, - "Initial Guess": initial_guess, - "Constants": self.constants, - "Fit": False} + data_dict.update({ + "Parameters": parameters_, + "Confidence": model_conf_ + }) with open(filepath, 'w') as f: json.dump(data_dict, f) @@ -364,12 +338,13 @@ def load(self, filepath, fitted_as_initial=False): self.constants = model_constants self.name = model_name - if json_data["Fit"]: - if fitted_as_initial: - self.initial_guess = np.array(json_data['Parameters']) - else: - self.parameters_ = np.array(json_data["Parameters"]) - self.conf_ = np.array(json_data["Confidence"]) + if not json_data["Fit"]: + return + if fitted_as_initial: + self.initial_guess = np.array(json_data['Parameters']) + else: + self.parameters_ = np.array(json_data["Parameters"]) + self.conf_ = np.array(json_data["Confidence"]) class Randles(BaseCircuit): diff --git a/impedance/models/circuits/elements.py b/impedance/models/circuits/elements.py index 643ff696..1977c137 100644 --- a/impedance/models/circuits/elements.py +++ b/impedance/models/circuits/elements.py @@ -33,7 +33,7 @@ def wrapper(p, f): wrapper.__name__ = func.__name__ wrapper.__doc__ = func.__doc__ - global circuit_elements + # global circuit_elements if func.__name__ in ["s", "p"]: raise ElementError("cannot redefine elements 's' (series)" + "or 'p' (parallel)") diff --git a/impedance/models/circuits/fitting.py b/impedance/models/circuits/fitting.py index 5fe1c300..72fb3702 100644 --- a/impedance/models/circuits/fitting.py +++ b/impedance/models/circuits/fitting.py @@ -180,7 +180,7 @@ def opt_function(x): np.hstack([Z.real, Z.imag])) class BasinhoppingBounds(object): - """ Adapted from the basinhopping documetation + """ Adapted from the basinhopping documentation https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.basinhopping.html """ @@ -293,24 +293,28 @@ def count_parens(string): result = [] skipped = [] for i, sub_str in enumerate(split): - if i not in skipped: - if '(' not in sub_str and ')' not in sub_str: - result.append(sub_str) - else: - open_parens, closed_parens = count_parens(sub_str) - if open_parens == closed_parens: - result.append(sub_str) - else: - uneven = True - while i < len(split) - 1 and uneven: - sub_str += special + split[i+1] - - open_parens, closed_parens = count_parens(sub_str) - uneven = open_parens != closed_parens - - i += 1 - skipped.append(i) - result.append(sub_str) + if i in skipped: + continue + + if '(' not in sub_str and ')' not in sub_str: + result.append(sub_str) + continue + + open_parens, closed_parens = count_parens(sub_str) + if open_parens == closed_parens: + result.append(sub_str) + continue + + uneven = True + while i < len(split) - 1 and uneven: + sub_str += special + split[i+1] + + open_parens, closed_parens = count_parens(sub_str) + uneven = open_parens != closed_parens + + i += 1 + skipped.append(i) + result.append(sub_str) return result parallel = parse_circuit(circuit, parallel=True) @@ -338,10 +342,9 @@ def count_parens(string): elem_number = check_and_eval(raw_elem).num_params param_list = [] for j in range(elem_number): + current_elem = elem if elem_number > 1: - current_elem = elem + '_{}'.format(j) - else: - current_elem = elem + current_elem += '_{}'.format(j) if current_elem in constants.keys(): param_list.append(constants[current_elem]) @@ -352,13 +355,12 @@ def count_parens(string): param_string += str(param_list) new = raw_elem + '(' + param_string + ',' + str(frequencies) + ')' eval_string += new - - if i == len(split) - 1: - if len(split) > 1: # do not add closing brackets if single element - eval_string += '])' - else: + if i < len(split) - 1: eval_string += ',' + if len(split) > 1: # do not add closing brackets if single element + eval_string += '])' + return eval_string, index @@ -383,14 +385,12 @@ def extract_circuit_elements(circuit): for i, char in enumerate(p_string): if char not in ints: current_element.append(char) - else: - # min to prevent looking ahead past end of list - if p_string[min(i+1, length-1)] not in ints: - current_element.append(char) - extracted_elements.append(''.join(current_element)) - current_element = [] - else: - current_element.append(char) + continue + current_element.append(char) + # min to prevent looking ahead past end of list + if p_string[min(i+1, length-1)] not in ints: + extracted_elements.append(''.join(current_element)) + current_element = [] extracted_elements.append(''.join(current_element)) return extracted_elements @@ -410,12 +410,13 @@ def calculateCircuitLength(circuit): """ length = 0 - if circuit: - extracted_elements = extract_circuit_elements(circuit) - for elem in extracted_elements: - raw_element = get_element_from_name(elem) - num_params = check_and_eval(raw_element).num_params - length += num_params + if not circuit: + return length + extracted_elements = extract_circuit_elements(circuit) + for elem in extracted_elements: + raw_element = get_element_from_name(elem) + num_params = check_and_eval(raw_element).num_params + length += num_params return length