Skip to content

Commit da96e15

Browse files
committed
handling latex and resolving bugs
1 parent 1014f8f commit da96e15

File tree

4 files changed

+98
-21
lines changed

4 files changed

+98
-21
lines changed

sphinx_proof/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,21 @@ def setup(app: Sphinx) -> Dict[str, Any]:
7676
proof_node,
7777
singlehtml=(visit_proof_node, depart_proof_node),
7878
html=(visit_proof_node, depart_proof_node),
79+
latex=(visit_proof_node, depart_proof_node)
7980
)
8081
app.add_node(
8182
unenumerable_node,
8283
singlehtml=(visit_unenumerable_node, depart_unenumerable_node),
8384
html=(visit_unenumerable_node, depart_unenumerable_node),
85+
latex=(visit_unenumerable_node, depart_unenumerable_node)
8486
)
8587
app.add_enumerable_node(
8688
enumerable_node,
8789
"proof",
8890
None,
8991
singlehtml=(visit_enumerable_node, depart_enumerable_node),
9092
html=(visit_enumerable_node, depart_enumerable_node),
93+
latex=(visit_enumerable_node, depart_enumerable_node)
9194
)
9295

9396
return {

sphinx_proof/directive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def run(self) -> List[Node]:
9797
"prio": 0,
9898
"nonumber": True if "nonumber" in self.options else False,
9999
}
100-
100+
101101
return [node]
102102

103103

sphinx_proof/domain.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from docutils import nodes
2222
from .directive import ProofDirective
2323
from .proof_type import PROOF_TYPES
24+
from copy import copy
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -37,7 +38,10 @@ def generate(self, docnames=None) -> Tuple[Dict[str, Any], bool]:
3738
if not hasattr(self.domain.env, "proof_list"):
3839
return content, True
3940

41+
# import pdb;
42+
# pdb.set_trace()
4043
proofs = self.domain.env.proof_list
44+
# {'theorem-0': {'docname': 'start/overview', 'type': 'theorem', 'ids': ['theorem-0'], 'label': 'theorem-0', 'prio': 0, 'nonumber': False}}
4145

4246
# name, subtype, docname, typ, anchor, extra, qualifier, description
4347
for anchor, values in proofs.items():
@@ -71,11 +75,21 @@ class ProofDomain(Domain):
7175
name = "prf"
7276
label = "Proof Domain"
7377

74-
roles = {"ref": ProofXRefRole()}
78+
roles = {"ref": ProofXRefRole()} # role name -> role callable
7579

76-
indices = {ProofIndex}
80+
indices = {ProofIndex} # a list of index subclasses
7781

78-
directives = {**{"proof": ProofDirective}, **PROOF_TYPES}
82+
directives = {**{"proof": ProofDirective}, **PROOF_TYPES} # list of directives
83+
84+
enumerable_nodes = {} # type: Dict[Type[Node], Tuple[str, Callable]]
85+
86+
def __init__(self, env: "BuildEnvironment") -> None:
87+
super().__init__(env)
88+
89+
# set up enumerable nodes
90+
self.enumerable_nodes = copy(self.enumerable_nodes) # create a copy for this instance
91+
for node, settings in env.app.registry.enumerable_nodes.items():
92+
self.enumerable_nodes[node] = settings
7993

8094
def resolve_xref(
8195
self,
@@ -87,7 +101,13 @@ def resolve_xref(
87101
node: pending_xref,
88102
contnode: Element,
89103
) -> Element:
90-
104+
"""
105+
Resolve the pending_xref node with the given typ and target. This method should return a new node,
106+
to replace the xref node, containing the contnode which is the markup content of the cross-reference.
107+
If no resolution can be found, None can be returned; the xref node will then given to the missing-reference event,
108+
and if that yields no resolution, replaced by contnode.The method can also raise sphinx.environment.NoUri
109+
to suppress the missing-reference event being emitted.
110+
"""
91111
try:
92112
match = env.proof_list[target]
93113
except Exception:
@@ -98,7 +118,6 @@ def resolve_xref(
98118

99119
todocname = match["docname"]
100120
title = contnode[0]
101-
102121
if target in contnode[0]:
103122
number = ""
104123
if not env.proof_list[target]["nonumber"]:

sphinx_proof/nodes.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from sphinx.builders.html import HTMLTranslator
1111
from docutils import nodes
1212
from docutils.nodes import Node
13+
from sphinx.writers.latex import LaTeXTranslator
1314

15+
CR = '\n'
1416

1517
class proof_node(nodes.Admonition, nodes.Element):
1618
pass
@@ -25,35 +27,55 @@ class unenumerable_node(nodes.Admonition, nodes.Element):
2527

2628

2729
def visit_enumerable_node(self, node: Node) -> None:
28-
self.body.append(self.starttag(node, "div", CLASS="admonition"))
30+
if isinstance(self, LaTeXTranslator):
31+
self.body.append(CR + '\\begin{sphinxadmonition}{note}')
32+
else:
33+
self.body.append(self.starttag(node, "div", CLASS="admonition"))
2934

3035

3136
def depart_enumerable_node(self, node: Node) -> None:
3237
typ = node.attributes.get("type", "")
33-
34-
# Find index in list of 'Proof #'
35-
number = get_node_number(self, node)
36-
idx = self.body.index(f"Proof {number} ")
37-
self.body[idx] = f"{typ.title()} {number} "
38-
self.body.append("</div>")
38+
if isinstance(self, LaTeXTranslator):
39+
number = get_node_number_latex(self, node)
40+
text = node.children[0].astext()
41+
idx = list_rindex(self.body,CR + '\\begin{sphinxadmonition}{note}') + 2
42+
self.body.insert(idx, f"{typ.title()} {number}")
43+
self.body.append('\\end{sphinxadmonition}' + CR)
44+
else:
45+
# Find index in list of 'Proof #'
46+
number = get_node_number(self, node)
47+
idx = self.body.index(f"Proof {number} ")
48+
self.body[idx] = f"{typ.title()} {number} "
49+
self.body.append("</div>")
3950

4051

4152
def visit_unenumerable_node(self, node: Node) -> None:
42-
self.body.append(self.starttag(node, "div", CLASS="admonition"))
53+
if isinstance(self, LaTeXTranslator):
54+
self.body.append(CR + '\\begin{sphinxadmonition}{note}')
55+
else:
56+
self.body.append(self.starttag(node, "div", CLASS="admonition"))
4357

4458

4559
def depart_unenumerable_node(self, node: Node) -> None:
4660
typ = node.attributes.get("type", "")
4761
title = node.attributes.get("title", "")
48-
49-
if title == "":
50-
idx = len(self.body) - self.body[-1::-1].index('<p class="admonition-title">')
62+
if isinstance(self, LaTeXTranslator):
63+
if title == "":
64+
idx = (len(self.body) - self.body[-1::-1].index(CR + '\\begin{sphinxadmonition}{note}')) + 1
65+
else:
66+
text = node.children[0].astext()
67+
idx = list_rindex(self.body,CR + '\\begin{sphinxadmonition}{note}') + 2
68+
self.body.insert(idx, f"{typ.title()}")
69+
self.body.append('\\end{sphinxadmonition}' + CR)
5170
else:
52-
idx = self.body.index(title)
71+
if title == "":
72+
idx = len(self.body) - self.body[-1::-1].index('<p class="admonition-title">')
73+
else:
74+
idx = list_rindex(self.body,title)
5375

54-
element = f"<span>{typ.title()} </span>"
55-
self.body.insert(idx, element)
56-
self.body.append("</div>")
76+
element = f"<span>{typ.title()} </span>"
77+
self.body.insert(idx, element)
78+
self.body.append("</div>")
5779

5880

5981
def visit_proof_node(self, node: Node) -> None:
@@ -69,3 +91,36 @@ def get_node_number(self: HTMLTranslator, node: Node) -> str:
6991
ids = node.attributes.get("ids", [])[0]
7092
number = self.builder.fignumbers.get(key, {}).get(ids, ())
7193
return ".".join(map(str, number))
94+
95+
def find_parent(env, node , parent_tag):
96+
"""Find the parent node."""
97+
while True:
98+
node = node.parent
99+
if node is None:
100+
return None
101+
# parent should be a document in toc
102+
if (
103+
"docname" in node.attributes
104+
and env.titles[node.attributes["docname"]].astext().lower()
105+
in node.attributes["names"]
106+
):
107+
return node.attributes['docname']
108+
109+
if node.tagname == parent_tag:
110+
return node.attributes['docname']
111+
112+
return None
113+
114+
def get_node_number_latex(self, node: Node) -> str:
115+
key = "proof"
116+
docname = find_parent(self.builder.env, node, "section")
117+
ids = node.attributes.get("ids", [])[0]
118+
fignumbers = self.builder.env.toc_fignumbers.get(docname, {})
119+
number = fignumbers.get(key, {}).get(ids, ())
120+
return ".".join(map(str, number))
121+
122+
def list_rindex(li, x):
123+
for i in reversed(range(len(li))):
124+
if li[i] == x:
125+
return i
126+
raise ValueError("{} is not in list".format(x))

0 commit comments

Comments
 (0)