Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 86 additions & 12 deletions sprm/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,23 @@ def __init__(self, mask: MaskStruct, cidx: int):
recs.sort(kind="heapsort", order=["cellid", "y", "x"])
LOGGER.debug(f"maxes: {np.max(recs['cellid'])} {np.max(recs['x'])} {np.max(recs['y'])}")
entries, counts = np.unique(recs["cellid"], return_counts=True)
self.counts= counts
self.counts= np.stack((entries, counts), axis=1).copy()
self.vals = np.vstack((recs["y"], recs["x"])).copy()

def cell_iter(self):
def __iter__(self):
cell_offset = 0
coord_offset = 0
while cell_offset < len(self.counts):
cell_sz = self.counts[cell_offset]
yield self.vals[:, coord_offset:coord_offset+cell_sz]
cell_id, cell_sz = self.counts[cell_offset]
yield (cell_id,
self.vals[:, coord_offset:coord_offset+cell_sz])
cell_offset += 1
coord_offset += cell_sz

def cell_iter(self):
for cell_id, cell_mtx in self:
yield cell_mtx

def cells_only_iter(self):
"""
Like cell_iter but excludes the background "cell" at index 0
Expand All @@ -52,16 +57,43 @@ def subset_iter(self, restriction_set: set):
Yields both the cell index and the pixel matrix, with the guarantee
that the cell index does exist in the set.
"""
for idx, cell_mtx in enumerate(self.cell_iter()):
if idx in restriction_set:
yield idx, cell_mtx
for cell_id, cell_mtx in self:
if cell_id in restriction_set:
yield cell_id, cell_mtx

def background(self):
return self.vals[:, 0:self.counts[0]]
return self.vals[:, 0:self.counts[0][1]]

def __len__(self):
return len(self.counts)

def __eq__(self, other):
return (
isinstance(other, type(self))
and np.array_equal(self.counts, other.counts)
and np.array_equal(self.vals, other.vals)
)

def save(self, fname: str):
np.savez(fname, vals=self.vals, counts=self.counts)

@classmethod
def load(cls, fname: str):
npzfile = np.load(fname)
if ("vals" not in npzfile.files
or "counts" not in npzfile.files):
raise RuntimeError(
f"{fname} is not a {cls.__name__}"
)
inst = cls.__new__(cls)
inst.vals = npzfile["vals"].copy()
if inst.vals.shape[0] != 2:
raise RuntimeError(
f"{fname} has the wrong shape to be a {cls.__name__}"
)
inst.counts = npzfile["counts"].copy()
return inst


class CellTable3D:
def __init__(self, mask: MaskStruct, cidx: int):
Expand All @@ -80,30 +112,72 @@ def __init__(self, mask: MaskStruct, cidx: int):
LOGGER.debug(f"maxes: {np.max(recs['cellid'])} {np.max(recs['x'])}"
f" {np.max(recs['y'])} {np.max(recs['z'])}")
entries, counts = np.unique(recs["cellid"], return_counts=True)
self.counts= counts
self.counts= np.stack((entries, counts), axis=1).copy()
self.vals = np.vstack((recs["z"], recs["y"], recs["x"])).copy()

def cell_iter(self):
def __iter__(self):
cell_offset = 0
coord_offset = 0
while cell_offset < len(self.counts):
cell_sz = self.counts[cell_offset]
yield self.vals[:, coord_offset:coord_offset+cell_sz]
cell_id, cell_sz = self.counts[cell_offset]
yield (cell_id,
self.vals[:, coord_offset:coord_offset+cell_sz])
cell_offset += 1
coord_offset += cell_sz

def cell_iter(self):
for cell_id, cell_mtx in self:
yield cell_mtx

def cells_only_iter(self):
"""
Like cell_iter but excludes the background "cell" at index 0
"""
return itertools.islice(self.cell_iter(), 1)

def subset_iter(self, restriction_set: set):
"""
Iterate only through the cell indices included in restriction_set.
Yields both the cell index and the pixel matrix, with the guarantee
that the cell index does exist in the set.
"""
for cell_id, cell_mtx in self:
if cell_id in restriction_set:
yield cell_id, cell_mtx

def background(self):
return self.vals[:, 0:self.counts[0]]

def __len__(self):
return len(self.counts)

def __eq__(self, other):
return (
isinstance(other, type(self))
and np.array_equal(self.counts, other.counts)
and np.array_equal(self.vals, other.vals)
)

def save(self, fname: str):
np.savez(fname, vals=self.vals, counts=self.counts)

@classmethod
def load(cls, fname: str):
npzfile = np.load(fname)
if ("vals" not in npzfile.files
or "counts" not in npzfile.files):
raise RuntimeError(
f"{fname} is not a {cls.__name__}"
)
inst = cls.__new__(cls)
inst.vals = npzfile["vals"].copy()
if inst.vals.shape[0] != 3:
raise RuntimeError(
f"{fname} has the wrong shape to be a {cls.__name__}"
)
inst.counts = npzfile["counts"].copy()
return inst


class IMGstruct:
"""
Expand Down