diff --git a/datastock/_generic_check.py b/datastock/_generic_check.py index 035da81..975e61b 100644 --- a/datastock/_generic_check.py +++ b/datastock/_generic_check.py @@ -420,7 +420,20 @@ def _check_vectbasis( e2=None, dim=None, tol=None, + direct=None, ): + """ Check a 2d or 3d set of unit vectors + + Check that: + - vectors are defined (if None, they can be inferred) + + Normalizes vectorto be unit vectors + Optionally (default) check the vectors forma direct basis + """ + + # ---------------- + # check inputs + # ---------------- # dim dim = _check_var(dim, 'dim', types=int, default=3, allowed=[2, 3]) @@ -428,6 +441,13 @@ def _check_vectbasis( # tol tol = _check_var(tol, 'tol', types=float, default=1.e-14, sign='>0.') + # direct + direct = _check_var(direct, 'direct', types=bool, default=True) + + # ---------------------- + # check what's provided + # ---------------------- + # check is provided if e0 is not None: e0 = _check_flat1darray(e0, 'e0', size=dim, dtype=float, norm=True) @@ -436,13 +456,23 @@ def _check_vectbasis( if e2 is not None: e2 = _check_flat1darray(e2, 'e2', size=dim, dtype=float, norm=True) + # preliminary + allnone = all([ee is None for ee in [e0, e1, e2][:dim]]) + if allnone is True: + lstr = [f"\t- {ee}" for ee in ['e0', 'e1', 'e2'][:dim]] + msg = ( + f"For a basis f dimension {dim}, provide at least one of:\n" + + "\n".join(lstr) + ) + raise Exception(msg) + + # ---------------------- + # dim = 2 + # ---------------------- + # vectors if dim == 2: - if e0 is None and e1 is None: - msg = "Please provide e0 and/or e1!" - raise Exception(msg) - # complete if missing if e0 is None: e0 = np.r_[e1[1], -e1[0]] @@ -455,16 +485,18 @@ def _check_vectbasis( raise Exception(msg) # direct - if np.abs(np.cross(e0, e1).tolist() - 1.) < tol: - msg = "Non-direct basis" - raise Exception(msg) + if direct is True: + if np.abs(np.cross(e0, e1).tolist() - 1.) < tol: + msg = "Non-direct basis" + raise Exception(msg) return e0, e1 + # ---------------------- + # dim = 3 + # ---------------------- + else: - if e0 is None and e1 is None and e2 is None: - msg = "Please provide at least e0, e1 or e2!" - raise Exception(msg) # complete if 2 missing if e0 is None and e1 is None: @@ -472,7 +504,7 @@ def _check_vectbasis( elif e0 is None and e2 is None: e2 = _get_vertical_unitvect(ee=e1) elif e1 is None and e2 is None: - e2 = _get_vertical_unitvect(ee=e0) + e2 = _get_horizontal_unitvect(ee=e0) # complete if 1 missing if e0 is None: @@ -502,9 +534,10 @@ def _check_vectbasis( raise Exception(msg) # direct - if not np.allclose(np.cross(e0, e1), e2, atol=tol, rtol=1e-6): - msg = "Non-direct basis" - raise Exception(msg) + if direct is True: + if not np.allclose(np.cross(e0, e1), e2, atol=tol, rtol=1e-6): + msg = "Non-direct basis" + raise Exception(msg) return e0, e1, e2