From c895e84f40e0bbb049302dd30b8ce3c4b108db79 Mon Sep 17 00:00:00 2001 From: Yury Luneff Date: Sun, 24 Nov 2024 16:30:28 +0300 Subject: [PATCH 1/2] Pass given device to the pipeline --- runorm/runorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runorm/runorm.py b/runorm/runorm.py index fe46918..cb90c74 100644 --- a/runorm/runorm.py +++ b/runorm/runorm.py @@ -44,7 +44,7 @@ def load(self, model_size="small", device="cpu", workdir=None): self.angl_model = T5ForConditionalGeneration.from_pretrained(self.paths["kirillizator"], cache_dir=self.workdir) self.tagger_model = BertForTokenClassification.from_pretrained(self.paths["tagger"], cache_dir=self.workdir) self.tagger_tokenizer = AutoTokenizer.from_pretrained(self.paths["tagger"], cache_dir=self.workdir) - self.tagger = pipeline("ner", model=self.tagger_model, tokenizer=self.tagger_tokenizer, aggregation_strategy="average") + self.tagger = pipeline("ner", model=self.tagger_model, tokenizer=self.tagger_tokenizer, aggregation_strategy="average", device=device) self.abbr_model.to(device) self.angl_model.to(device) self.abbr_model.eval() @@ -292,4 +292,4 @@ def norm(self, message): out = out + " " + final_answer #elapsed_time = time.time() - start - return out.strip() \ No newline at end of file + return out.strip() From 95676a0a7d973ea115f5e5eba66c65fbc4c3d8a9 Mon Sep 17 00:00:00 2001 From: Yury Luneff Date: Sun, 24 Nov 2024 16:35:36 +0300 Subject: [PATCH 2/2] fix re warnings --- runorm/runorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runorm/runorm.py b/runorm/runorm.py index cb90c74..5a27f5a 100644 --- a/runorm/runorm.py +++ b/runorm/runorm.py @@ -21,7 +21,7 @@ def __init__(self): self.rule_normalizer = RuleNormalizer() self.numbers_normalizer = Numbers2Words() self.re_tokens = re.compile(r"(?:[.,!?]|[а-яА-Я]\S*|-?\d\S*(?:\.\d+)?|[^а-яА-Я\d\s-]+)\s*") - self.re_normalization = re.compile(r"[^a-zA-Z0-9\sа-яА-ЯёЁ.,!?:;""''(){}\[\]«»„“”-]") + self.re_normalization = re.compile(r"[^a-zA-Z0-9\sа-яА-ЯёЁ.,!?:;""''(){}[]«»„“”-]") self.paths = { "tagger": "RUNorm/RUNorm-tagger", "kirillizator": "RUNorm/RUNorm-kirillizator", @@ -107,7 +107,7 @@ def construct_prompt(self, text, angl_mode=False): etid = 0 token_to_add = "" for token in self.process_sentence(text) + [""]: - if not re.search("[a-zA-Z\d]", token): + if not re.search(r"[a-zA-Z\d]", token): if token_to_add: end_match = re.search(r"(.+?)(\W*)$", token_to_add, re.M).groups() if self.is_english(end_match[0].strip()):