diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 72af194..c471bb0 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -18,6 +18,8 @@ jobs:
steps:
- uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch full history for setuptools-scm
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
@@ -58,6 +60,8 @@ jobs:
steps:
- uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch full history for setuptools-scm
- name: Install build dependencies
run: |
@@ -86,6 +90,8 @@ jobs:
steps:
- uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch full history for setuptools-scm
- name: Set up Python
uses: actions/setup-python@v5
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 3df56ac..babc49c 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -22,6 +22,8 @@ jobs:
steps:
- uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch full history for setuptools-scm
- name: Install build dependencies
run: |
@@ -60,6 +62,8 @@ jobs:
steps:
- uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch full history for setuptools-scm
- name: Create bin directory
run: mkdir -p dalla_data_processing/deduplication/bin
@@ -111,6 +115,8 @@ jobs:
steps:
- uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch full history for setuptools-scm
- name: Set up Python
uses: actions/setup-python@v5
@@ -167,6 +173,8 @@ jobs:
steps:
- uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch full history for setuptools-scm
- uses: actions/download-artifact@v4
with:
diff --git a/.gitignore b/.gitignore
index 5c952c5..54b501a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,6 +20,9 @@ wheels/
.installed.cfg
*.egg
+# setuptools-scm version file
+dalla_data_processing/_version.py
+
# Virtual environments
venv/
env/
diff --git a/README.md b/README.md
index 030cb6e..41f91a9 100644
--- a/README.md
+++ b/README.md
@@ -10,49 +10,90 @@ A comprehensive Arabic data processing pipeline with deduplication, stemming, qu
## Installation
+### Quick Start (All Features)
+
+For most users, install with all features enabled:
+
Using uv
```bash
-# Install the package
-uv pip install dalla-data-processing
+uv pip install "dalla-data-processing[all]"
```
-
Using pip
```bash
-# Install the package
+pip install "dalla-data-processing[all]"
+```
+
+### Modular Installation (Advanced)
+
+Install only the components you need to keep dependencies minimal:
+
+```bash
+# Base installation (no processing features, only core dependencies)
pip install dalla-data-processing
+
+# Install specific features
+pip install "dalla-data-processing[dedup]" # Deduplication only
+pip install "dalla-data-processing[stem]" # Stemming only
+pip install "dalla-data-processing[quality]" # Quality checking only
+pip install "dalla-data-processing[readability]" # Readability scoring only
+pip install "dalla-data-processing[pack]" # Dataset packing only
+
+# Combine multiple features
+pip install "dalla-data-processing[dedup,stem,quality]"
```
+### Development Installation
-From Source
+From Source (with uv - recommended)
```bash
git clone https://github.com/U4RASD/dalla-data-processing.git
cd dalla-data-processing
-# Using uv
-uv pip install -e .
+# Install all features and dev dependencies
+uv sync --all-extras
-# Or using pip
-pip install -e .
+# Or install with specific extras only
+uv sync --extra dedup --extra stem
+```
+
+From Source (with pip)
+
+```bash
+git clone https://github.com/U4RASD/dalla-data-processing.git
+cd dalla-data-processing
+
+# Install with all features for development
+pip install -e ".[all,dev]"
```
## Components
+> **Note:** Each component requires its corresponding extra to be installed. Install with `[all]` to enable all features, or see [Modular Installation](#modular-installation-advanced) to install only what you need.
+
### 1. [Deduplication](dalla_data_processing/deduplication/README.md)
Detect and remove duplicate or near-duplicate documents from your datasets using the Onion algorithm.
+- **Requires:** `[dedup]` extra
### 2. [Stemming](dalla_data_processing/stemming/README.md)
Apply morphological analysis and stemming using CAMeL Tools.
+- **Requires:** `[stem]` extra
### 3. [Quality Checking](dalla_data_processing/quality/README.md)
Check text quality using morphological analysis to detect errors and foreign words.
+- **Requires:** `[quality]` extra
### 4. [Readability Scoring](dalla_data_processing/readability/README.md)
Calculate readability scores using Flesch Reading Ease and Osman methods.
Contains also ranking according to both scores
+- **Requires:** `[readability]` extra
+
+### 5. [Dataset Packing](dalla_data_processing/packing/README.md)
+Pack and prepare datasets for training.
+- **Requires:** `[pack]` extra
## Links
diff --git a/dalla_data_processing/__init__.py b/dalla_data_processing/__init__.py
index bd9bb00..9231721 100644
--- a/dalla_data_processing/__init__.py
+++ b/dalla_data_processing/__init__.py
@@ -8,31 +8,49 @@
- Readability scoring
"""
-__version__ = "0.0.1"
-
try:
- from dalla_data_processing.core.dataset import DatasetManager
-
- _has_dataset = True
+ from dalla_data_processing._version import version as __version__
except ImportError:
- _has_dataset = False
- DatasetManager = None
+ # Fallback for development without installation
+ try:
+ from importlib.metadata import PackageNotFoundError, version
-try:
- from dalla_data_processing.utils.tokenize import simple_word_tokenize
+ __version__ = version("dalla-data-processing")
+ except PackageNotFoundError:
+ __version__ = "0.0.0+unknown"
- _has_tokenize = True
-except ImportError:
- _has_tokenize = False
- simple_word_tokenize = None
-try:
- from dalla_data_processing.stemming import stem, stem_dataset
+# Lazy imports - only import when actually used, not at package load time
+def __getattr__(name):
+ """Lazy load heavy modules only when accessed."""
+ if name == "DatasetManager":
+ from dalla_data_processing.core.dataset import DatasetManager
+
+ return DatasetManager
+ elif name == "simple_word_tokenize":
+ from dalla_data_processing.utils.tokenize import simple_word_tokenize
+
+ return simple_word_tokenize
+ elif name == "stem":
+ from dalla_data_processing.stemming import stem
+
+ return stem
+ elif name == "stem_dataset":
+ from dalla_data_processing.stemming import stem_dataset
+
+ return stem_dataset
+ elif name == "DatasetPacker":
+ from dalla_data_processing.packing import DatasetPacker
+
+ return DatasetPacker
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
- _has_stemming = True
-except ImportError:
- _has_stemming = False
- stem = None
- stem_dataset = None
-__all__ = ["DatasetManager", "simple_word_tokenize", "stem", "stem_dataset", "__version__"]
+__all__ = [
+ "DatasetManager",
+ "simple_word_tokenize",
+ "stem",
+ "stem_dataset",
+ "DatasetPacker",
+ "__version__",
+]
diff --git a/dalla_data_processing/cli.py b/dalla_data_processing/cli.py
index 252b717..bf02f7e 100644
--- a/dalla_data_processing/cli.py
+++ b/dalla_data_processing/cli.py
@@ -1,5 +1,5 @@
"""
-Main CLI entry point for dalla-process.
+Main CLI entry point for dalla-data-processing (dalla-dp).
This module provides the unified command-line interface for all
Arabic data processing operations.
@@ -7,16 +7,18 @@
import sys
from pathlib import Path
+from typing import TYPE_CHECKING
import click
-from datasets import Dataset, DatasetDict
from dalla_data_processing import __version__
-from dalla_data_processing.core.dataset import DatasetManager
-from dalla_data_processing.utils import get_logger, setup_logging
+from dalla_data_processing.utils.logger import get_logger
-setup_logging(log_format="console", log_level="INFO")
-logger = get_logger(__name__)
+if TYPE_CHECKING:
+ from datasets import Dataset
+
+
+logger = get_logger("dalla.cli")
class Context:
@@ -30,16 +32,103 @@ def __init__(self):
self.verbose: bool = False
self.overwrite: bool = False
self.dataset: Dataset | None = None
- self.dataset_manager = DatasetManager()
+ self._dataset_manager = None
+
+ @property
+ def dataset_manager(self):
+ """Lazy load DatasetManager."""
+ if self._dataset_manager is None:
+ from dalla_data_processing.core.dataset import DatasetManager
+
+ self._dataset_manager = DatasetManager()
+ return self._dataset_manager
pass_context = click.make_pass_decorator(Context, ensure=True)
-@click.group(
- context_settings={"help_option_names": ["-h", "--help"], "allow_interspersed_args": True}
-)
-@click.version_option(version=__version__, prog_name="dalla-process")
+def common_options(func):
+ """Decorator to add common dataset processing options to commands.
+
+ This allows options to be specified either before or after the subcommand,
+ and ensures they appear in each subcommand's help text.
+ """
+ decorators = [
+ click.option(
+ "--input-dataset",
+ "-i",
+ type=click.Path(exists=True, path_type=Path),
+ help="Path to input HuggingFace dataset",
+ ),
+ click.option(
+ "--output-dataset",
+ "-o",
+ type=click.Path(path_type=Path),
+ help="Path to save output HuggingFace dataset",
+ ),
+ click.option(
+ "--column",
+ "-c",
+ default="text",
+ help="Column name to process (default: 'text')",
+ ),
+ click.option(
+ "--num-workers",
+ "-w",
+ type=int,
+ help="Number of parallel workers (default: auto)",
+ ),
+ click.option(
+ "--verbose",
+ "-v",
+ is_flag=True,
+ help="Enable verbose output",
+ ),
+ click.option(
+ "--quiet",
+ "-q",
+ is_flag=True,
+ help="Suppress non-error output",
+ ),
+ click.option(
+ "--overwrite",
+ is_flag=True,
+ help="Overwrite output dataset if it already exists",
+ ),
+ ]
+ for decorator in reversed(decorators):
+ func = decorator(func)
+ return func
+
+
+def _setup_context_and_logging(
+ ctx: Context,
+ input_dataset: Path | None,
+ output_dataset: Path | None,
+ column: str,
+ num_workers: int | None,
+ verbose: bool,
+ quiet: bool,
+ overwrite: bool,
+):
+ """Helper to populate context from command-level parameters and setup logging."""
+ ctx.input_dataset = input_dataset or ctx.input_dataset
+ ctx.output_dataset = output_dataset or ctx.output_dataset
+ ctx.column = column if column != "text" else ctx.column or column
+ ctx.num_workers = num_workers or ctx.num_workers
+ ctx.verbose = verbose or ctx.verbose
+ ctx.overwrite = overwrite or ctx.overwrite
+
+ from dalla_data_processing.utils import setup_logging
+
+ if quiet:
+ setup_logging(log_format="console", log_level="ERROR")
+ elif ctx.verbose:
+ setup_logging(log_format="console", log_level="DEBUG")
+
+
+@click.group(context_settings={"help_option_names": ["-h", "--help"]})
+@click.version_option(version=__version__, prog_name="dalla-data-processing")
@click.option(
"--input-dataset",
"-i",
@@ -95,22 +184,12 @@ def cli(
"""
Dalla Data Processing - Unified Arabic Data Processing Pipeline
- A comprehensive toolkit for processing Arabic text data with support for:
- Deduplication using onion algorithm
- Stemming and morphological analysis
- Quality checking
- Readability scoring
+ - Packing Dataset for training
- Examples:
-
- # Deduplicate a dataset
- dalla-dp -i ./data/raw -o ./data/deduped deduplicate
-
- # Stem text with 8 workers
- dalla-dp -i ./data/raw -o ./data/stemmed -w 8 stem
-
- # Check quality with custom column
- dalla-dp -i ./data/raw -o ./data/quality -c content quality-check
"""
ctx.input_dataset = input_dataset
ctx.output_dataset = output_dataset
@@ -119,13 +198,17 @@ def cli(
ctx.verbose = verbose
ctx.overwrite = overwrite
- if quiet:
- setup_logging(log_format="console", log_level="ERROR")
- elif verbose:
- setup_logging(log_format="console", log_level="DEBUG")
+ if quiet or verbose:
+ from dalla_data_processing.utils import setup_logging
+
+ if quiet:
+ setup_logging(log_format="console", log_level="ERROR")
+ elif verbose:
+ setup_logging(log_format="console", log_level="DEBUG")
@cli.command(context_settings={"help_option_names": ["-h", "--help"]})
+@common_options
@click.option(
"--threshold",
"-t",
@@ -161,6 +244,13 @@ def cli(
@pass_context
def deduplicate(
ctx: Context,
+ input_dataset: Path | None,
+ output_dataset: Path | None,
+ column: str,
+ num_workers: int | None,
+ verbose: bool,
+ quiet: bool,
+ overwrite: bool,
threshold: float,
return_pairs: bool,
keep_vert_files: bool,
@@ -169,20 +259,28 @@ def deduplicate(
onion_binary: str | None,
):
"""Remove duplicate entries using onion algorithm."""
+ _setup_context_and_logging(
+ ctx, input_dataset, output_dataset, column, num_workers, verbose, quiet, overwrite
+ )
_require_io_paths(ctx)
- click.echo(f"Loading dataset from {ctx.input_dataset}")
+ logger.info(f"Loading dataset from {ctx.input_dataset}")
dataset = ctx.dataset_manager.load(ctx.input_dataset)
dataset = _handle_dataset_dict(dataset)
mode = "pairs" if return_pairs else "filter"
- click.echo(f"Deduplicating with threshold={threshold}, mode={mode}")
+ logger.info(f"Deduplicating with threshold={threshold}, mode={mode}")
if calculate_scores:
- click.echo(" Phase 2: ON (calculating similarity scores)")
+ logger.info(" Phase 2: ON (calculating similarity scores)")
else:
- click.echo(" Phase 2: OFF (faster, sufficient for most use cases)")
+ logger.info(" Phase 2: OFF (faster, sufficient for most use cases)")
- from dalla_data_processing.deduplication import deduplicate_dataset
+ try:
+ from dalla_data_processing.deduplication import deduplicate_dataset
+ except ImportError:
+ logger.error("Missing dependencies for deduplication")
+ logger.error("Install with: pip install 'dalla-data-processing[dedup]'")
+ sys.exit(1)
deduplicated = deduplicate_dataset(
dataset,
@@ -195,28 +293,31 @@ def deduplicate(
onion_binary=Path(onion_binary) if onion_binary else None,
)
- click.echo(f"Saving deduplicated dataset to {ctx.output_dataset}")
+ logger.info(f"Saving deduplicated dataset to {ctx.output_dataset}")
ctx.dataset_manager.save(deduplicated, ctx.output_dataset, overwrite=ctx.overwrite)
+ from dalla_data_processing.core.dataset import DatasetManager
+
original_size = DatasetManager.get_size(dataset)
final_size = DatasetManager.get_size(deduplicated)
- click.echo(click.style("✓ Deduplication complete", fg="green"))
- click.echo(f" Original: {original_size:,} examples")
+ logger.info("✓ Deduplication complete")
+ logger.info(f" Original: {original_size:,} examples")
if return_pairs:
num_dups = sum(1 for ex in deduplicated if ex.get("is_duplicate", False))
- click.echo(
+ logger.info(
f" Documents with duplicates: {num_dups:,} ({num_dups / original_size * 100:.1f}%)"
)
- click.echo(" Added columns: duplicate_cluster, is_duplicate, duplicate_count")
+ logger.info(" Added columns: duplicate_cluster, is_duplicate, duplicate_count")
else:
removed = original_size - final_size
- click.echo(f" Removed: {removed:,} duplicates ({removed / original_size * 100:.1f}%)")
- click.echo(f" Final: {final_size:,} examples")
+ logger.info(f" Removed: {removed:,} duplicates ({removed / original_size * 100:.1f}%)")
+ logger.info(f" Final: {final_size:,} examples")
@cli.command(context_settings={"help_option_names": ["-h", "--help"]})
+@common_options
@click.option(
"--sep-token",
default="<+>",
@@ -245,19 +346,39 @@ def deduplicate(
)
@pass_context
def stem(
- ctx: Context, sep_token: str, normalize: bool, keep_diacritics: bool, model: str, use_gpu: bool
+ ctx: Context,
+ input_dataset: Path | None,
+ output_dataset: Path | None,
+ column: str,
+ num_workers: int | None,
+ verbose: bool,
+ quiet: bool,
+ overwrite: bool,
+ sep_token: str,
+ normalize: bool,
+ keep_diacritics: bool,
+ model: str,
+ use_gpu: bool,
):
"""Apply stemming and morphological analysis."""
+ _setup_context_and_logging(
+ ctx, input_dataset, output_dataset, column, num_workers, verbose, quiet, overwrite
+ )
_require_io_paths(ctx)
- click.echo(f"Loading dataset from {ctx.input_dataset}")
+ logger.info(f"Loading dataset from {ctx.input_dataset}")
dataset = ctx.dataset_manager.load(ctx.input_dataset)
dataset = _handle_dataset_dict(dataset)
- click.echo(f"Stemming {ctx.column} column (workers={ctx.num_workers or 'auto'})")
- click.echo(f"Model: {model.upper()}{' (GPU enabled)' if model == 'bert' and use_gpu else ''}")
+ logger.info(f"Stemming {ctx.column} column (workers={ctx.num_workers or 'auto'})")
+ logger.info(f"Model: {model.upper()}{' (GPU enabled)' if model == 'bert' and use_gpu else ''}")
- from dalla_data_processing.stemming import stem_dataset
+ try:
+ from dalla_data_processing.stemming import stem_dataset
+ except ImportError:
+ logger.error("Missing dependencies for stemming")
+ logger.error("Install with: pip install 'dalla-data-processing[stem]'")
+ sys.exit(1)
stemmed = stem_dataset(
dataset,
@@ -270,13 +391,14 @@ def stem(
use_gpu=use_gpu,
)
- click.echo(f"Saving stemmed dataset to {ctx.output_dataset}")
+ logger.info(f"Saving stemmed dataset to {ctx.output_dataset}")
ctx.dataset_manager.save(stemmed, ctx.output_dataset, overwrite=ctx.overwrite)
- click.echo(click.style("✓ Stemming complete", fg="green"))
+ logger.info("✓ Stemming complete")
@cli.command(context_settings={"help_option_names": ["-h", "--help"]})
+@common_options
@click.option(
"--min-score",
type=float,
@@ -300,18 +422,39 @@ def stem(
help="Use GPU for BERT model (only applicable when --model=bert)",
)
@pass_context
-def quality_check(ctx: Context, min_score: float, save_errors: bool, model: str, use_gpu: bool):
+def quality_check(
+ ctx: Context,
+ input_dataset: Path | None,
+ output_dataset: Path | None,
+ column: str,
+ num_workers: int | None,
+ verbose: bool,
+ quiet: bool,
+ overwrite: bool,
+ min_score: float,
+ save_errors: bool,
+ model: str,
+ use_gpu: bool,
+):
"""Check text quality and calculate scores."""
+ _setup_context_and_logging(
+ ctx, input_dataset, output_dataset, column, num_workers, verbose, quiet, overwrite
+ )
_require_io_paths(ctx)
- click.echo(f"Loading dataset from {ctx.input_dataset}")
+ logger.info(f"Loading dataset from {ctx.input_dataset}")
dataset = ctx.dataset_manager.load(ctx.input_dataset)
dataset = _handle_dataset_dict(dataset)
- click.echo(f"Checking quality of {ctx.column} column")
- click.echo(f"Model: {model.upper()}{' (GPU enabled)' if model == 'bert' and use_gpu else ''}")
+ logger.info(f"Checking quality of {ctx.column} column")
+ logger.info(f"Model: {model.upper()}{' (GPU enabled)' if model == 'bert' and use_gpu else ''}")
- from dalla_data_processing.quality import check_quality
+ try:
+ from dalla_data_processing.quality import check_quality
+ except ImportError:
+ logger.error("Missing dependencies for quality checking")
+ logger.error("Install with: pip install 'dalla-data-processing[quality]'")
+ sys.exit(1)
scored = check_quality(
dataset,
@@ -323,40 +466,61 @@ def quality_check(ctx: Context, min_score: float, save_errors: bool, model: str,
use_gpu=use_gpu,
)
- click.echo(f"Saving quality-checked dataset to {ctx.output_dataset}")
+ logger.info(f"Saving quality-checked dataset to {ctx.output_dataset}")
ctx.dataset_manager.save(scored, ctx.output_dataset, overwrite=ctx.overwrite)
+ from dalla_data_processing.core.dataset import DatasetManager
+
original_size = DatasetManager.get_size(dataset)
final_size = DatasetManager.get_size(scored)
- click.echo(click.style("✓ Quality check complete", fg="green"))
+ logger.info("✓ Quality check complete")
if min_score > 0:
removed = original_size - final_size
- click.echo(
+ logger.info(
f" Filtered {removed:,} low-quality examples ({removed / original_size * 100:.1f}%)"
)
@cli.command(context_settings={"help_option_names": ["-h", "--help"]})
+@common_options
@click.option(
"--add-ranks/--no-ranks",
default=True,
help="Add ranking and level columns (default: True)",
)
@pass_context
-def readability(ctx: Context, add_ranks: bool):
+def readability(
+ ctx: Context,
+ input_dataset: Path | None,
+ output_dataset: Path | None,
+ column: str,
+ num_workers: int | None,
+ verbose: bool,
+ quiet: bool,
+ overwrite: bool,
+ add_ranks: bool,
+):
"""Calculate readability scores using Flesch and Osman methods."""
+ _setup_context_and_logging(
+ ctx, input_dataset, output_dataset, column, num_workers, verbose, quiet, overwrite
+ )
_require_io_paths(ctx)
- click.echo(f"Loading dataset from {ctx.input_dataset}")
+ logger.info(f"Loading dataset from {ctx.input_dataset}")
dataset = ctx.dataset_manager.load(ctx.input_dataset)
dataset = _handle_dataset_dict(dataset)
- click.echo(f"Calculating readability scores for {ctx.column} column")
+ logger.info(f"Calculating readability scores for {ctx.column} column")
if add_ranks:
- click.echo(" Including ranking and difficulty levels (0-4)")
+ logger.info(" Including ranking and difficulty levels (0-4)")
- from dalla_data_processing.readability import score_readability
+ try:
+ from dalla_data_processing.readability import score_readability
+ except ImportError:
+ logger.error("Missing dependencies for readability scoring")
+ logger.error("Install with: pip install 'dalla-data-processing[readability]'")
+ sys.exit(1)
scored = score_readability(
dataset,
@@ -365,16 +529,168 @@ def readability(ctx: Context, add_ranks: bool):
num_proc=ctx.num_workers,
)
- click.echo(f"Saving scored dataset to {ctx.output_dataset}")
+ logger.info(f"Saving scored dataset to {ctx.output_dataset}")
ctx.dataset_manager.save(scored, ctx.output_dataset, overwrite=ctx.overwrite)
- click.echo(click.style("✓ Readability scoring complete", fg="green"))
+ logger.info("✓ Readability scoring complete")
if add_ranks:
- click.echo(" Added columns: flesch_score, osman_score, flesch_rank, osman_rank,")
- click.echo(" readability_level")
+ logger.info(" Added columns: flesch_score, osman_score, flesch_rank, osman_rank,")
+ logger.info(" readability_level")
+ else:
+ logger.info(" Added columns: flesch_score, osman_score")
+
+
+@cli.command(context_settings={"help_option_names": ["-h", "--help"]})
+@common_options
+@click.option(
+ "--config",
+ type=click.Path(exists=True, path_type=Path),
+ help="Path to config YAML file (optional)",
+)
+@click.option(
+ "--tokenizer-path",
+ type=str,
+ help="Path to tokenizer",
+)
+@click.option(
+ "--max-seq-length",
+ type=int,
+ help="Maximum sequence length for packing",
+)
+@click.option(
+ "--chunk-size-gb",
+ type=float,
+ help="Size of each processing chunk in GB",
+)
+@click.option(
+ "--subset-order",
+ multiple=True,
+ help="Subset processing order (can be specified multiple times)",
+)
+@click.option(
+ "--sft",
+ is_flag=True,
+ help="Enable SFT mode (uses tokenizer's chat template)",
+)
+@click.option(
+ "--rbpe",
+ is_flag=True,
+ help="Use R-BPE tokenizer",
+)
+@click.option(
+ "--text-column",
+ type=str,
+ help="Column name containing text data (defaults: 'text' for non-SFT, 'messages' for SFT)",
+)
+@pass_context
+def pack(
+ ctx: Context,
+ input_dataset: Path | None,
+ output_dataset: Path | None,
+ column: str,
+ num_workers: int | None,
+ verbose: bool,
+ quiet: bool,
+ overwrite: bool,
+ config: Path | None,
+ tokenizer_path: str | None,
+ max_seq_length: int | None,
+ chunk_size_gb: float | None,
+ subset_order: tuple[str, ...],
+ sft: bool,
+ rbpe: bool,
+ text_column: str | None,
+):
+ """Pack datasets for efficient LLM training.
+
+ Combines multiple examples into fixed-length sequences, optimizing for
+ efficient training. Preserves data integrity by ensuring no example is
+ split across multiple sequences.
+ """
+ _setup_context_and_logging(
+ ctx, input_dataset, output_dataset, column, num_workers, verbose, quiet, overwrite
+ )
+ _require_io_paths(ctx)
+
+ try:
+ import yaml
+ except ImportError:
+ logger.error("Missing dependencies for packing")
+ logger.error("Install with: pip install 'dalla-data-processing[pack]'")
+ sys.exit(1)
+
+ config_data = {}
+ if config:
+ with open(config) as f:
+ config_data = yaml.safe_load(f) or {}
+
+ if tokenizer_path:
+ config_data["tokenizer_path"] = tokenizer_path
+ if rbpe:
+ config_data["rbpe"] = True
+ if sft:
+ config_data["sft"] = True
+ if max_seq_length is not None:
+ config_data["max_seq_length"] = max_seq_length
+ if chunk_size_gb is not None:
+ config_data["chunk_size_gb"] = chunk_size_gb
+ if subset_order:
+ config_data["subset_order"] = list(subset_order)
+ if text_column:
+ config_data["text_column"] = text_column
+
+ if "tokenizer_path" not in config_data:
+ logger.error("Error: --tokenizer-path is required")
+ sys.exit(1)
+
+ if config_data.get("rbpe"):
+ try:
+ from rbpe import RBPETokenizer
+
+ tokenizer = RBPETokenizer.from_pretrained(config_data["tokenizer_path"])
+ except ImportError:
+ logger.error("Missing rbpe package")
+ logger.error("Install with: pip install rbpe")
+ sys.exit(1)
else:
- click.echo(" Added columns: flesch_score, osman_score")
+ try:
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(config_data["tokenizer_path"])
+ except ImportError:
+ logger.error("Missing transformers package")
+ logger.error("Install with: pip install transformers")
+ sys.exit(1)
+
+ if "text_column" not in config_data:
+ config_data["text_column"] = "messages" if config_data.get("sft", False) else "text"
+
+ logger.info("Starting dataset processing")
+ logger.info(f"Input path: {ctx.input_dataset}")
+ logger.info(f"Output directory: {ctx.output_dataset}")
+ logger.info(f"Using RBPE tokenizer: {config_data.get('rbpe', False)}")
+ logger.info(f"Chunk size: {config_data.get('chunk_size_gb', 2)}GB")
+ logger.info(f"Max sequence length: {config_data.get('max_seq_length', 2048)}")
+
+ from dalla_data_processing.packing import DatasetPacker
+
+ processor = DatasetPacker(
+ input_dataset=str(ctx.input_dataset),
+ output_dataset=str(ctx.output_dataset),
+ tokenizer=tokenizer,
+ subset_order=config_data.get("subset_order"),
+ num_workers=ctx.num_workers or 4,
+ chunk_size_gb=config_data.get("chunk_size_gb", 2.0),
+ max_seq_length=config_data.get("max_seq_length", 2048),
+ sft=config_data.get("sft", False),
+ rbpe=config_data.get("rbpe", False),
+ text_column=config_data.get("text_column"),
+ )
+
+ final_path = processor.process()
+ logger.info("Processing completed successfully!")
+ logger.info(f"Final dataset saved to: {final_path}")
@cli.command(context_settings={"help_option_names": ["-h", "--help"]})
@@ -388,31 +704,34 @@ def readability(ctx: Context, add_ranks: bool):
)
def info(dataset_path: Path, split: str | None):
"""Display information about a dataset."""
+ from dalla_data_processing.core.dataset import DatasetManager
+
dm = DatasetManager()
try:
dataset = dm.load(dataset_path, split=split)
dm.print_info(dataset)
except Exception as e:
- click.echo(click.style(f"Error loading dataset: {e}", fg="red"), err=True)
+ logger.error(f"Error loading dataset: {e}")
sys.exit(1)
def _handle_dataset_dict(dataset, split_preference: str = "train"):
"""Handle DatasetDict by selecting appropriate split."""
+ from datasets import DatasetDict
if isinstance(dataset, DatasetDict):
splits = list(dataset.keys())
- click.echo(f"Dataset has multiple splits: {', '.join(splits)}")
+ logger.info(f"Dataset has multiple splits: {', '.join(splits)}")
if split_preference in dataset:
- click.echo(
+ logger.info(
f"Using '{split_preference}' split ({len(dataset[split_preference])} examples)"
)
return dataset[split_preference]
else:
first_split = splits[0]
- click.echo(f"Using '{first_split}' split ({len(dataset[first_split])} examples)")
+ logger.info(f"Using '{first_split}' split ({len(dataset[first_split])} examples)")
return dataset[first_split]
else:
return dataset
@@ -421,19 +740,13 @@ def _handle_dataset_dict(dataset, split_preference: str = "train"):
def _require_io_paths(ctx: Context):
"""Ensure input and output paths are provided."""
if ctx.input_dataset is None:
- click.echo(
- click.style("Error: --input-dataset is required", fg="red"),
- err=True,
- )
- click.echo("Use --help for usage information")
+ logger.error("Error: --input-dataset is required")
+ logger.info("Use --help for usage information")
sys.exit(1)
if ctx.output_dataset is None:
- click.echo(
- click.style("Error: --output-dataset is required", fg="red"),
- err=True,
- )
- click.echo("Use --help for usage information")
+ logger.error("Error: --output-dataset is required")
+ logger.info("Use --help for usage information")
sys.exit(1)
@@ -442,10 +755,10 @@ def main():
try:
cli(obj=Context())
except KeyboardInterrupt:
- click.echo("\n" + click.style("Interrupted by user", fg="yellow"))
+ logger.warning("\nInterrupted by user")
sys.exit(130)
except Exception as e:
- click.echo(click.style(f"Error: {e}", fg="red"), err=True)
+ logger.error(f"Error: {e}")
if "--verbose" in sys.argv or "-v" in sys.argv:
raise
sys.exit(1)
diff --git a/dalla_data_processing/core/dataset.py b/dalla_data_processing/core/dataset.py
index bf19068..84a4fe0 100644
--- a/dalla_data_processing/core/dataset.py
+++ b/dalla_data_processing/core/dataset.py
@@ -1,8 +1,5 @@
"""
Dataset I/O utilities for unified HuggingFace dataset handling.
-
-This module provides a consistent interface for loading, saving, and manipulating
-HuggingFace datasets across all dalla-process components.
"""
from collections.abc import Callable
diff --git a/dalla_data_processing/deduplication/README.md b/dalla_data_processing/deduplication/README.md
index 50fa4cb..18072d3 100644
--- a/dalla_data_processing/deduplication/README.md
+++ b/dalla_data_processing/deduplication/README.md
@@ -2,6 +2,15 @@
Detect and remove duplicate or near-duplicate documents from your datasets using the Onion algorithm.
+## Installation
+
+This feature requires the `[dedup]` extra:
+
+```bash
+pip install "dalla-data-processing[dedup]"
+# or install all features: pip install "dalla-data-processing[all]"
+```
+
## CLI Usage
**Command:** `dalla-dp deduplicate [OPTIONS]`
diff --git a/dalla_data_processing/packing/README.md b/dalla_data_processing/packing/README.md
new file mode 100644
index 0000000..a0fe008
--- /dev/null
+++ b/dalla_data_processing/packing/README.md
@@ -0,0 +1,244 @@
+# Dataset Packing
+
+Pack datasets efficiently for LLM training by combining multiple examples into fixed-length sequences. Pre-pack your datasets locally to avoid wasting expensive GPU time on servers.
+
+## Why Use This Tool?
+
+When training large language models, dataset packing is essential for efficient training. This tool allows you to:
+- **Save GPU time**: Pack datasets on your local machine before uploading to training servers
+- **Preserve data integrity**: Ensures no example is split across multiple packed sequences
+- **Handle large datasets**: Process datasets in chunks to manage memory efficiently
+- **Flexible tokenization**: Support for both standard text data and chat-formatted (SFT) data
+
+## Installation
+
+This feature requires the `[pack]` extra:
+
+```bash
+pip install "dalla-data-processing[pack]"
+# or install all features: pip install "dalla-data-processing[all]"
+```
+
+## CLI Usage
+
+The packing functionality is integrated into the unified `dalla-dp` CLI:
+
+### Basic Usage
+
+```bash
+dalla-dp -i input_dataset -o output_dataset pack --tokenizer-path /path/to/tokenizer
+```
+
+### Common Options (from main CLI)
+
+- `-i, --input-dataset`: Path to input HuggingFace dataset (required)
+- `-o, --output-dataset`: Path to save output dataset (required)
+- `-w, --num-workers`: Number of parallel workers for packing (default: 4)
+- `-v, --verbose`: Enable verbose output
+- `--overwrite`: Overwrite output if it exists
+
+### Pack-Specific Options
+
+- `--config`: Path to config YAML file (optional)
+- `--tokenizer-path`: Path to tokenizer
+- `--max-seq-length`: Maximum sequence length (default: 2048)
+- `--chunk-size-gb`: Chunk size in GB (default: 2.0)
+- `--text-column`: Text column name (default: 'text' or 'messages' for SFT)
+- `--subset-order`: Subset processing order
+- `--sft`: Enable SFT mode
+- `--rbpe`: Use R-BPE tokenizer
+
+### Examples
+
+**Basic packing:**
+```bash
+dalla-dp -i my_dataset -o packed_dataset pack --tokenizer-path /path/to/tokenizer
+```
+
+**Using a config file:**
+```bash
+dalla-dp -i my_dataset -o packed_dataset pack --config pack_config.yaml
+```
+
+**Override config with command line:**
+```bash
+dalla-dp -i my_dataset -o packed_dataset pack \
+ --config pack_config.yaml \
+ --max-seq-length 4096
+```
+
+**With custom sequence length and workers:**
+```bash
+dalla-dp -i my_dataset -o packed_dataset -w 8 pack \
+ --tokenizer-path /path/to/tokenizer \
+ --max-seq-length 4096
+```
+
+**SFT mode with subset order:**
+```bash
+dalla-dp -i my_dataset -o packed_dataset pack \
+ --tokenizer-path /path/to/tokenizer \
+ --sft \
+ --subset-order train --subset-order validation
+```
+
+**Using a custom text column:**
+```bash
+dalla-dp -i my_dataset -o packed_dataset pack \
+ --tokenizer-path /path/to/tokenizer \
+ --text-column content
+```
+
+**With verbose output:**
+```bash
+dalla-dp -i my_dataset -o packed_dataset -v pack \
+ --tokenizer-path /path/to/tokenizer \
+ --chunk-size-gb 1.0
+```
+
+## Configuration File
+
+```yaml
+tokenizer_path: "/path/to/tokenizer"
+max_seq_length: 2048
+chunk_size_gb: 2.0
+sft: false
+rbpe: false
+text_column: "content"
+
+subset_order:
+ - "train"
+ - "validation"
+```
+
+CLI arguments override config values.
+
+## Python API Usage
+
+You can also use the packing functionality directly in Python:
+
+```python
+from dalla_data_processing.packing import DatasetPacker
+from transformers import AutoTokenizer
+
+tokenizer = AutoTokenizer.from_pretrained("path/to/tokenizer")
+
+packer = DatasetPacker(
+ input_dataset="my_dataset",
+ output_dataset="packed_dataset",
+ tokenizer=tokenizer,
+ num_workers=4,
+ max_seq_length=2048,
+ chunk_size_gb=2.0,
+ text_column="content",
+)
+
+final_path = packer.process()
+```
+
+### API Parameters
+
+- `input_dataset` (str): Path to input dataset
+- `output_dataset` (str): Path for output dataset
+- `tokenizer`: HuggingFace tokenizer instance
+- `subset_order` (list[str], optional): Order to process subsets
+- `num_workers` (int): Number of parallel packing processes (default: 4)
+- `chunk_size_gb` (float): Size of processing chunks in GB (default: 2.0)
+- `max_seq_length` (int): Maximum sequence length (default: 2048)
+- `sft` (bool): Enable SFT mode (default: False)
+- `rbpe` (bool): Use R-BPE tokenizer (default: False)
+- `text_column` (str, optional): Text column name
+
+## The Packing Method
+
+- Packs multiple examples into a single sequence up to `max_seq_length`
+- Adds EOS token between examples
+- Pads remaining space with PAD tokens
+- **Guarantees no example is cut in the middle** - if an example doesn't fit, it starts in the next sequence
+- Best for preserving data integrity
+
+## SFT Mode
+
+Enable SFT mode when working with chat-formatted data:
+
+```bash
+dalla-dp -i my_dataset -o packed_dataset pack \
+ --tokenizer-path /path/to/tokenizer \
+ --sft
+```
+
+**When to use SFT mode:**
+- Your dataset has a `messages` field with chat conversations
+- Your tokenizer has a chat template defined
+- You're doing supervised fine-tuning (SFT) on conversational data
+
+**When NOT to use SFT mode:**
+- Continued pre-training (CPT) on plain text
+- Your tokenizer doesn't have a chat template
+- Your dataset only has a `text` field
+
+**Input format for SFT mode:**
+```python
+{
+ "messages": [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there!"}
+ ]
+}
+```
+
+## Custom Text Column
+
+Defaults: `text` (non-SFT) or `messages` (SFT). Override with `--text-column`.
+
+## Dataset Format
+
+### Input
+Your dataset should be in Hugging Face datasets format:
+```
+my_dataset/
+ train/
+ data-00000-of-00001.arrow
+ dataset_info.json
+ state.json
+ validation/
+ data-00000-of-00001.arrow
+ dataset_info.json
+ state.json
+ dataset_dict.json
+```
+
+### Output
+The packed dataset will be saved to:
+```
+packed_dataset/
+ final_dataset/
+ train/
+ data-00000-of-00001.arrow
+ dataset_info.json
+ state.json
+```
+
+Each example in the final dataset will have:
+- `input_ids`: Token IDs (length = `max_seq_length`)
+- `labels`: Same as `input_ids` (or masked for SFT)
+
+## Memory Considerations
+
+The `--chunk-size-gb` parameter controls memory usage:
+- **Smaller values** (0.5-1 GB): Lower memory, more chunks
+- **Larger values** (2-4 GB): Higher memory, fewer chunks
+
+## How It Works
+
+1. **Analyze**: Calculate sizes of dataset subsets
+2. **Split**: Divide datasets into manageable chunks
+3. **Tokenize**: Convert text to token IDs using parallel processing
+4. **Pack**: Combine multiple examples into fixed-length sequences
+5. **Concatenate**: Merge all packed chunks into final dataset
+
+The tool automatically:
+- Preserves subset ordering
+- Removes intermediate files to save disk space
+- Handles empty subsets gracefully
+- Skips examples longer than `max_seq_length`
diff --git a/dalla_data_processing/packing/__init__.py b/dalla_data_processing/packing/__init__.py
new file mode 100644
index 0000000..678838c
--- /dev/null
+++ b/dalla_data_processing/packing/__init__.py
@@ -0,0 +1,10 @@
+"""
+Dataset packing module for efficient LLM training.
+
+This module provides functionality to pack datasets by combining multiple
+examples into fixed-length sequences, optimizing for efficient training.
+"""
+
+from dalla_data_processing.packing.dataset_packer import DatasetPacker
+
+__all__ = ["DatasetPacker"]
diff --git a/dalla_data_processing/packing/dataset_packer.py b/dalla_data_processing/packing/dataset_packer.py
new file mode 100644
index 0000000..92015a9
--- /dev/null
+++ b/dalla_data_processing/packing/dataset_packer.py
@@ -0,0 +1,449 @@
+import contextlib
+import math
+import os
+import shutil
+from multiprocessing import Pool, cpu_count
+
+from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
+from tqdm import tqdm
+
+from dalla_data_processing.utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+def get_directory_size(path):
+ total_size = 0
+ for dirpath, _dirnames, filenames in os.walk(path):
+ for filename in filenames:
+ filepath = os.path.join(dirpath, filename)
+ # Skip if it's a symbolic link
+ if not os.path.islink(filepath):
+ with contextlib.suppress(OSError):
+ total_size += os.path.getsize(filepath)
+ return total_size
+
+
+def remove_path(path):
+ """Safely remove a file, symlink, or directory tree."""
+ try:
+ if os.path.islink(path) or os.path.isfile(path):
+ os.remove(path)
+ elif os.path.isdir(path):
+ shutil.rmtree(path)
+ except Exception as e:
+ logger.warning("Failed to remove path", path=path, error=str(e))
+
+
+class DatasetPacker:
+ def __init__(
+ self,
+ input_dataset,
+ output_dataset,
+ tokenizer,
+ subset_order=None,
+ num_workers=4,
+ chunk_size_gb=2,
+ max_seq_length=2048,
+ sft=False,
+ rbpe=False,
+ text_column=None,
+ ):
+ self.input_dataset = input_dataset
+ self.output_dataset = output_dataset
+ self.tokenizer = tokenizer
+ self.num_workers = num_workers
+ self.chunk_size_bytes = int(chunk_size_gb * 1024**3)
+ self.max_seq_length = max_seq_length
+ self.rbpe = rbpe
+ self.subset_order = subset_order
+ self.sft = sft
+ # Set text_column: use provided value, or default based on sft mode
+ if text_column:
+ self.text_column = text_column
+ else:
+ self.text_column = "messages" if sft else "text"
+ if self.rbpe:
+ self.parallel = False
+ if self.sft:
+ self.add_special_tokens = False
+ self.append_concat_token = True
+ else:
+ self.add_special_tokens = True
+ self.append_concat_token = False
+ else:
+ self.parallel = True
+ self.add_special_tokens = True
+ self.append_concat_token = True
+ os.makedirs(output_dataset, exist_ok=True)
+
+ def get_directory_sizes(self, base_path):
+ # Handle both directory and direct dataset paths
+ if os.path.isfile(os.path.join(base_path, "dataset_info.json")):
+ # Direct dataset path
+ size_bytes = 0
+ for root, _dirs, files in os.walk(base_path):
+ for file in files:
+ file_path = os.path.join(root, file)
+ size_bytes += os.path.getsize(file_path)
+ return {"dataset": size_bytes}
+ else:
+ sizes = {}
+ dirs_to_check = self.subset_order if self.subset_order else os.listdir(base_path)
+ for dir_name in dirs_to_check:
+ path = os.path.join(base_path, dir_name)
+ if not os.path.exists(path):
+ logger.warning("Directory not found", directory=dir_name, base_path=base_path)
+ continue
+ size_bytes = 0
+ for root, _dirs, files in os.walk(path):
+ for file in files:
+ file_path = os.path.join(root, file)
+ size_bytes += os.path.getsize(file_path)
+ sizes[dir_name] = size_bytes
+ return sizes
+
+ def split_dataset(self, base_path, sizes):
+ counter = 0
+ splits = []
+
+ # Handle single dataset case
+ if "dataset" in sizes:
+ dataset = load_from_disk(base_path)
+ size = sizes["dataset"]
+ num_splits = math.ceil(size / self.chunk_size_bytes)
+ total_size = len(dataset)
+ subset_size = total_size // num_splits
+ remainder = total_size % num_splits
+ start_idx = 0
+
+ for i in range(num_splits):
+ current_size = subset_size + (1 if i < remainder else 0)
+ end_idx = start_idx + current_size
+
+ subset_data = dataset.select(range(start_idx, end_idx))
+ name = f"{self.output_dataset}/split_{counter}"
+
+ counter += 1
+ subset_data.save_to_disk(name)
+ size_bytes = get_directory_size(name)
+ splits.append((name, size_bytes))
+ del subset_data
+ start_idx = end_idx
+ return splits
+
+ # Multiple datasets case
+ logger.info("Splitting datasets", subset_order=self.subset_order)
+ for subset in self.subset_order or sizes.keys():
+ size = sizes[subset]
+ dataset = load_from_disk(f"{base_path}/{subset}")
+ num_splits = math.ceil(size / self.chunk_size_bytes)
+ total_size = len(dataset)
+ subset_size = total_size // num_splits
+ remainder = total_size % num_splits
+ start_idx = 0
+
+ for i in range(num_splits):
+ # Add one extra item to early splits if there's a remainder
+ current_size = subset_size + (1 if i < remainder else 0)
+ end_idx = start_idx + current_size
+
+ subset_data = dataset.select(range(start_idx, end_idx))
+ name = f"{self.output_dataset}/split_{counter}"
+
+ counter += 1
+ subset_data.save_to_disk(name)
+ size_bytes = get_directory_size(name)
+ splits.append((name, size_bytes))
+ del subset_data
+ start_idx = end_idx
+ return splits
+
+ def create_chunks(self, sizes):
+ chunks = []
+ current_chunk = []
+ current_size = 0
+
+ # Use subset_order instead of sorting
+ for dir_name, size in sizes:
+ if current_size + size > self.chunk_size_bytes and current_chunk:
+ chunks.append(current_chunk)
+ current_chunk = []
+ current_size = 0
+ current_chunk.append(dir_name)
+ current_size += size
+
+ if current_chunk:
+ chunks.append(current_chunk)
+ return chunks
+
+ def tokenize_batch(self, texts):
+ if self.sft:
+ return self.tokenizer.apply_chat_template(
+ texts,
+ truncation=False,
+ padding=False,
+ return_assistant_tokens_mask=True,
+ return_dict=True,
+ add_special_tokens=self.add_special_tokens,
+ )
+ else:
+ return self.tokenizer(texts, truncation=False, padding=False)
+
+ def process_chunk(self, chunk_ranges, chunk_idx):
+ # Load and concatenate datasets in chunk
+ dataset_splits = []
+ for range_name in chunk_ranges:
+ try:
+ split = load_from_disk(range_name)
+ dataset_splits.append(split)
+ logger.info("Loaded split", split_name=range_name)
+ except Exception as e:
+ logger.error("Error loading split", split_name=range_name, error=str(e))
+
+ if not dataset_splits:
+ return None
+ # Concatenate splits
+ concatenated = concatenate_datasets(dataset_splits)
+ del split
+ del dataset_splits
+ # delete splits
+ for split in chunk_ranges:
+ remove_path(split)
+ # Tokenize
+ logger.info("Tokenizing chunk", chunk_idx=chunk_idx)
+ texts = concatenated[self.text_column]
+
+ num_cores = cpu_count() - 1 # Leave some cores free
+ text_chunk_size = len(texts) // num_cores
+ text_chunks = [
+ texts[i : i + text_chunk_size] for i in range(0, len(texts), text_chunk_size)
+ ]
+
+ if self.parallel:
+ logger.info("Tokenizing in parallel", num_cores=num_cores)
+ pool = Pool(num_cores)
+ try:
+ tokenized_chunks = list(
+ tqdm(
+ pool.imap(self.tokenize_batch, text_chunks),
+ total=len(text_chunks),
+ desc="Tokenizing",
+ )
+ )
+ finally:
+ pool.close()
+ pool.join()
+ else:
+ tokenized_chunks = [
+ self.tokenize_batch(text_chunk)
+ for text_chunk in tqdm(text_chunks, desc="Tokenizing")
+ ]
+
+ input_ids = [i["input_ids"] for i in tokenized_chunks]
+ assistant_masks = (
+ [i["assistant_masks"] for i in tokenized_chunks] if self.sft else input_ids
+ )
+ all_input_ids = [item for sublist in input_ids for item in sublist]
+ all_assistant_masks = (
+ [item for sublist in assistant_masks for item in sublist] if self.sft else all_input_ids
+ )
+ zeros = False
+ for assistant_mask in all_assistant_masks:
+ if all(i == 0 for i in assistant_mask):
+ zeros = True
+ break
+ if zeros:
+ tokenized_dataset = Dataset.from_dict(
+ {"input_ids": all_input_ids, "labels": all_input_ids}
+ )
+ else:
+ logger.debug("Assistant masks not all zeros")
+ new_labels = []
+ for assistant_mask, input_id in zip(all_assistant_masks, all_input_ids, strict=True):
+ # new_labels.append([-100 * i*j for i, j in zip(assistant_mask, input_id)])
+ # if i in attention_mask is 0, then have -100, otherwise have input_id
+ new_labels.append(
+ [-100 if i == 0 else j for i, j in zip(assistant_mask, input_id, strict=True)]
+ )
+
+ tokenized_dataset = Dataset.from_dict(
+ {"input_ids": all_input_ids, "labels": new_labels}
+ )
+
+ # Save tokenized dataset
+ chunk_name = f"chunk_{chunk_idx}_tokenized"
+ tokenized_path = os.path.join(self.output_dataset, chunk_name)
+ tokenized_dataset.save_to_disk(tokenized_path)
+
+ return tokenized_path
+
+ def pack_sequences(self, input_path, output_path):
+ packed_sequences = []
+ batch_input_ids = []
+ batch_labels = []
+ batch_len = 0
+ skipped_examples = []
+ eos_token_id = self.tokenizer.eos_token_id
+ pad_token_id = self.tokenizer.pad_token_id
+ tokenized = load_from_disk(input_path)
+ logger.debug("Processing tokenized dataset", input_path=input_path)
+ for n, example in tqdm(enumerate(tokenized), desc="Packing sequences"):
+ masking = True
+ if example["input_ids"] == example["labels"]:
+ masking = False
+
+ example_len = len(example["input_ids"])
+
+ # Account for separator token if appending concat token
+ sep_len = 1 if self.append_concat_token else 0
+ if example_len + sep_len > self.max_seq_length:
+ skipped_examples.append(n)
+ continue
+
+ if batch_len + example_len + sep_len > self.max_seq_length:
+ # Pad and add current batch
+ batch_input_ids.extend([pad_token_id] * (self.max_seq_length - batch_len))
+ if masking:
+ batch_labels.extend([-100] * (self.max_seq_length - batch_len))
+ else:
+ batch_labels.extend([pad_token_id] * (self.max_seq_length - batch_len))
+ packed_sequences.append({"input_ids": batch_input_ids, "labels": batch_labels})
+ batch_input_ids = []
+ batch_labels = []
+ batch_len = 0
+
+ batch_input_ids.extend(example["input_ids"])
+ batch_labels.extend(example["labels"])
+ if self.append_concat_token:
+ batch_input_ids.append(eos_token_id) # Add separator token
+ if masking:
+ batch_labels.append(-100)
+ else:
+ batch_labels.append(eos_token_id)
+ batch_len += example_len + 1
+ else:
+ batch_len += example_len
+
+ # Handle last batch if not empty
+ if batch_input_ids:
+ batch_input_ids.extend([pad_token_id] * (self.max_seq_length - batch_len))
+ if masking:
+ batch_labels.extend([-100] * (self.max_seq_length - batch_len))
+ else:
+ batch_labels.extend([pad_token_id] * (self.max_seq_length - batch_len))
+ packed_sequences.append({"input_ids": batch_input_ids, "labels": batch_labels})
+ logger.info(
+ "Skipped examples that exceeded max sequence length", num_skipped=len(skipped_examples)
+ )
+ del tokenized
+ packed_dataset = Dataset.from_list(packed_sequences)
+ packed_dataset.save_to_disk(output_path)
+ remove_path(input_path)
+ return output_path
+
+ def _pack_dataset_wrapper(self, paths):
+ input_path, output_path = paths
+ try:
+ self.pack_sequences(input_path, output_path)
+ return output_path
+ except Exception as e:
+ logger.error("Error packing dataset", input_path=input_path, error=str(e))
+ return None
+
+ def pack_datasets_sequentially(self, tokenized_paths):
+ """Pack datasets one at a time without parallel processing"""
+ packed_paths = []
+ for input_path in tokenized_paths:
+ output_path = input_path.replace("_tokenized", "_packed")
+ logger.info("Packing dataset", input_path=input_path)
+ try:
+ self.pack_sequences(input_path, output_path)
+ packed_paths.append(output_path)
+ except Exception as e:
+ logger.error("Error packing dataset", input_path=input_path, error=str(e))
+ return packed_paths
+
+ def run_parallel_packing(self, tokenized_paths):
+ # Prepare input/output path pairs
+ pack_args = [
+ (input_path, input_path.replace("_tokenized", "_packed"))
+ for input_path in tokenized_paths
+ ]
+
+ pool = Pool(processes=self.num_workers)
+ try:
+ completed_paths = list(
+ tqdm(
+ pool.imap(self._pack_dataset_wrapper, pack_args),
+ total=len(pack_args),
+ desc="Packing datasets",
+ )
+ )
+ finally:
+ pool.close()
+ pool.join()
+
+ # Filter out None values (failed packing attempts)
+ logger.debug("Completed packing paths", completed_paths=completed_paths)
+ return [path for path in completed_paths if path is not None]
+
+ def concatenate_final_dataset(self, packed_paths):
+ # Create a mapping of chunk index to dataset
+ chunk_datasets = {}
+ for path in packed_paths:
+ # Extract chunk index from path
+ chunk_idx = int(os.path.basename(path).split("_")[1])
+ chunk_datasets[chunk_idx] = load_from_disk(path)
+
+ # Load datasets in the original chunk order
+ datasets = [chunk_datasets[i] for i in range(len(chunk_datasets))]
+ final_dataset = concatenate_datasets(datasets)
+ final_path = os.path.join(self.output_dataset, "final_dataset")
+ dataset_dict = DatasetDict({"train": final_dataset})
+ dataset_dict.save_to_disk(final_path)
+ return final_path
+
+ def process(self):
+ # Step 1: Analyze and create chunks
+ logger.info("Analyzing directory sizes")
+ log_data = {}
+ sizes = self.get_directory_sizes(self.input_dataset)
+ log_data["sizes"] = sizes
+
+ splits = self.split_dataset(self.input_dataset, sizes)
+ log_data["splits"] = splits
+
+ chunks = self.create_chunks(splits)
+ log_data["chunks"] = chunks
+
+ # Step 2: Process each chunk (concatenate and tokenize)
+ tokenized_paths = []
+ for i, chunk in enumerate(chunks):
+ logger.info("Processing chunk", chunk_num=i + 1, total_chunks=len(chunks))
+ tokenized_path = self.process_chunk(chunk, i)
+ if tokenized_path:
+ tokenized_paths.append(tokenized_path)
+
+ log_data["tokenized"] = tokenized_paths
+
+ # Step 3: Pack datasets (parallel or sequential)
+ if self.parallel:
+ logger.info("Packing datasets in parallel")
+ packed_paths = self.run_parallel_packing(tokenized_paths)
+ else:
+ logger.info("Packing datasets sequentially")
+ packed_paths = self.pack_datasets_sequentially(tokenized_paths)
+
+ log_data["packed"] = packed_paths
+
+ # Step 4: Concatenate final dataset
+ logger.info("Concatenating final dataset")
+ final_path = self.concatenate_final_dataset(packed_paths)
+ for path in packed_paths:
+ remove_path(path)
+
+ log_data["final"] = final_path
+ self.log_data = log_data
+ logger.info("Processing complete! Final dataset saved", final_path=final_path)
+ return final_path
diff --git a/dalla_data_processing/packing/pack_config.example.yaml b/dalla_data_processing/packing/pack_config.example.yaml
new file mode 100644
index 0000000..ed73582
--- /dev/null
+++ b/dalla_data_processing/packing/pack_config.example.yaml
@@ -0,0 +1,11 @@
+tokenizer_path: "/path/to/tokenizer"
+max_seq_length: 2048
+chunk_size_gb: 2.0
+rbpe: false
+sft: false
+
+# text_column: "content"
+
+subset_order:
+ - "train"
+ - "validation"
diff --git a/dalla_data_processing/quality/README.md b/dalla_data_processing/quality/README.md
index 7d1fefa..5d6f33c 100644
--- a/dalla_data_processing/quality/README.md
+++ b/dalla_data_processing/quality/README.md
@@ -2,6 +2,15 @@
Check text quality using morphological analysis to detect errors and foreign words.
+## Installation
+
+This feature requires the `[quality]` extra:
+
+```bash
+pip install "dalla-data-processing[quality]"
+# or install all features: pip install "dalla-data-processing[all]"
+```
+
## CLI Usage
**Command:** `dalla-dp quality-check [OPTIONS]`
diff --git a/dalla_data_processing/readability/README.md b/dalla_data_processing/readability/README.md
index 1edb3d6..39b987a 100644
--- a/dalla_data_processing/readability/README.md
+++ b/dalla_data_processing/readability/README.md
@@ -2,6 +2,15 @@
Calculate readability scores using Flesch Reading Ease and Osman methods.
+## Installation
+
+This feature requires the `[readability]` extra:
+
+```bash
+pip install "dalla-data-processing[readability]"
+# or install all features: pip install "dalla-data-processing[all]"
+```
+
## CLI Usage
**Command:** `dalla-dp readability [OPTIONS]`
diff --git a/dalla_data_processing/stemming/README.md b/dalla_data_processing/stemming/README.md
index 3c8d206..5ad4f5a 100644
--- a/dalla_data_processing/stemming/README.md
+++ b/dalla_data_processing/stemming/README.md
@@ -2,6 +2,15 @@
Apply morphological analysis and stemming using CAMeL Tools.
+## Installation
+
+This feature requires the `[stem]` extra:
+
+```bash
+pip install "dalla-data-processing[stem]"
+# or install all features: pip install "dalla-data-processing[all]"
+```
+
## CLI Usage
**Command:** `dalla-dp stem [OPTIONS]`
diff --git a/dalla_data_processing/stemming/__init__.py b/dalla_data_processing/stemming/__init__.py
index febd264..cc12608 100644
--- a/dalla_data_processing/stemming/__init__.py
+++ b/dalla_data_processing/stemming/__init__.py
@@ -1,668 +1,9 @@
"""Stemming and morphological analysis module.
-This module is completely self-contained and does not depend on external scripts.
-All necessary functions are included here.
+This module provides Arabic stemming and morphological tokenization
+functionality using CAMeL Tools disambiguators.
"""
-import os
-import re
-from collections import deque
-from types import MethodType
-
-from camel_tools.data.catalogue import Catalogue
-from camel_tools.disambig.bert import BERTUnfactoredDisambiguator
-from camel_tools.disambig.mle import MLEDisambiguator
-from camel_tools.utils.dediac import dediac_ar
-from datasets import Dataset
-
-from dalla_data_processing.utils.logger import get_logger
-from dalla_data_processing.utils.tokenize import simple_word_tokenize
-
-logger = get_logger(__name__)
-
-
-def normalize_arabic(text: str) -> str:
- """Normalize Arabic text."""
- _DIAC_RE = re.compile(
- r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06DC\u06DF-\u06E8\u06EA-\u06ED\u08D3-\u08FF]"
- )
- _TATWEEL_RE = re.compile(r"\u0640")
- _ALIF_RE = re.compile(r"[آأإٱ]")
- _ALIF_MAK_RE = re.compile(r"ى")
- _TEH_MARB_RE = re.compile(r"ة")
-
- text = _DIAC_RE.sub("", text)
- text = _TATWEEL_RE.sub("", text)
- text = _ALIF_RE.sub("ا", text)
- text = _ALIF_MAK_RE.sub("ي", text)
- text = _TEH_MARB_RE.sub("ه", text)
- return text
-
-
-def has_diacritics(word):
- """Check if word has diacritics."""
- diacritic_marks = {
- "\u064b",
- "\u064c",
- "\u064d",
- "\u064e",
- "\u064f",
- "\u0650",
- "\u0651",
- "\u0652",
- "\u0670",
- }
- return any(char in diacritic_marks for char in word)
-
-
-def apply_diacritics_to_segments_keep_markers(segments, diacritized_word, sep_token="<+>"):
- """Apply diacritics from original word to segmented tokens."""
- result = []
- diacritic_marks = {
- "\u064b",
- "\u064c",
- "\u064d",
- "\u064e",
- "\u064f",
- "\u0650",
- "\u0651",
- "\u0652",
- "\u0670",
- }
- sep_len = len(sep_token)
-
- leading_diacritics = []
- i = 0
- while i < len(diacritized_word) and diacritized_word[i] in diacritic_marks:
- leading_diacritics.append(diacritized_word[i])
- i += 1
-
- diacritic_index = len(leading_diacritics)
-
- for segment_idx, segment in enumerate(segments):
- if segment == sep_token:
- result.append(segment)
- else:
- diacritized_segment = []
-
- if segment_idx == 0 and leading_diacritics:
- diacritized_segment.extend(leading_diacritics)
-
- i = 0
- while i < len(segment):
- char = segment[i]
- if segment[i : i + sep_len] == sep_token:
- diacritized_segment.append(sep_token)
- i += sep_len
- continue
-
- if diacritic_index < len(diacritized_word):
- while (
- diacritic_index < len(diacritized_word)
- and diacritized_word[diacritic_index] in diacritic_marks
- ):
- diacritic_index += 1
-
- if (
- diacritic_index < len(diacritized_word)
- and diacritized_word[diacritic_index] == char
- ):
- diacritized_segment.append(char)
- diacritic_index += 1
-
- while (
- diacritic_index < len(diacritized_word)
- and diacritized_word[diacritic_index] in diacritic_marks
- ):
- diacritized_segment.append(diacritized_word[diacritic_index])
- diacritic_index += 1
- else:
- diacritized_segment.append(char)
- else:
- diacritized_segment.append(char)
-
- i += 1
-
- result.append("".join(diacritized_segment))
-
- return result
-
-
-def read_and_dediacritize(file_name):
- """Read words from file and dediacritize them."""
- words = []
- with open(file_name, encoding="utf-8") as file:
- for line in file:
- word = line.strip()
- dediacritized_word = dediac_ar(word)
- words.append(dediacritized_word)
- return words
-
-
-def par_is_utf8_encoded(paragraph):
- """Check if paragraph is UTF-8 encoded."""
- try:
- paragraph.encode("utf-8")
- return True
- except UnicodeEncodeError:
- return False
-
-
-def tokenize(text):
- """Tokenize text into words."""
- if par_is_utf8_encoded(text):
- text_list = simple_word_tokenize(text)
- return text_list
- else:
- return None
-
-
-def merge_alef_and_alef_lam(input_list, sep_token="<+>"):
- """Merge specific Arabic morpheme patterns."""
- pattern = [f"\u0644{sep_token}".encode(), f"\u0627\u0644{sep_token}".encode()]
- replacement = f"\u0644\u0644{sep_token}"
-
- modified_list = []
- i = 0
-
- while i < len(input_list):
- if i < len(input_list) - 1:
- current_element = input_list[i].encode("utf-8")
- next_element = input_list[i + 1].encode("utf-8")
-
- if current_element == pattern[0] and next_element == pattern[1]:
- modified_list.append(replacement)
- i += 2
- continue
-
- modified_list.append(input_list[i])
- i += 1
-
- return modified_list
-
-
-def process_NOAN_word(list_al_t, list_al, list_t, word, sep_token="<+>"):
- """Process words marked as NOAN (no analysis)."""
- alef_lam = b"\xd8\xa7\xd9\x84"
- taa_marbouta_detached = b"\xef\xba\x93"
- taa_marbouta_attached = b"\xd8\xa9"
- word_bytes = word.encode("utf-8")
-
- if (
- word_bytes.startswith(alef_lam)
- and (
- word_bytes.endswith(taa_marbouta_detached) or word_bytes.endswith(taa_marbouta_attached)
- )
- and word in list_al_t
- ):
- stripped_word = word[2:-1]
- first_part = word[0:2] + sep_token
- last_part = sep_token + word[-1]
- return [first_part, stripped_word, last_part]
-
- if word_bytes.startswith(alef_lam) and word in list_al:
- stripped_word = word[2:]
- first_part = word[0:2] + sep_token
- return [first_part, stripped_word]
-
- if word_bytes.endswith(taa_marbouta_detached) or word_bytes.endswith(taa_marbouta_attached):
- if word in list_t:
- stripped_word = word[:-1]
- last_part = sep_token + word[-1]
- return [stripped_word, last_part]
-
- return [word]
-
-
-def merge_tokens(tokens, original_word, sep_token="<+>"):
- """Merge tokenized segments back into a word."""
- parts = []
- sep_len = len(sep_token)
- for tok in tokens:
- if tok == sep_token:
- parts.append("_")
- elif tok.endswith(sep_token):
- tok = tok[:-sep_len]
- parts.append(tok)
- elif tok.startswith(sep_token):
- tok = tok[sep_len:]
- parts.append(tok)
- elif tok.endswith("+"):
- tok = tok[:-1]
- parts.append(tok)
- elif tok.startswith("+"):
- tok = tok[1:]
- parts.append(tok)
- else:
- parts.append(tok)
-
- merged_word = "".join(parts)
- return merged_word
-
-
-def split_token_on_t(list_toks, sep_token="<+>"):
- """Split tokens on taa marbouta character."""
- new_list = []
- taa_marbouta_detached = b"\xef\xba\x93"
- taa_marbouta_attached = b"\xd8\xa9"
- haa_attached = b"\xd9\x87"
-
- for token in list_toks:
- token_bytes = token.encode("utf-8")
- if (
- token_bytes.endswith(taa_marbouta_detached)
- or token_bytes.endswith(taa_marbouta_attached)
- or token_bytes.endswith(haa_attached)
- ):
- if token_bytes == b"\xd9\x87":
- token = sep_token + taa_marbouta_attached.decode("utf-8")
- new_list.append(token)
- else:
- part1 = token[:-1]
- part2 = sep_token + token[-1]
- new_list.append(part1)
- new_list.append(part2)
- else:
- new_list.append(token)
-
- return new_list
-
-
-def replace_separator(toks, sep_token="<+>"):
- """Replace + with sep_token in tokens."""
- result = list(toks)
-
- for i, tok in enumerate(result):
- if tok.startswith("+"):
- result[i] = sep_token + tok[1:]
- if tok.endswith("+"):
- result[i] = tok[:-1] + sep_token
- return result
-
-
-def morph_tokenize(
- words, disambiguator, list_al_t, list_al, list_t, scheme="d3tok", split=True, sep_token="<+>"
-):
- """Generate morphological tokens for a list of words."""
- disambig_words = disambiguator.disambiguate(words)
- result = deque()
- err_disambig = []
- err_camel = []
- has_diacritics_in_par = False
-
- for original, disambig_word in zip(words, disambig_words, strict=False):
- scored_analyses = disambig_word.analyses
- original_word = original
- dediac_word = dediac_ar(original_word)
-
- if has_diacritics(original_word):
- has_diacritics_in_par = True
-
- if not scored_analyses:
- result.append(original_word)
- continue
-
- analysis = scored_analyses[0].analysis
- tok = dediac_ar(analysis.get(scheme, None))
- tok_bw = dediac_ar(analysis.get("bwtok", None))
- seg_d3 = dediac_ar(analysis.get("d3seg", None))
-
- taa_marbouta_detached = b"\xef\xba\x93"
- taa_marbouta_attached = b"\xd8\xa9"
- original_word_bytes = dediac_word.encode("utf-8")
-
- if original_word_bytes.endswith(taa_marbouta_attached) or original_word_bytes.endswith(
- taa_marbouta_detached
- ):
- if "+ة_+" in tok_bw or "+ه" in tok_bw or "+ة" in tok_bw:
- toks = tok.split("_")
- toks = split_token_on_t(toks, sep_token)
- toks = replace_separator(toks, sep_token)
- toks = merge_alef_and_alef_lam(toks, sep_token)
- merged_toks = dediac_ar(merge_tokens(toks, dediac_word, sep_token))
-
- d3_seg_tok = seg_d3.split("_")
- d3_seg_tok = split_token_on_t(d3_seg_tok, sep_token)
- d3_seg_tok = replace_separator(d3_seg_tok, sep_token)
- d3_seg_tok = merge_alef_and_alef_lam(d3_seg_tok, sep_token)
- merged_toks_seg = dediac_ar(merge_tokens(d3_seg_tok, dediac_word, sep_token))
-
- bw_toks = tok_bw.split("_")
- bw_toks = split_token_on_t(bw_toks, sep_token)
- bw_toks = replace_separator(bw_toks, sep_token)
- bw_toks = merge_alef_and_alef_lam(bw_toks, sep_token)
- merged_toks_bw = dediac_ar(merge_tokens(bw_toks, dediac_word, sep_token))
-
- if merged_toks == dediac_word and len(toks) > 1:
- if has_diacritics(original):
- toks = apply_diacritics_to_segments_keep_markers(toks, original, sep_token)
- result.extend(toks)
- continue
-
- elif merged_toks_seg == dediac_word and len(d3_seg_tok) > 1:
- if has_diacritics(original):
- d3_seg_tok = apply_diacritics_to_segments_keep_markers(
- d3_seg_tok, original, sep_token
- )
- result.extend(d3_seg_tok)
- continue
-
- elif merged_toks_bw == dediac_word and len(bw_toks) > 1:
- if has_diacritics(original):
- bw_toks = apply_diacritics_to_segments_keep_markers(
- bw_toks, original, sep_token
- )
- result.extend(bw_toks)
- continue
-
- else:
- result.append(original_word)
- err_disambig.append(dediac_word)
- err_camel.append(merged_toks)
- continue
-
- if tok is None or "NOAN" in tok:
- tok = process_NOAN_word(list_al_t, list_al, list_t, dediac_word, sep_token)
- if has_diacritics(original):
- toks = apply_diacritics_to_segments_keep_markers(tok, original, sep_token)
- else:
- toks = tok
- result.extend(toks)
-
- elif split:
- tok = dediac_ar(tok)
- toks = tok.split("_")
- toks = replace_separator(toks, sep_token)
- toks = merge_alef_and_alef_lam(toks, sep_token)
- merged_toks = dediac_ar(merge_tokens(toks, dediac_word, sep_token))
-
- bw_toks = tok_bw.split("_")
- bw_toks = replace_separator(bw_toks, sep_token)
- bw_toks = merge_alef_and_alef_lam(bw_toks, sep_token)
- merged_toks_bw = dediac_ar(merge_tokens(bw_toks, dediac_word, sep_token))
-
- d3_seg_tok = seg_d3.split("_")
- d3_seg_tok = replace_separator(d3_seg_tok, sep_token)
- d3_seg_tok = merge_alef_and_alef_lam(d3_seg_tok, sep_token)
- merged_toks_seg = dediac_ar(merge_tokens(d3_seg_tok, dediac_word, sep_token))
-
- if merged_toks == dediac_word and len(toks) > 1:
- if has_diacritics(original):
- toks = apply_diacritics_to_segments_keep_markers(toks, original, sep_token)
- result.extend(toks)
- elif merged_toks_seg == dediac_word and len(d3_seg_tok) > 1:
- if has_diacritics(original):
- d3_seg_tok = apply_diacritics_to_segments_keep_markers(
- d3_seg_tok, original, sep_token
- )
- result.extend(d3_seg_tok)
- elif merged_toks_bw == dediac_word and len(bw_toks) > 1:
- if has_diacritics(original):
- bw_toks = apply_diacritics_to_segments_keep_markers(
- bw_toks, original, sep_token
- )
- result.extend(bw_toks)
- else:
- result.append(original_word)
- err_disambig.append(dediac_word)
- err_camel.append(merged_toks)
-
- else:
- tok = dediac_ar(tok)
- if tok == dediac_word:
- result.append(original_word)
- else:
- result.append(original_word)
- err_disambig.append(dediac_word)
- err_camel.append(tok)
-
- return list(result), err_disambig, err_camel, has_diacritics_in_par
-
-
-def stem_dataset(
- dataset: Dataset,
- column: str = "text",
- sep_token: str = "<+>",
- normalize: bool = False,
- keep_diacritics: bool = True,
- num_proc: int | None = None,
- model: str = "mle",
- use_gpu: bool = False,
-) -> Dataset:
- """
- Apply stemming and morphological analysis to dataset.
-
- Args:
- dataset: HuggingFace dataset
- column: Column to process
- sep_token: Separator token for morphological splits (default: '<+>')
- normalize: Apply Arabic normalization (default: False)
- keep_diacritics: Keep dediacritized column (default: True)
- num_proc: Number of parallel processes
- model: Disambiguator model to use - "mle" or "bert" (default: "mle")
- use_gpu: Whether to use GPU for BERT model (default: False)
-
- Returns:
- Dataset with {column}_stemmed and optionally {column}_dediac columns
-
- Example:
- >>> # Stem with defaults (MLE, keeps diacritics)
- >>> stemmed = stem_dataset(dataset)
- >>> # Result has 'text_stemmed' and 'text_dediac' columns
-
- >>> # Stem using BERT with GPU
- >>> stemmed = stem_dataset(dataset, model="bert", use_gpu=True)
-
- >>> # Stem without keeping diacritics
- >>> stemmed = stem_dataset(dataset, keep_diacritics=False)
- >>> # Result has only 'text_stemmed' column
- """
- model = model.lower()
- if model not in ["mle", "bert"]:
- raise ValueError(f"Invalid model '{model}'. Must be 'mle' or 'bert'")
-
- logger.info(f"Starting stemming of {len(dataset)} examples")
- logger.info(
- f"Model: {model.upper()}, Column: {column}, Sep token: {sep_token}, Normalize: {normalize}"
- )
- logger.info(f"Keep diacritics: {keep_diacritics}, Workers: {num_proc or 'auto'}")
- if model == "bert":
- logger.info(f"GPU: {use_gpu}")
-
- logger.info("Checking CAMeL Tools data packages...")
- catalogue = Catalogue.load_catalogue()
- try:
- catalogue.download_package("morphology-db-msa-r13")
- if model == "mle":
- catalogue.download_package("disambig-mle-calima-msa-r13")
- # For BERT, let it download automatically when pretrained() is called
- logger.info("CAMeL Tools data packages ready")
- except Exception as e:
- logger.warning(f"Could not verify CAMeL packages: {e}")
-
- logger.info("Loading additional words lists...")
- words_dir = os.path.join(os.path.dirname(__file__), "data")
- list_al_t = set(read_and_dediacritize(os.path.join(words_dir, "words_al_t.txt")))
- list_al = set(read_and_dediacritize(os.path.join(words_dir, "words_al.txt")))
- list_t = set(read_and_dediacritize(os.path.join(words_dir, "words_t.txt")))
- logger.info("Loaded word list entries")
-
- logger.info(f"Initializing {model.upper()} disambiguator...")
- if model == "mle":
- disambiguator = MLEDisambiguator.pretrained("calima-msa-r13", cache_size=1_000_000)
- else: # bert
- disambiguator = BERTUnfactoredDisambiguator.pretrained(use_gpu=use_gpu)
- logger.info("Disambiguator ready")
-
- def new_scored_analysis(self, word_dd):
- if word_dd in self._cache:
- return self._cache[word_dd]
- result = self._scored_analyses(word_dd)
- self._cache[word_dd] = result
- return result
-
- disambiguator._scored_analyses_cached = MethodType(new_scored_analysis, disambiguator)
- disambiguator._score_fn = disambiguator._scored_analyses_cached
-
- def process_row(row):
- text = row.get(column, "")
- if not text:
- row[f"{column}_stemmed"] = ""
- if keep_diacritics:
- row[f"{column}_dediac"] = ""
- return row
-
- word_list = tokenize(text)
- if word_list is None:
- row[f"{column}_stemmed"] = text
- if keep_diacritics:
- row[f"{column}_dediac"] = dediac_ar(text)
- return row
-
- tokenized, _, _, has_diacs = morph_tokenize(
- word_list, disambiguator, list_al_t, list_al, list_t, sep_token=sep_token
- )
-
- if tokenized is not None:
- tokenized = merge_alef_and_alef_lam(tokenized, sep_token)
- stemmed = "".join(tokenized)
-
- if normalize:
- stemmed = normalize_arabic(stemmed)
-
- row[f"{column}_stemmed"] = stemmed
-
- if keep_diacritics:
- row[f"{column}_dediac"] = dediac_ar(stemmed)
- else:
- row[f"{column}_stemmed"] = text
- if keep_diacritics:
- row[f"{column}_dediac"] = dediac_ar(text)
-
- return row
-
- logger.info("Starting morphological tokenization...")
- result = dataset.map(process_row, num_proc=num_proc, desc="Stemming")
-
- logger.info(f"Stemming complete! Processed {len(result)} examples")
- return result
-
-
-def stem(
- text: str | list[str],
- sep_token: str = "<+>",
- normalize: bool = False,
- keep_diacritics: bool = False,
- model: str = "mle",
- use_gpu: bool = False,
-) -> str | list[str]:
- """
- Stem Arabic text or list of texts.
-
- Args:
- text: Single string or list of strings to stem
- sep_token: Separator token for morphological splits (default: '<+>')
- normalize: Apply Arabic normalization (default: False)
- keep_diacritics: Keep diacritics in output (default: False)
- model: Disambiguator model to use - "mle" or "bert" (default: "mle")
- use_gpu: Whether to use GPU for BERT model (default: False)
-
- Returns:
- Stemmed text in the same format as input (string or list of strings)
-
- Example:
- >>> # Stem a single string
- >>> stemmed = stem("النص العربي")
- >>> # Returns: "ال<+>نص ال<+>عربي"
-
- >>> # Stem a list of strings
- >>> stemmed = stem(["النص العربي", "مثال آخر"])
- >>> # Returns: ["ال<+>نص ال<+>عربي", "مثال آخر"]
-
- >>> # Stem with BERT model and GPU
- >>> stemmed = stem("النص", model="bert", use_gpu=True)
- """
- # Validate model parameter
- model = model.lower()
- if model not in ["mle", "bert"]:
- raise ValueError(f"Invalid model '{model}'. Must be 'mle' or 'bert'")
-
- # Track whether input was a single string
- is_single_string = isinstance(text, str)
-
- # Convert single string to list for uniform processing
- text_list = [text] if is_single_string else text
-
- # Validate all items are strings
- if not all(isinstance(t, str) for t in text_list):
- raise TypeError("All items in text list must be strings")
-
- # Initialize disambiguator (cached globally if possible)
- logger.info(f"Initializing {model.upper()} disambiguator...")
- catalogue = Catalogue.load_catalogue()
- try:
- catalogue.download_package("morphology-db-msa-r13")
- if model == "mle":
- catalogue.download_package("disambig-mle-calima-msa-r13")
- except Exception as e:
- logger.warning(f"Could not verify CAMeL packages: {e}")
-
- if model == "mle":
- disambiguator = MLEDisambiguator.pretrained("calima-msa-r13", cache_size=1_000_000)
- else: # bert
- disambiguator = BERTUnfactoredDisambiguator.pretrained(use_gpu=use_gpu)
-
- # Add caching to disambiguator
- def new_scored_analysis(self, word_dd):
- if word_dd in self._cache:
- return self._cache[word_dd]
- result = self._scored_analyses(word_dd)
- self._cache[word_dd] = result
- return result
-
- disambiguator._scored_analyses_cached = MethodType(new_scored_analysis, disambiguator)
- disambiguator._score_fn = disambiguator._scored_analyses_cached
-
- # Load word lists
- words_dir = os.path.join(os.path.dirname(__file__), "data")
- list_al_t = set(read_and_dediacritize(os.path.join(words_dir, "words_al_t.txt")))
- list_al = set(read_and_dediacritize(os.path.join(words_dir, "words_al.txt")))
- list_t = set(read_and_dediacritize(os.path.join(words_dir, "words_t.txt")))
-
- # Process each text
- results = []
- for txt in text_list:
- if not txt:
- results.append("")
- continue
-
- word_list = tokenize(txt)
- if word_list is None:
- stemmed = dediac_ar(txt) if not keep_diacritics else txt
- results.append(stemmed)
- continue
-
- tokenized, _, _, has_diacs = morph_tokenize(
- word_list, disambiguator, list_al_t, list_al, list_t, sep_token=sep_token
- )
-
- if tokenized is not None:
- tokenized = merge_alef_and_alef_lam(tokenized, sep_token)
- stemmed = "".join(tokenized)
-
- if normalize:
- stemmed = normalize_arabic(stemmed)
-
- if not keep_diacritics:
- stemmed = dediac_ar(stemmed)
-
- results.append(stemmed)
- else:
- stemmed = dediac_ar(txt) if not keep_diacritics else txt
- results.append(stemmed)
-
- # Return in the same format as input
- return results[0] if is_single_string else results
-
+from dalla_data_processing.stemming.stemmer import stem, stem_dataset
__all__ = ["stem_dataset", "stem"]
diff --git a/dalla_data_processing/stemming/stemmer.py b/dalla_data_processing/stemming/stemmer.py
new file mode 100644
index 0000000..8bed063
--- /dev/null
+++ b/dalla_data_processing/stemming/stemmer.py
@@ -0,0 +1,665 @@
+"""Stemming and morphological analysis implementation.
+
+This module contains all the implementation details for Arabic stemming
+and morphological tokenization using CAMeL Tools.
+"""
+
+import os
+import re
+from collections import deque
+from types import MethodType
+
+from camel_tools.data.catalogue import Catalogue
+from camel_tools.disambig.bert import BERTUnfactoredDisambiguator
+from camel_tools.disambig.mle import MLEDisambiguator
+from camel_tools.utils.dediac import dediac_ar
+from datasets import Dataset
+
+from dalla_data_processing.utils.logger import get_logger
+from dalla_data_processing.utils.tokenize import simple_word_tokenize
+
+logger = get_logger(__name__)
+
+
+def normalize_arabic(text: str) -> str:
+ """Normalize Arabic text."""
+ _DIAC_RE = re.compile(
+ r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06DC\u06DF-\u06E8\u06EA-\u06ED\u08D3-\u08FF]"
+ )
+ _TATWEEL_RE = re.compile(r"\u0640")
+ _ALIF_RE = re.compile(r"[آأإٱ]")
+ _ALIF_MAK_RE = re.compile(r"ى")
+ _TEH_MARB_RE = re.compile(r"ة")
+
+ text = _DIAC_RE.sub("", text)
+ text = _TATWEEL_RE.sub("", text)
+ text = _ALIF_RE.sub("ا", text)
+ text = _ALIF_MAK_RE.sub("ي", text)
+ text = _TEH_MARB_RE.sub("ه", text)
+ return text
+
+
+def has_diacritics(word):
+ """Check if word has diacritics."""
+ diacritic_marks = {
+ "\u064b",
+ "\u064c",
+ "\u064d",
+ "\u064e",
+ "\u064f",
+ "\u0650",
+ "\u0651",
+ "\u0652",
+ "\u0670",
+ }
+ return any(char in diacritic_marks for char in word)
+
+
+def apply_diacritics_to_segments_keep_markers(segments, diacritized_word, sep_token="<+>"):
+ """Apply diacritics from original word to segmented tokens."""
+ result = []
+ diacritic_marks = {
+ "\u064b",
+ "\u064c",
+ "\u064d",
+ "\u064e",
+ "\u064f",
+ "\u0650",
+ "\u0651",
+ "\u0652",
+ "\u0670",
+ }
+ sep_len = len(sep_token)
+
+ leading_diacritics = []
+ i = 0
+ while i < len(diacritized_word) and diacritized_word[i] in diacritic_marks:
+ leading_diacritics.append(diacritized_word[i])
+ i += 1
+
+ diacritic_index = len(leading_diacritics)
+
+ for segment_idx, segment in enumerate(segments):
+ if segment == sep_token:
+ result.append(segment)
+ else:
+ diacritized_segment = []
+
+ if segment_idx == 0 and leading_diacritics:
+ diacritized_segment.extend(leading_diacritics)
+
+ i = 0
+ while i < len(segment):
+ char = segment[i]
+ if segment[i : i + sep_len] == sep_token:
+ diacritized_segment.append(sep_token)
+ i += sep_len
+ continue
+
+ if diacritic_index < len(diacritized_word):
+ while (
+ diacritic_index < len(diacritized_word)
+ and diacritized_word[diacritic_index] in diacritic_marks
+ ):
+ diacritic_index += 1
+
+ if (
+ diacritic_index < len(diacritized_word)
+ and diacritized_word[diacritic_index] == char
+ ):
+ diacritized_segment.append(char)
+ diacritic_index += 1
+
+ while (
+ diacritic_index < len(diacritized_word)
+ and diacritized_word[diacritic_index] in diacritic_marks
+ ):
+ diacritized_segment.append(diacritized_word[diacritic_index])
+ diacritic_index += 1
+ else:
+ diacritized_segment.append(char)
+ else:
+ diacritized_segment.append(char)
+
+ i += 1
+
+ result.append("".join(diacritized_segment))
+
+ return result
+
+
+def read_and_dediacritize(file_name):
+ """Read words from file and dediacritize them."""
+ words = []
+ with open(file_name, encoding="utf-8") as file:
+ for line in file:
+ word = line.strip()
+ dediacritized_word = dediac_ar(word)
+ words.append(dediacritized_word)
+ return words
+
+
+def par_is_utf8_encoded(paragraph):
+ """Check if paragraph is UTF-8 encoded."""
+ try:
+ paragraph.encode("utf-8")
+ return True
+ except UnicodeEncodeError:
+ return False
+
+
+def tokenize(text):
+ """Tokenize text into words."""
+ if par_is_utf8_encoded(text):
+ text_list = simple_word_tokenize(text)
+ return text_list
+ else:
+ return None
+
+
+def merge_alef_and_alef_lam(input_list, sep_token="<+>"):
+ """Merge specific Arabic morpheme patterns."""
+ pattern = [f"\u0644{sep_token}".encode(), f"\u0627\u0644{sep_token}".encode()]
+ replacement = f"\u0644\u0644{sep_token}"
+
+ modified_list = []
+ i = 0
+
+ while i < len(input_list):
+ if i < len(input_list) - 1:
+ current_element = input_list[i].encode("utf-8")
+ next_element = input_list[i + 1].encode("utf-8")
+
+ if current_element == pattern[0] and next_element == pattern[1]:
+ modified_list.append(replacement)
+ i += 2
+ continue
+
+ modified_list.append(input_list[i])
+ i += 1
+
+ return modified_list
+
+
+def process_NOAN_word(list_al_t, list_al, list_t, word, sep_token="<+>"):
+ """Process words marked as NOAN (no analysis)."""
+ alef_lam = b"\xd8\xa7\xd9\x84"
+ taa_marbouta_detached = b"\xef\xba\x93"
+ taa_marbouta_attached = b"\xd8\xa9"
+ word_bytes = word.encode("utf-8")
+
+ if (
+ word_bytes.startswith(alef_lam)
+ and (
+ word_bytes.endswith(taa_marbouta_detached) or word_bytes.endswith(taa_marbouta_attached)
+ )
+ and word in list_al_t
+ ):
+ stripped_word = word[2:-1]
+ first_part = word[0:2] + sep_token
+ last_part = sep_token + word[-1]
+ return [first_part, stripped_word, last_part]
+
+ if word_bytes.startswith(alef_lam) and word in list_al:
+ stripped_word = word[2:]
+ first_part = word[0:2] + sep_token
+ return [first_part, stripped_word]
+
+ if word_bytes.endswith(taa_marbouta_detached) or word_bytes.endswith(taa_marbouta_attached):
+ if word in list_t:
+ stripped_word = word[:-1]
+ last_part = sep_token + word[-1]
+ return [stripped_word, last_part]
+
+ return [word]
+
+
+def merge_tokens(tokens, original_word, sep_token="<+>"):
+ """Merge tokenized segments back into a word."""
+ parts = []
+ sep_len = len(sep_token)
+ for tok in tokens:
+ if tok == sep_token:
+ parts.append("_")
+ elif tok.endswith(sep_token):
+ tok = tok[:-sep_len]
+ parts.append(tok)
+ elif tok.startswith(sep_token):
+ tok = tok[sep_len:]
+ parts.append(tok)
+ elif tok.endswith("+"):
+ tok = tok[:-1]
+ parts.append(tok)
+ elif tok.startswith("+"):
+ tok = tok[1:]
+ parts.append(tok)
+ else:
+ parts.append(tok)
+
+ merged_word = "".join(parts)
+ return merged_word
+
+
+def split_token_on_t(list_toks, sep_token="<+>"):
+ """Split tokens on taa marbouta character."""
+ new_list = []
+ taa_marbouta_detached = b"\xef\xba\x93"
+ taa_marbouta_attached = b"\xd8\xa9"
+ haa_attached = b"\xd9\x87"
+
+ for token in list_toks:
+ token_bytes = token.encode("utf-8")
+ if (
+ token_bytes.endswith(taa_marbouta_detached)
+ or token_bytes.endswith(taa_marbouta_attached)
+ or token_bytes.endswith(haa_attached)
+ ):
+ if token_bytes == b"\xd9\x87":
+ token = sep_token + taa_marbouta_attached.decode("utf-8")
+ new_list.append(token)
+ else:
+ part1 = token[:-1]
+ part2 = sep_token + token[-1]
+ new_list.append(part1)
+ new_list.append(part2)
+ else:
+ new_list.append(token)
+
+ return new_list
+
+
+def replace_separator(toks, sep_token="<+>"):
+ """Replace + with sep_token in tokens."""
+ result = list(toks)
+
+ for i, tok in enumerate(result):
+ if tok.startswith("+"):
+ result[i] = sep_token + tok[1:]
+ if tok.endswith("+"):
+ result[i] = tok[:-1] + sep_token
+ return result
+
+
+def morph_tokenize(
+ words, disambiguator, list_al_t, list_al, list_t, scheme="d3tok", split=True, sep_token="<+>"
+):
+ """Generate morphological tokens for a list of words."""
+ disambig_words = disambiguator.disambiguate(words)
+ result = deque()
+ err_disambig = []
+ err_camel = []
+ has_diacritics_in_par = False
+
+ for original, disambig_word in zip(words, disambig_words, strict=False):
+ scored_analyses = disambig_word.analyses
+ original_word = original
+ dediac_word = dediac_ar(original_word)
+
+ if has_diacritics(original_word):
+ has_diacritics_in_par = True
+
+ if not scored_analyses:
+ result.append(original_word)
+ continue
+
+ analysis = scored_analyses[0].analysis
+ tok = dediac_ar(analysis.get(scheme, None))
+ tok_bw = dediac_ar(analysis.get("bwtok", None))
+ seg_d3 = dediac_ar(analysis.get("d3seg", None))
+
+ taa_marbouta_detached = b"\xef\xba\x93"
+ taa_marbouta_attached = b"\xd8\xa9"
+ original_word_bytes = dediac_word.encode("utf-8")
+
+ if original_word_bytes.endswith(taa_marbouta_attached) or original_word_bytes.endswith(
+ taa_marbouta_detached
+ ):
+ if "+ة_+" in tok_bw or "+ه" in tok_bw or "+ة" in tok_bw:
+ toks = tok.split("_")
+ toks = split_token_on_t(toks, sep_token)
+ toks = replace_separator(toks, sep_token)
+ toks = merge_alef_and_alef_lam(toks, sep_token)
+ merged_toks = dediac_ar(merge_tokens(toks, dediac_word, sep_token))
+
+ d3_seg_tok = seg_d3.split("_")
+ d3_seg_tok = split_token_on_t(d3_seg_tok, sep_token)
+ d3_seg_tok = replace_separator(d3_seg_tok, sep_token)
+ d3_seg_tok = merge_alef_and_alef_lam(d3_seg_tok, sep_token)
+ merged_toks_seg = dediac_ar(merge_tokens(d3_seg_tok, dediac_word, sep_token))
+
+ bw_toks = tok_bw.split("_")
+ bw_toks = split_token_on_t(bw_toks, sep_token)
+ bw_toks = replace_separator(bw_toks, sep_token)
+ bw_toks = merge_alef_and_alef_lam(bw_toks, sep_token)
+ merged_toks_bw = dediac_ar(merge_tokens(bw_toks, dediac_word, sep_token))
+
+ if merged_toks == dediac_word and len(toks) > 1:
+ if has_diacritics(original):
+ toks = apply_diacritics_to_segments_keep_markers(toks, original, sep_token)
+ result.extend(toks)
+ continue
+
+ elif merged_toks_seg == dediac_word and len(d3_seg_tok) > 1:
+ if has_diacritics(original):
+ d3_seg_tok = apply_diacritics_to_segments_keep_markers(
+ d3_seg_tok, original, sep_token
+ )
+ result.extend(d3_seg_tok)
+ continue
+
+ elif merged_toks_bw == dediac_word and len(bw_toks) > 1:
+ if has_diacritics(original):
+ bw_toks = apply_diacritics_to_segments_keep_markers(
+ bw_toks, original, sep_token
+ )
+ result.extend(bw_toks)
+ continue
+
+ else:
+ result.append(original_word)
+ err_disambig.append(dediac_word)
+ err_camel.append(merged_toks)
+ continue
+
+ if tok is None or "NOAN" in tok:
+ tok = process_NOAN_word(list_al_t, list_al, list_t, dediac_word, sep_token)
+ if has_diacritics(original):
+ toks = apply_diacritics_to_segments_keep_markers(tok, original, sep_token)
+ else:
+ toks = tok
+ result.extend(toks)
+
+ elif split:
+ tok = dediac_ar(tok)
+ toks = tok.split("_")
+ toks = replace_separator(toks, sep_token)
+ toks = merge_alef_and_alef_lam(toks, sep_token)
+ merged_toks = dediac_ar(merge_tokens(toks, dediac_word, sep_token))
+
+ bw_toks = tok_bw.split("_")
+ bw_toks = replace_separator(bw_toks, sep_token)
+ bw_toks = merge_alef_and_alef_lam(bw_toks, sep_token)
+ merged_toks_bw = dediac_ar(merge_tokens(bw_toks, dediac_word, sep_token))
+
+ d3_seg_tok = seg_d3.split("_")
+ d3_seg_tok = replace_separator(d3_seg_tok, sep_token)
+ d3_seg_tok = merge_alef_and_alef_lam(d3_seg_tok, sep_token)
+ merged_toks_seg = dediac_ar(merge_tokens(d3_seg_tok, dediac_word, sep_token))
+
+ if merged_toks == dediac_word and len(toks) > 1:
+ if has_diacritics(original):
+ toks = apply_diacritics_to_segments_keep_markers(toks, original, sep_token)
+ result.extend(toks)
+ elif merged_toks_seg == dediac_word and len(d3_seg_tok) > 1:
+ if has_diacritics(original):
+ d3_seg_tok = apply_diacritics_to_segments_keep_markers(
+ d3_seg_tok, original, sep_token
+ )
+ result.extend(d3_seg_tok)
+ elif merged_toks_bw == dediac_word and len(bw_toks) > 1:
+ if has_diacritics(original):
+ bw_toks = apply_diacritics_to_segments_keep_markers(
+ bw_toks, original, sep_token
+ )
+ result.extend(bw_toks)
+ else:
+ result.append(original_word)
+ err_disambig.append(dediac_word)
+ err_camel.append(merged_toks)
+
+ else:
+ tok = dediac_ar(tok)
+ if tok == dediac_word:
+ result.append(original_word)
+ else:
+ result.append(original_word)
+ err_disambig.append(dediac_word)
+ err_camel.append(tok)
+
+ return list(result), err_disambig, err_camel, has_diacritics_in_par
+
+
+def stem_dataset(
+ dataset: Dataset,
+ column: str = "text",
+ sep_token: str = "<+>",
+ normalize: bool = False,
+ keep_diacritics: bool = True,
+ num_proc: int | None = None,
+ model: str = "mle",
+ use_gpu: bool = False,
+) -> Dataset:
+ """
+ Apply stemming and morphological analysis to dataset.
+
+ Args:
+ dataset: HuggingFace dataset
+ column: Column to process
+ sep_token: Separator token for morphological splits (default: '<+>')
+ normalize: Apply Arabic normalization (default: False)
+ keep_diacritics: Keep dediacritized column (default: True)
+ num_proc: Number of parallel processes
+ model: Disambiguator model to use - "mle" or "bert" (default: "mle")
+ use_gpu: Whether to use GPU for BERT model (default: False)
+
+ Returns:
+ Dataset with {column}_stemmed and optionally {column}_dediac columns
+
+ Example:
+ >>> # Stem with defaults (MLE, keeps diacritics)
+ >>> stemmed = stem_dataset(dataset)
+ >>> # Result has 'text_stemmed' and 'text_dediac' columns
+
+ >>> # Stem using BERT with GPU
+ >>> stemmed = stem_dataset(dataset, model="bert", use_gpu=True)
+
+ >>> # Stem without keeping diacritics
+ >>> stemmed = stem_dataset(dataset, keep_diacritics=False)
+ >>> # Result has only 'text_stemmed' column
+ """
+ model = model.lower()
+ if model not in ["mle", "bert"]:
+ raise ValueError(f"Invalid model '{model}'. Must be 'mle' or 'bert'")
+
+ logger.info(f"Starting stemming of {len(dataset)} examples")
+ logger.info(
+ f"Model: {model.upper()}, Column: {column}, Sep token: {sep_token}, Normalize: {normalize}"
+ )
+ logger.info(f"Keep diacritics: {keep_diacritics}, Workers: {num_proc or 'auto'}")
+ if model == "bert":
+ logger.info(f"GPU: {use_gpu}")
+
+ logger.info("Checking CAMeL Tools data packages...")
+ catalogue = Catalogue.load_catalogue()
+ try:
+ catalogue.download_package("morphology-db-msa-r13")
+ if model == "mle":
+ catalogue.download_package("disambig-mle-calima-msa-r13")
+ # For BERT, let it download automatically when pretrained() is called
+ logger.info("CAMeL Tools data packages ready")
+ except Exception as e:
+ logger.warning(f"Could not verify CAMeL packages: {e}")
+
+ logger.info("Loading additional words lists...")
+ words_dir = os.path.join(os.path.dirname(__file__), "data")
+ list_al_t = set(read_and_dediacritize(os.path.join(words_dir, "words_al_t.txt")))
+ list_al = set(read_and_dediacritize(os.path.join(words_dir, "words_al.txt")))
+ list_t = set(read_and_dediacritize(os.path.join(words_dir, "words_t.txt")))
+ logger.info("Loaded word list entries")
+
+ logger.info(f"Initializing {model.upper()} disambiguator...")
+ if model == "mle":
+ disambiguator = MLEDisambiguator.pretrained("calima-msa-r13", cache_size=1_000_000)
+ else: # bert
+ disambiguator = BERTUnfactoredDisambiguator.pretrained(use_gpu=use_gpu)
+ logger.info("Disambiguator ready")
+
+ def new_scored_analysis(self, word_dd):
+ if word_dd in self._cache:
+ return self._cache[word_dd]
+ result = self._scored_analyses(word_dd)
+ self._cache[word_dd] = result
+ return result
+
+ disambiguator._scored_analyses_cached = MethodType(new_scored_analysis, disambiguator)
+ disambiguator._score_fn = disambiguator._scored_analyses_cached
+
+ def process_row(row):
+ text = row.get(column, "")
+ if not text:
+ row[f"{column}_stemmed"] = ""
+ if keep_diacritics:
+ row[f"{column}_dediac"] = ""
+ return row
+
+ word_list = tokenize(text)
+ if word_list is None:
+ row[f"{column}_stemmed"] = text
+ if keep_diacritics:
+ row[f"{column}_dediac"] = dediac_ar(text)
+ return row
+
+ tokenized, _, _, has_diacs = morph_tokenize(
+ word_list, disambiguator, list_al_t, list_al, list_t, sep_token=sep_token
+ )
+
+ if tokenized is not None:
+ tokenized = merge_alef_and_alef_lam(tokenized, sep_token)
+ stemmed = "".join(tokenized)
+
+ if normalize:
+ stemmed = normalize_arabic(stemmed)
+
+ row[f"{column}_stemmed"] = stemmed
+
+ if keep_diacritics:
+ row[f"{column}_dediac"] = dediac_ar(stemmed)
+ else:
+ row[f"{column}_stemmed"] = text
+ if keep_diacritics:
+ row[f"{column}_dediac"] = dediac_ar(text)
+
+ return row
+
+ logger.info("Starting morphological tokenization...")
+ result = dataset.map(process_row, num_proc=num_proc, desc="Stemming")
+
+ logger.info(f"Stemming complete! Processed {len(result)} examples")
+ return result
+
+
+def stem(
+ text: str | list[str],
+ sep_token: str = "<+>",
+ normalize: bool = False,
+ keep_diacritics: bool = False,
+ model: str = "mle",
+ use_gpu: bool = False,
+) -> str | list[str]:
+ """
+ Stem Arabic text or list of texts.
+
+ Args:
+ text: Single string or list of strings to stem
+ sep_token: Separator token for morphological splits (default: '<+>')
+ normalize: Apply Arabic normalization (default: False)
+ keep_diacritics: Keep diacritics in output (default: False)
+ model: Disambiguator model to use - "mle" or "bert" (default: "mle")
+ use_gpu: Whether to use GPU for BERT model (default: False)
+
+ Returns:
+ Stemmed text in the same format as input (string or list of strings)
+
+ Example:
+ >>> # Stem a single string
+ >>> stemmed = stem("النص العربي")
+ >>> # Returns: "ال<+>نص ال<+>عربي"
+
+ >>> # Stem a list of strings
+ >>> stemmed = stem(["النص العربي", "مثال آخر"])
+ >>> # Returns: ["ال<+>نص ال<+>عربي", "مثال آخر"]
+
+ >>> # Stem with BERT model and GPU
+ >>> stemmed = stem("النص", model="bert", use_gpu=True)
+ """
+ # Validate model parameter
+ model = model.lower()
+ if model not in ["mle", "bert"]:
+ raise ValueError(f"Invalid model '{model}'. Must be 'mle' or 'bert'")
+
+ # Track whether input was a single string
+ is_single_string = isinstance(text, str)
+
+ # Convert single string to list for uniform processing
+ text_list = [text] if is_single_string else text
+
+ # Validate all items are strings
+ if not all(isinstance(t, str) for t in text_list):
+ raise TypeError("All items in text list must be strings")
+
+ # Initialize disambiguator (cached globally if possible)
+ logger.info(f"Initializing {model.upper()} disambiguator...")
+ catalogue = Catalogue.load_catalogue()
+ try:
+ catalogue.download_package("morphology-db-msa-r13")
+ if model == "mle":
+ catalogue.download_package("disambig-mle-calima-msa-r13")
+ except Exception as e:
+ logger.warning(f"Could not verify CAMeL packages: {e}")
+
+ if model == "mle":
+ disambiguator = MLEDisambiguator.pretrained("calima-msa-r13", cache_size=1_000_000)
+ else: # bert
+ disambiguator = BERTUnfactoredDisambiguator.pretrained(use_gpu=use_gpu)
+
+ # Add caching to disambiguator
+ def new_scored_analysis(self, word_dd):
+ if word_dd in self._cache:
+ return self._cache[word_dd]
+ result = self._scored_analyses(word_dd)
+ self._cache[word_dd] = result
+ return result
+
+ disambiguator._scored_analyses_cached = MethodType(new_scored_analysis, disambiguator)
+ disambiguator._score_fn = disambiguator._scored_analyses_cached
+
+ # Load word lists
+ words_dir = os.path.join(os.path.dirname(__file__), "data")
+ list_al_t = set(read_and_dediacritize(os.path.join(words_dir, "words_al_t.txt")))
+ list_al = set(read_and_dediacritize(os.path.join(words_dir, "words_al.txt")))
+ list_t = set(read_and_dediacritize(os.path.join(words_dir, "words_t.txt")))
+
+ # Process each text
+ results = []
+ for txt in text_list:
+ if not txt:
+ results.append("")
+ continue
+
+ word_list = tokenize(txt)
+ if word_list is None:
+ stemmed = dediac_ar(txt) if not keep_diacritics else txt
+ results.append(stemmed)
+ continue
+
+ tokenized, _, _, has_diacs = morph_tokenize(
+ word_list, disambiguator, list_al_t, list_al, list_t, sep_token=sep_token
+ )
+
+ if tokenized is not None:
+ tokenized = merge_alef_and_alef_lam(tokenized, sep_token)
+ stemmed = "".join(tokenized)
+
+ if normalize:
+ stemmed = normalize_arabic(stemmed)
+
+ if not keep_diacritics:
+ stemmed = dediac_ar(stemmed)
+
+ results.append(stemmed)
+ else:
+ stemmed = dediac_ar(txt) if not keep_diacritics else txt
+ results.append(stemmed)
+
+ # Return in the same format as input
+ return results[0] if is_single_string else results
diff --git a/dalla_data_processing/utils/__init__.py b/dalla_data_processing/utils/__init__.py
index acd2474..9912b42 100644
--- a/dalla_data_processing/utils/__init__.py
+++ b/dalla_data_processing/utils/__init__.py
@@ -5,6 +5,14 @@
"""
from dalla_data_processing.utils.logger import get_logger, logger, setup_logging
-from dalla_data_processing.utils.tokenize import simple_word_tokenize
__all__ = ["simple_word_tokenize", "logger", "get_logger", "setup_logging"]
+
+
+def __getattr__(name):
+ """Lazy load modules with optional dependencies."""
+ if name == "simple_word_tokenize":
+ from dalla_data_processing.utils.tokenize import simple_word_tokenize
+
+ return simple_word_tokenize
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
diff --git a/pyproject.toml b/pyproject.toml
index 0f02da0..cbf1851 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,10 +1,10 @@
[build-system]
-requires = ["setuptools>=61.0", "wheel"]
+requires = ["setuptools>=61.0", "setuptools-scm>=8.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "dalla-data-processing"
-version = "0.0.1"
+dynamic = ["version"]
description = "data processing pipeline with deduplication, stemming, quality checking, and readability scoring, used for the DALLA Models"
authors = [
{name = "Hadi Hamoud", email = "hhamoud@dohainstitute.edu.qa"},
@@ -25,13 +25,9 @@ classifiers = [
dependencies = [
"datasets>=2.14.0",
"transformers>=4.30.0",
- "camel-tools>=1.5.0",
"click>=8.0.0",
"tqdm>=4.65.0",
- "pandas>=2.0.0",
- "numpy>=1.24.0",
"pyarrow>=12.0.0",
- "textstat>=0.7.0",
"structlog>=24.0.0",
]
@@ -42,11 +38,28 @@ dev = [
"ruff>=0.1.0",
"pre-commit>=3.0.0",
]
+dedup = [
+ "camel-tools>=1.5.0",
+]
dedup-native = [
"cffi>=1.15.0",
]
+stem = [
+ "camel-tools>=1.5.0",
+]
+quality = [
+ "camel-tools>=1.5.0",
+]
+readability = [
+ "textstat>=0.7.0",
+]
+pack = [
+ "sentencepiece>=0.2.0",
+ "rbpe",
+ "pyyaml",
+]
all = [
- "dalla-data-processing[dev,dedup-native]",
+ "dalla-data-processing[dev,dedup,dedup-native,stem,quality,readability,pack]",
]
[project.scripts]
@@ -59,7 +72,7 @@ Repository = "https://github.com/U4RASD/dalla-data-processing"
"Bug Tracker" = "https://github.com/U4RASD/dalla-data-processing/issues"
[tool.setuptools]
-packages = ["dalla_data_processing", "dalla_data_processing.core", "dalla_data_processing.deduplication", "dalla_data_processing.stemming", "dalla_data_processing.quality", "dalla_data_processing.readability", "dalla_data_processing.utils"]
+packages = ["dalla_data_processing", "dalla_data_processing.core", "dalla_data_processing.deduplication", "dalla_data_processing.packing", "dalla_data_processing.stemming", "dalla_data_processing.quality", "dalla_data_processing.readability", "dalla_data_processing.utils"]
include-package-data = true
[tool.setuptools.package-data]
@@ -86,6 +99,9 @@ select = [
]
ignore = [
"E501",
+ "SIM102",
+ "N802",
+ "N806"
]
@@ -102,8 +118,11 @@ known-first-party = ["dalla_data_processing"]
"dalla_data_processing/deduplication/onion/**/*.py" = ["N", "SIM", "UP"]
"dalla_data_processing/stemming/__init__.py" = ["N802", "N806", "SIM102"]
-[tool.uv]
-dev-dependencies = [
+[tool.setuptools_scm]
+version_file = "dalla_data_processing/_version.py"
+
+[dependency-groups]
+dev = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"ruff>=0.1.0",