Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ plugins = "numpy.typing.mypy_plugin"
[[tool.mypy.overrides]]
module = [
"scipy",
"scipy.io",
"scipy.sparse",
"scipy.sparse.linalg",
"scipy.optimize",
Expand Down
8 changes: 6 additions & 2 deletions pyttb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

from pyttb.cp_als import cp_als
from pyttb.cp_apr import cp_apr
from pyttb.export_data import export_data
from pyttb.export_data import export_data, export_data_bin, export_data_mat
from pyttb.gcp_opt import gcp_opt
from pyttb.hosvd import hosvd
from pyttb.import_data import import_data
from pyttb.import_data import import_data, import_data_bin, import_data_mat
from pyttb.khatrirao import khatrirao
from pyttb.ktensor import ktensor
from pyttb.matlab import matlab_support
Expand All @@ -42,9 +42,13 @@ def ignore_warnings(ignore=True):
cp_als.__name__,
cp_apr.__name__,
export_data.__name__,
export_data_bin.__name__,
export_data_mat.__name__,
gcp_opt.__name__,
hosvd.__name__,
import_data.__name__,
import_data_bin.__name__,
import_data_mat.__name__,
khatrirao.__name__,
ktensor.__name__,
matlab_support.__name__,
Expand Down
112 changes: 111 additions & 1 deletion pyttb/export_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,23 @@

from __future__ import annotations

from typing import TextIO
from enum import Enum
from typing import Any, TextIO

import numpy as np
from scipy.io import savemat

import pyttb as ttb
from pyttb.pyttb_utils import Shape, parse_shape


class ExportFormat(Enum):
"""Export format enumeration."""

NUMPY = "numpy"
MATLAB = "matlab"


def export_data(
data: ttb.tensor | ttb.ktensor | ttb.sptensor | np.ndarray,
filename: str,
Expand Down Expand Up @@ -56,6 +65,107 @@ def export_data(
export_array(fp, data, fmt_data)


def export_data_bin(
data: ttb.tensor | ttb.ktensor | ttb.sptensor | np.ndarray,
filename: str,
index_base: int = 1,
):
"""Export tensor-related data to a binary file."""
_export_data_binary(data, filename, ExportFormat.NUMPY, index_base)


def export_data_mat(
data: ttb.tensor | ttb.ktensor | ttb.sptensor | np.ndarray,
filename: str,
index_base: int = 1,
):
"""Export tensor-related data to a matlab compatible binary file."""
_export_data_binary(data, filename, ExportFormat.MATLAB, index_base)


def _export_data_binary(
data: ttb.tensor | ttb.ktensor | ttb.sptensor | np.ndarray,
filename: str,
export_format: ExportFormat,
index_base: int = 1,
):
"""Export tensor-related data to a binary file using specified format."""
if not isinstance(data, (ttb.tensor, ttb.sptensor, ttb.ktensor, np.ndarray)):
raise NotImplementedError(f"Invalid data type for export: {type(data)}")

# Prepare data for export based on type
if isinstance(data, ttb.tensor):
export_data_dict = _prepare_tensor_data(data)
elif isinstance(data, ttb.sptensor):
export_data_dict = _prepare_sptensor_data(data, index_base)
elif isinstance(data, ttb.ktensor):
export_data_dict = _prepare_ktensor_data(data)
elif isinstance(data, np.ndarray):
export_data_dict = _prepare_matrix_data(data)
else:
raise NotImplementedError(f"Unsupported data type: {type(data)}")

# Save using appropriate format
if export_format == ExportFormat.NUMPY:
with open(filename, "wb") as fp:
np.savez(fp, allow_pickle=False, **export_data_dict)
elif export_format == ExportFormat.MATLAB:
savemat(filename, export_data_dict)
else:
raise ValueError(f"Unsupported export format: {export_format}")


def _create_header(data_type: str) -> np.ndarray:
"""Create consistent header for tensor data."""
# TODO encode version information
return np.array([data_type, "F"])


def _prepare_sptensor_data(data: ttb.sptensor, index_base: int = 1) -> dict[str, Any]:
"""Prepare sparse tensor data for export."""
return {
"header": _create_header("sptensor"),
"shape": np.array(data.shape),
"nnz": np.array([data.nnz]),
"subs": data.subs + index_base,
"vals": data.vals,
}


def _prepare_tensor_data(data: ttb.tensor) -> dict[str, Any]:
"""Prepare dense tensor data for export."""
return {
"header": _create_header("tensor"),
"data": data.data,
}


def _prepare_matrix_data(data: np.ndarray) -> dict[str, Any]:
"""Prepare matrix data for export."""
return {
"header": _create_header("matrix"),
"data": data,
}


def _prepare_ktensor_data(data: ttb.ktensor) -> dict[str, Any]:
"""Prepare ktensor data for export."""
factor_matrices = data.factor_matrices
num_factor_matrices = len(factor_matrices)

export_dict = {
"header": _create_header("ktensor"),
"weights": data.weights,
"num_factor_matrices": num_factor_matrices,
}

# Add individual factor matrices for NumPy compatibility
for i in range(num_factor_matrices):
export_dict[f"factor_matrix_{i}"] = factor_matrices[i]

return export_dict


def export_size(fp: TextIO, shape: Shape):
"""Export the size of something to a file."""
shape = parse_shape(shape)
Expand Down
110 changes: 109 additions & 1 deletion pyttb/import_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from typing import TextIO

import numpy as np
from scipy.io import loadmat

import pyttb as ttb
from pyttb.pyttb_utils import to_memory_order


def import_data(
Expand Down Expand Up @@ -65,12 +67,118 @@ def import_data(
fp.readline().strip() # Skip factor type
fac_shape = import_shape(fp)
fac = import_array(fp, np.prod(fac_shape))
fac = np.reshape(fac, np.array(fac_shape))
fac = to_memory_order(np.reshape(fac, np.array(fac_shape)), order="F")
factor_matrices.append(fac)
return ttb.ktensor(factor_matrices, weights, copy=False)
raise ValueError("Failed to load tensor data") # pragma: no cover


def import_data_bin(
filename: str,
index_base: int = 1,
) -> ttb.sptensor | ttb.ktensor | ttb.tensor | np.ndarray:
"""Import tensor-related data from a binary file."""

def load_bin_data(filename: str):
npzfile = np.load(filename, allow_pickle=False)
return {
"header": npzfile["header"][0],
"data": npzfile.get("data"),
"shape": tuple(npzfile["shape"]) if "shape" in npzfile else None,
"subs": npzfile.get("subs"),
"vals": npzfile.get("vals"),
"num_factor_matrices": int(npzfile["num_factor_matrices"])
if "num_factor_matrices" in npzfile
else None,
"factor_matrices": [
npzfile[f"factor_matrix_{i}"]
for i in range(int(npzfile["num_factor_matrices"]))
]
if "num_factor_matrices" in npzfile
else None,
"weights": npzfile.get("weights"),
}

return _import_tensor_data(filename, index_base, load_bin_data)


def import_data_mat(
filename: str,
index_base: int = 1,
) -> ttb.sptensor | ttb.ktensor | ttb.tensor | np.ndarray:
"""Import tensor-related data from a MATLAB file."""

def load_mat_data(filename: str):
mat_data = loadmat(filename)
header = mat_data["header"][0]
return {
"header": header.split()[0],
"data": mat_data.get("data"),
"shape": tuple(mat_data["shape"][0]) if "shape" in mat_data else None,
"subs": mat_data.get("subs"),
"vals": mat_data.get("vals"),
"num_factor_matrices": int(mat_data["num_factor_matrices"])
if "num_factor_matrices" in mat_data
else None,
"factor_matrices": [
mat_data[f"factor_matrix_{i}"]
for i in range(int(mat_data["num_factor_matrices"]))
]
if "num_factor_matrices" in mat_data
else None,
"weights": mat_data.get("weights").flatten()
if "weights" in mat_data
else None,
}

return _import_tensor_data(filename, index_base, load_mat_data)


def _import_tensor_data(
filename: str,
index_base: int,
data_loader,
) -> ttb.sptensor | ttb.ktensor | ttb.tensor | np.ndarray:
"""Generalized function to import tensor data from different file formats.

Parameters
----------
filename:
File to import.
index_base:
Index basing allows interoperability (Primarily between python and MATLAB).
data_loader:
Function that loads and structures the data from the file.
"""
# Check if file exists
if not os.path.isfile(filename):
raise FileNotFoundError(f"File path {filename} does not exist.")

loaded_data = data_loader(filename)
data_type = loaded_data["header"]

if data_type not in ["tensor", "sptensor", "matrix", "ktensor"]:
raise ValueError(f"Invalid data type found: '{data_type}'")

if data_type == "tensor":
data = loaded_data["data"]
return ttb.tensor(data)
elif data_type == "sptensor":
shape = loaded_data["shape"]
subs = loaded_data["subs"] - index_base
vals = loaded_data["vals"]
return ttb.sptensor(subs, vals, shape)
elif data_type == "matrix":
data = loaded_data["data"]
return data
elif data_type == "ktensor":
factor_matrices = loaded_data["factor_matrices"]
weights = loaded_data["weights"]
return ttb.ktensor(factor_matrices, weights)

raise ValueError(f"Invalid data type found: {data_type}")


def import_type(fp: TextIO) -> str:
"""Extract IO data type."""
return fp.readline().strip().split(" ")[0]
Expand Down
Loading
Loading