Skip to content

Commit 06cd767

Browse files
committed
implemented correct smart trimming
1 parent 064a14a commit 06cd767

File tree

1 file changed

+66
-33
lines changed

1 file changed

+66
-33
lines changed

conversion2025/mathpix_to_llm_with_lines_to_api.ipynb

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@
400400
"class InlineMath(Markdown):\n",
401401
" def __init__(self, content):\n",
402402
" super().__init__(content)\n",
403+
" self.delimiter_size = 2\n",
403404
"\n",
404405
" def __str__(self):\n",
405406
" return f\"InlineMath({self.content!r})\"\n",
@@ -411,6 +412,7 @@
411412
"class DisplayMath(Markdown): \n",
412413
" def __init__(self, content):\n",
413414
" super().__init__(content)\n",
415+
" self.delimiter_size = 4\n",
414416
"\n",
415417
" def __str__(self):\n",
416418
" return f\"DisplayMath({self.content!r})\"\n",
@@ -1231,25 +1233,49 @@
12311233
"# but overall the position should be fairly accurate.\n",
12321234
"\n",
12331235
"def improve_trim(text: str, start: int, end: int) -> str:\n",
1234-
" markdown_classes = convert_markdown_to_classes_by_lines(text)\n",
1235-
" index = 0\n",
1236+
" markdown_classes = convert_markdown_to_classes(text)\n",
1237+
" # print(markdown_classes)\n",
1238+
" text_index = 0\n",
1239+
" class_index = 0\n",
1240+
" improved_start = -1\n",
1241+
" improved_end = -1\n",
12361242
"\n",
1237-
" for i in range(len(markdown_classes)):\n",
1238-
" structure = markdown_classes[i]\n",
1243+
" while class_index < len(markdown_classes):\n",
1244+
" structure = markdown_classes[class_index]\n",
12391245
"\n",
12401246
" match structure:\n",
12411247
" case RegularText():\n",
1242-
" if len(structure.content) + index < start:\n",
1243-
" # start is not in this structure\n",
1244-
" index += len(structure.content) + 1\n",
1245-
" continue\n",
1246-
" else:\n",
1247-
" # start is in this structure\n",
1248-
" structure_length = len(structure.content)\n",
1249-
" structure.content = structure.content[start - index:]\n",
1250-
" index += structure_length + 1\n",
1251-
" continue\n",
1252-
" return \"\"\n",
1248+
" structure_length = len(structure.content)\n",
1249+
" temp_improved_start = 0\n",
1250+
" temp_improved_end = len(structure.content)\n",
1251+
" if text_index <= start and structure_length + text_index > start:\n",
1252+
" improved_start = class_index\n",
1253+
" temp_improved_start = start - text_index\n",
1254+
"\n",
1255+
" if text_index <= end and structure_length + text_index >= end:\n",
1256+
" improved_end = class_index\n",
1257+
" temp_improved_end = end - text_index\n",
1258+
" \n",
1259+
" structure.content = structure.content[temp_improved_start:temp_improved_end]\n",
1260+
" text_index += structure_length\n",
1261+
"\n",
1262+
" case InlineMath() | DisplayMath():\n",
1263+
" structure_length = len(structure.content) + structure.delimiter_size\n",
1264+
" if text_index <= start and structure_length + text_index > start:\n",
1265+
" improved_start = class_index\n",
1266+
" if text_index <= end and structure_length + text_index >= end:\n",
1267+
" improved_end = class_index\n",
1268+
" text_index += structure_length\n",
1269+
" \n",
1270+
" class_index += 1\n",
1271+
"\n",
1272+
"\n",
1273+
" ret = markdown_classes[improved_start:improved_end + 1]\n",
1274+
" # print(ret)\n",
1275+
" # print(len(text), start, end)\n",
1276+
" # print(improved_start, improved_end)\n",
1277+
"\n",
1278+
" return convert_classes_to_markdown(ret)\n",
12531279
"\n"
12541280
]
12551281
},
@@ -1337,9 +1363,12 @@
13371363
" \"\"\"\n",
13381364
"\n",
13391365
"\n",
1340-
"def trim_question(question: Set_Question_With_Solution) -> Set_Question_With_Solution:\n",
1366+
"def trim_question(question: tuple[int, Set_Question_With_Solution]) -> Set_Question_With_Solution:\n",
1367+
" question_number, question = question\n",
1368+
" question_number += 1\n",
13411369
"\n",
13421370
" def trim_question_content(content_text: str) -> str:\n",
1371+
"\n",
13431372
" if content_text == \"\":\n",
13441373
" return content_text\n",
13451374
"\n",
@@ -1368,20 +1397,22 @@
13681397
" try:\n",
13691398
" parsed_output = content_parser.parse(response.content)\n",
13701399
" start = content_text.index(parsed_output.start)\n",
1371-
" end = content_text.index(parsed_output.end)\n",
1372-
" print(\"Successfully trimmed the stem.\")\n",
1400+
" end = content_text.index(parsed_output.end) + len(parsed_output.end)\n",
1401+
" print(f\"Successfully trimmed the stem of question {question_number}.\")\n",
13731402
"\n",
1374-
" return content_text[start:end + len(parsed_output.end) + 1].strip()\n",
1403+
" return improve_trim(content_text, start, end)\n",
13751404
" except Exception as e:\n",
1376-
" print(f\"Error parsing LLM response as JSON for trimming content:\")\n",
1405+
" print(f\"Error parsing LLM response as JSON for trimming content of question {question_number}:\")\n",
13771406
" print(f\"Retrying... Attempt No.{attempt_idx + 1}\")\n",
13781407
" time.sleep(2)\n",
13791408
" else:\n",
13801409
" print(\"Final LLM Response:\")\n",
13811410
" print(response.content)\n",
13821411
" raise Exception(\"Failed to parse LLM response as JSON after multiple attempts for trimming content.\")\n",
13831412
"\n",
1384-
" def trim_question_part(part_text: str) -> str:\n",
1413+
" def trim_question_part(part: tuple[int, str]) -> str:\n",
1414+
" part_number, part_text = part\n",
1415+
" part_number += 1\n",
13851416
" if part_text == \"\":\n",
13861417
" return part_text\n",
13871418
" \n",
@@ -1410,20 +1441,22 @@
14101441
" try:\n",
14111442
" parsed_output = part_parser.parse(response.content)\n",
14121443
" start = part_text.index(parsed_output.start)\n",
1413-
" end = part_text.index(parsed_output.end)\n",
1414-
" print(\"Successfully trimmed part\")\n",
1444+
" end = part_text.index(parsed_output.end) + len(parsed_output.end)\n",
1445+
" print(f\"Successfully trimmed part of question {question_number}, part {part_number}.\")\n",
14151446
"\n",
1416-
" return part_text[start:end + len(parsed_output.end) + 1].strip()\n",
1447+
" return improve_trim(part_text, start, end)\n",
14171448
" except Exception as e:\n",
1418-
" print(f\"Error parsing LLM response as JSON for trimming part:\")\n",
1449+
" print(f\"Error parsing LLM response as JSON for trimming part for question {question_number}, part {part_number}\")\n",
14191450
" print(f\"Retrying... Attempt No.{attempt_idx + 1}\")\n",
14201451
" time.sleep(2)\n",
14211452
" else:\n",
14221453
" print(\"Final LLM Response:\")\n",
14231454
" print(response.content)\n",
14241455
" raise Exception(\"Failed to parse LLM response as JSON after multiple attempts for trimming part.\")\n",
14251456
"\n",
1426-
" def trim_question_part_solution(solution_text: str) -> str:\n",
1457+
" def trim_question_part_solution(solution: tuple[int, str]) -> str:\n",
1458+
" part_number, solution_text = solution\n",
1459+
" part_number += 1\n",
14271460
" if solution_text == \"\":\n",
14281461
" return solution_text\n",
14291462
" \n",
@@ -1452,12 +1485,12 @@
14521485
" try:\n",
14531486
" parsed_output = solution_parser.parse(response.content)\n",
14541487
" start = solution_text.index(parsed_output.start)\n",
1455-
" end = solution_text.index(parsed_output.end)\n",
1456-
" print(\"Successfully trimmed part-solution.\")\n",
1488+
" end = solution_text.index(parsed_output.end) + len(parsed_output.end)\n",
1489+
" print(f\"Successfully trimmed part-solution for question {question_number}, part {part_number}.\")\n",
14571490
"\n",
1458-
" return solution_text[start:end + len(parsed_output.end) + 1].strip()\n",
1491+
" return improve_trim(solution_text, start, end)\n",
14591492
" except Exception as e:\n",
1460-
" print(f\"Error parsing LLM response as JSON for trimming solution part:\")\n",
1493+
" print(f\"Error parsing LLM response as JSON for trimming solution part for question {question_number}, part {part_number}\")\n",
14611494
" print(f\"Retrying... Attempt No.{attempt_idx + 1}\")\n",
14621495
" time.sleep(2)\n",
14631496
"\n",
@@ -1469,15 +1502,15 @@
14691502
" question.content = trim_question_content(question.content)\n",
14701503
"\n",
14711504
" with concurrent.futures.ThreadPoolExecutor() as executor:\n",
1472-
" question.parts = list(executor.map(trim_question_part, question.parts))\n",
1473-
" question.parts_solutions = list(executor.map(trim_question_part_solution, question.parts_solutions))\n",
1505+
" question.parts = list(executor.map(trim_question_part, enumerate(question.parts)))\n",
1506+
" question.parts_solutions = list(executor.map(trim_question_part_solution, enumerate(question.parts_solutions)))\n",
14741507
"\n",
14751508
" return question\n",
14761509
"\n",
14771510
"def trim_text(set_questions: Set_Lines) -> Set_Lines:\n",
14781511
"\n",
14791512
" with concurrent.futures.ThreadPoolExecutor() as executor:\n",
1480-
" set_questions.questions = list(executor.map(trim_question, set_questions.questions))\n",
1513+
" set_questions.questions = list(executor.map(trim_question, enumerate(set_questions.questions)))\n",
14811514
"\n",
14821515
" return set_questions\n"
14831516
]

0 commit comments

Comments
 (0)