Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions pathways/typing/mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def create_cart_diagram(root: Node) -> str:
links = []

for node in root.preorder():
probabilities = getattr(node, "class_probabilities", None)
probabilities = node.class_probabilities

if node.is_leaf and probabilities:
prob_shapes, prob_links = create_segment_probability_stack(
Expand Down Expand Up @@ -174,9 +174,10 @@ def create_segment_probability_stack(
return shapes, links


def create_form_diagram(root: Node, *, skip_notes: bool = False) -> str:
def create_form_diagram(root: Node, *, skip_notes: bool = False, threshold: float = 0.0) -> str:
"""Create mermaid diagram for typing form."""
header = "flowchart TD"
threshold = threshold / 100.0
shapes = {
"segment": "stadium",
"select_one": "rectangle",
Expand All @@ -195,14 +196,20 @@ def create_form_diagram(root: Node, *, skip_notes: bool = False) -> str:
continue

is_segment_leaf = node.name == "segment"
probabilities = getattr(node, "class_probabilities", None)
probabilities = node.class_probabilities

if is_segment_leaf and probabilities:
prob_shapes, prob_links = create_segment_probability_stack(
node, probabilities, "circle"
)
shapes_lst.extend(prob_shapes)
links.extend(prob_links)
max_prob = max(probabilities.values())
if max_prob < threshold:
prob_shapes, prob_links = create_segment_probability_stack(
node, probabilities, "circle"
)
shapes_lst.extend(prob_shapes)
links.extend(prob_links)
else:
shape_label = get_form_shape_label(node)
shape = draw_shape(node.uid, shape_label, "circle")
shapes_lst.append(shape)
else:
shape_type = "circle" if is_segment_leaf else shapes[node.question.type]
shape_label = get_form_shape_label(node)
Expand Down
40 changes: 34 additions & 6 deletions pathways/typing/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,48 @@ def add_segment_note(


def add_segment_notes(
root: Node, settings_config: dict, segments_config: dict | None = None
root: Node,
settings_config: dict,
segments_config: dict | None = None,
low_confidence_threshold: float = 0.0,
) -> Node:
"""Add notes once segments are assigned."""
"""Add notes once segments are assigned.

If confidence_threshold is provided (percentage), calculate max probability. If max_probability < threshold,
segment + dead-end note will be applied. Otherwise, only segment note is applied.
"""
low_confidence_threshold = low_confidence_threshold / 100
new_root = copy.deepcopy(root)
note_label = {
key.replace("segment_note", "label"): value
key.replace("segment_note", "label"): value.replace("\\n", "\n") if isinstance(value, str) else value
for key, value in settings_config.items()
if key.startswith("segment_note")
}
low_conf_label = {
key.replace("deadend_note", "label"): value.replace("\\n", "\n") if isinstance(value, str) else value
for key, value in settings_config.items()
if key.startswith("deadend_note")
}

for node in new_root.preorder():
if node.is_leaf and node.name == "segment":
add_segment_note(node, note_label, segments_config)
return new_root
use_low_conf = False
if low_confidence_threshold > 0 and node.class_probabilities:
max_prob = max(node.class_probabilities.values())
use_low_conf = max_prob < low_confidence_threshold
final_label = note_label.copy()
if use_low_conf:
for key, seg_note in final_label.items():
low_conf_note = low_conf_label.get(
key,
"\n[Low segment assignment confidence]\n"
"We recommend stopping this survey and starting with a new respondent."
)
final_label[key] = seg_note + low_conf_note

add_segment_note(node, final_label, segments_config)

return new_root

def enforce_relevance(root: Node) -> Node:
"""Enforce relevance rules for the node.
Expand Down Expand Up @@ -280,7 +308,7 @@ def exit_deadends(

# create note for dead-end
deadend_label = {
key.replace("deadend_note", "label"): value
key.replace("deadend_note", "label"): value.replace("\\n", "\n") if isinstance(value, str) else value
for key, value in settings_config.items()
if key.startswith("deadend_note")
}
Expand Down