Skip to content
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,8 @@ src/nexusformat/_version.py
# miscellaneous system files
.directoryhash

# Claude settings
.claude

# uv
uv.lock
61 changes: 34 additions & 27 deletions src/nexusformat/nexus/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,15 @@ def natural_sort(key):

Parameters
----------
key : str
key : str or Path
String in the list to be sorted.

Returns
-------
list
List of components splitting embedded numbers as integers.
"""
return [int(t) if t.isdigit() else t for t in re.split(r'(\d+)', key)]
return [int(t) if t.isdigit() else t for t in re.split(r'(\d+)', str(key))]


class NeXusError(Exception):
Expand Down Expand Up @@ -2345,10 +2345,7 @@ def save(self, filename=None, mode='w-', **kwargs):
f.writefile(root)
root = f._root
root._file = f
if mode == 'w' or mode == 'w-':
root._mode = 'rw'
else:
root._mode = mode
root._mode = 'rw'
self.set_changed()
return root
else:
Expand Down Expand Up @@ -4383,17 +4380,9 @@ def plot_shape(self):
"""
Shape of NXfield for plotting.

Size-1 axes are removed from the shape for multidimensional
data.
Size-1 axes are removed from the shape.
"""
try:
_shape = list(self.shape)
if len(_shape) > 1:
while 1 in _shape:
_shape.remove(1)
return tuple(_shape)
except Exception:
return ()
return tuple(s for s in self.shape if s > 1)

@property
def plot_rank(self):
Expand Down Expand Up @@ -4614,7 +4603,10 @@ def __deepcopy__(self, memo={}):
dpcpy._vidx = copy(obj._vidx)
dpcpy._vpath = copy(obj._vpath)
dpcpy._vfiles = copy(obj._vfiles)
shape = (len(obj._vfiles),) + slice_shape(obj._vidx, obj._vshape)
if obj._vidx:
shape = (len(obj._vfiles),) + slice_shape(obj._vidx, obj._vshape)
else:
shape = (len(obj._vfiles),) + obj._vshape
dpcpy._create_virtual_data(shape=shape, idx=obj._vidx)
dpcpy._h5opts = copy(obj._h5opts)
dpcpy._changed = True
Expand Down Expand Up @@ -6364,12 +6356,15 @@ def __enter__(self):
Current NXroot instance.
"""
if self.nxfile:
self._current_mode = self._mode
self._mode = self._file.mode = 'rw'
self.nxfile.__enter__()
return self

def __exit__(self, *args):
"""Close the NeXus file."""
if self.nxfile:
self._mode = self._file.mode = self._current_mode
self.nxfile.__exit__()

def serialize(self):
Expand Down Expand Up @@ -7912,7 +7907,8 @@ def nxauxiliary_signals(self, signals):
signals = list(signals)
if all(isinstance(signal, str) for signal in signals):
self.attrs['auxiliary_signals'] = signals
elif all(isinstance(signal, NXfield) for signal in signals):
elif all(isinstance(signal, NXfield) or
isinstance(signal, NXlink)for signal in signals):
self.attrs['auxiliary_signals'] = [signal.nxname for signal
in signals]
else:
Expand Down Expand Up @@ -8002,6 +7998,12 @@ def __init__(self, *args, **kwargs):
from datetime import datetime as dt
self.date = dt.isoformat(dt.today())

def set_date(self, date=None):
"""Set the date to a specific value."""
from datetime import datetime as dt
if date is None:
date = dt.today()
self.date = dt.isoformat(date)

class NXnote(NXgroup):

Expand Down Expand Up @@ -8590,7 +8592,7 @@ def consolidate(files, data_path, scan_path=None, idx=None):
Data slice to be used in the virtual field, by default None
"""

if isinstance(files[0], str):
if isinstance(files[0], str) or isinstance(files[0], Path):
files = [nxload(f) for f in files]
if isinstance(data_path, NXdata):
data_path = data_path.nxpath
Expand All @@ -8602,6 +8604,8 @@ def consolidate(files, data_path, scan_path=None, idx=None):
else:
scan_files = [f for f in files if data_path in f
and f[data_path].nxsignal.exists()]
if len(scan_files) == 0:
raise NeXusError(f'{data_path} not found in files')
scan_file = scan_files[0]
if scan_path:
scan_values = [f[scan_path] for f in scan_files]
Expand All @@ -8616,15 +8620,18 @@ def consolidate(files, data_path, scan_path=None, idx=None):
else:
scan_axis = NXfield(range(len(scan_files)), name='file_index',
long_name='File Index')
signal = scan_file[data_path].nxsignal
axes = scan_file[data_path].nxaxes
if idx is not None:
axes = [axis[s] for axis, s in zip(axes, idx)]
sources = [f[signal.nxpath].nxfilename for f in scan_files]
scan_field = NXvirtualfield(signal, sources, shape=signal.shape, idx=idx,
name=signal.nxname)
scan_data = NXdata(scan_field, [scan_axis] + axes,
name=scan_file[data_path].nxname)
for signal in [s for s in scan_file[data_path].nxsignals if s.exists()]:
if idx is not None:
axes = [axis[s] for axis, s in zip(axes, idx)]
sources = [f[signal.nxpath].nxfilename for f in scan_files]
scan_field = NXvirtualfield(signal, sources, shape=signal.shape,
idx=idx, name=signal.nxname)
if signal.nxname == scan_file[data_path].nxsignal.nxname:
scan_data = NXdata(scan_field, [scan_axis] + axes,
name=scan_file[data_path].nxname)
else:
scan_data[signal.nxname] = scan_field
scan_data.title = data_path
return scan_data

Expand Down
2 changes: 1 addition & 1 deletion src/nexusformat/nexus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dateutil.parser import parse


name_pattern = re.compile(r'^[a-zA-Z0-9_]([a-zA-Z0-9_.]*[a-zA-Z0-9_])?$')
name_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')


def get_logger():
Expand Down
Loading