Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pynini
from pynini.lib import pynutil

from nemo_text_processing.inverse_text_normalization.zh.graph_utils import NEMO_DIGIT, NEMO_SIGMA, GraphFst
from nemo_text_processing.inverse_text_normalization.zh.graph_utils import NEMO_CHAR, NEMO_DIGIT, NEMO_SIGMA, GraphFst
from nemo_text_processing.inverse_text_normalization.zh.utils import get_abs_path


Expand Down Expand Up @@ -53,6 +53,7 @@ def __init__(self):
# grammar for tens, not the output for Cardinal grammar but for pure Arabic digits (used in other grammars)
graph_tens = (ties + graph_digits) | (pynini.cross(pynini.accep("零"), "0") + graph_digits)
graph_all = graph_tens | graph_teens | pynutil.insert("00")
graph_all = graph_all.optimize()

# grammar for hundreds 百
graph_hundreds_complex = (
Expand Down Expand Up @@ -336,7 +337,7 @@ def __init__(self):
graph_teens,
graph_digits,
zero,
)
).optimize()

# combining grammar; output consists only arabic numbers
graph_just_cardinals = pynini.union(
Expand All @@ -354,29 +355,42 @@ def __init__(self):
graph_teens,
graph_digits,
zero,
)
).optimize()

# delete unnecessary leading zero
delete_leading_zeros = pynutil.delete(pynini.closure("0"))
stop_at_non_zero = pynini.difference(NEMO_DIGIT, "0")
rest_of_cardinal = pynini.closure(NEMO_DIGIT) | pynini.closure(NEMO_SIGMA)
rest_of_cardinal = (pynini.closure(NEMO_DIGIT) + pynini.closure(NEMO_CHAR, 1)) | (
pynini.closure(NEMO_DIGIT)
) # general use cases for other graphs
rest_of_cardinal_2 = (pynini.closure(NEMO_DIGIT) + pynini.closure(NEMO_CHAR, 1)) | (
pynini.closure(NEMO_DIGIT, 2)
) # for normal cardinal graph

# output for cardinal grammar without leading zero
clean_cardinal = delete_leading_zeros + stop_at_non_zero + rest_of_cardinal
clean_cardinal = clean_cardinal | "0"
graph = graph @ clean_cardinal # output for regular cardinals
self.for_ordinals = graph # used for ordinal grammars
clean_cardinal_2 = delete_leading_zeros + stop_at_non_zero + rest_of_cardinal_2
clean_just_cardinal = delete_leading_zeros + stop_at_non_zero + rest_of_cardinal

# union zero with graph to stop overproduced 0's from the inserts.
# TODO: Rewrite digits graphs so that we don't have free floating zero inserts.
self.for_ordinals = (graph | zero) @ clean_cardinal # used for ordinal grammars
self.for_ordinals = self.for_ordinals.optimize()

# output for pure arabic number without leading zero
clean_just_cardinal = delete_leading_zeros + stop_at_non_zero + rest_of_cardinal
clean_just_cardinal = clean_just_cardinal | "0"
graph_just_cardinals = graph_just_cardinals @ clean_just_cardinal # output for other grammars
self.just_cardinals = graph_just_cardinals # used for other grammars
self.just_cardinals = graph_just_cardinals | zero # used for other grammars
self.just_cardinals = self.just_cardinals.optimize()

# final grammar for cardinal output; tokenization
optional_minus_graph = (pynini.closure(pynutil.insert("negative: ") + pynini.cross("负", '"-"'))) | (
pynini.closure(pynutil.insert("negative: ") + pynini.cross("負", '"-"'))
)
final_graph = optional_minus_graph + pynutil.insert('integer: "') + graph + pynutil.insert('"')
final_graph = (
optional_minus_graph
+ pynutil.insert('integer: "')
+ ((graph | zero) @ clean_cardinal_2)
+ pynutil.insert('"')
)
final_graph = self.add_tokens(final_graph)
self.fst = final_graph
self.fst = final_graph.optimize()
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
pynini.closure(punct + pynutil.insert(" ")) + token + pynini.closure(pynutil.insert(" ") + punct)
)

graph = token_plus_punct + pynini.closure(delete_zero_or_one_space + token_plus_punct)
graph = token_plus_punct + pynini.closure(delete_space + token_plus_punct)
graph = delete_space + graph + delete_space

self.fst = graph.optimize()
Expand Down