From 9952a8da78aa004a13d46764b17138dd4795bda0 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Thu, 16 Jan 2025 16:32:49 -0800 Subject: [PATCH 1/9] tests passing --- README.md | 7 +++-- pyproject.toml | 26 +++++++++++++++++ setup.py | 41 --------------------------- zschema/keys.py | 2 +- zschema/{tests.py => test_zschema.py} | 16 +++++------ 5 files changed, 40 insertions(+), 52 deletions(-) create mode 100644 pyproject.toml delete mode 100644 setup.py rename zschema/{tests.py => test_zschema.py} (99%) diff --git a/README.md b/README.md index 5c91054..d510bec 100644 --- a/README.md +++ b/README.md @@ -188,8 +188,11 @@ https://github.com/zmap/zschema/blob/master/zschema/leaves.py#L25. Running Tests ============= -Tests are run with [nose](http://nose.readthedocs.io/en/latest/). Run them via -`python setup.py test`. +Tests are run with [pytest](https://docs.pytest.org/en/stable/). Run them via: +```zsh +pip3 install setuptools +python setup.py test +``` License and Copyright diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..04a230a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["setuptools>=42"] +build-backend = "setuptools.build_meta" + +[project] +name = "zschema" +description = "A schema language for JSON documents that allows validation and compilation into various database engines" +version = "0.11.0" +authors = [ { name = "ZMap Team"} ] +license = { text = "Apache License, Version 2.0" } # Replace with the actual license +keywords = ["python", "json", "schema", "bigquery", "elasticsearch"] + +dependencies = [ + "future", + "python-dateutil", + "pytz", + "six" +] + +[project.optional-dependencies] +tests = [ + "pytest" +] + +[project.scripts] +zschema = "zschema.__main__:main" diff --git a/setup.py b/setup.py deleted file mode 100644 index f1c148c..0000000 --- a/setup.py +++ /dev/null @@ -1,41 +0,0 @@ -# -*- coding: utf-8 -*- - -from setuptools import setup - -import os.path - -base_dir = os.path.dirname(__file__) - -about = dict() -with open(os.path.join(base_dir, "zschema", "__init__.py")) as f: - exec(f.read(), about) - -setup( - name = "zschema", - description = "A schema language for JSON documents that allows validation and compilation into various database engines", - version = about["__version__"], - license = about["__license__"], - author = about["__author__"], - author_email = about["__email__"], - keywords = "python json schema bigquery elasticsearch", - - install_requires = [ - "future", - "python-dateutil", - "pytz", - "six", - ], - - packages = [ - "zschema", - ], - - entry_points={ - 'console_scripts': [ - 'zschema = zschema.__main__:main', - ] - }, - - tests_require = [ 'nose' ], - test_suite = 'nose.collector' -) diff --git a/zschema/keys.py b/zschema/keys.py index f621d56..f608b2b 100644 --- a/zschema/keys.py +++ b/zschema/keys.py @@ -211,7 +211,7 @@ def _handle_validation_exception(policy, e): logging.error(e.message) raise e elif policy == "warn": - logging.warn(e.message) + logging.warning(e.message) elif policy == "ignore": pass else: diff --git a/zschema/tests.py b/zschema/test_zschema.py similarity index 99% rename from zschema/tests.py rename to zschema/test_zschema.py index b1b360a..1f930bc 100644 --- a/zschema/tests.py +++ b/zschema/test_zschema.py @@ -1,4 +1,4 @@ -import collections +from collections.abc import Sized import datetime import json import os @@ -352,10 +352,10 @@ def assertBigQuerySchemaEqual(self, a, b): if a == b: return else: - self.assertEquals(type(a), type(b)) - if isinstance(a, collections.Sized) \ - and isinstance(b, collections.Sized): - self.assertEquals(len(a), len(b)) + self.assertEqual(type(a), type(b)) + if isinstance(a, Sized) \ + and isinstance(b, Sized): + self.assertEqual(len(a), len(b)) if isinstance(a, list) and isinstance(b, list): name_ordered_a = sorted(a, key=lambda x: x['name']) name_ordered_b = sorted(b, key=lambda x: x['name']) @@ -366,7 +366,7 @@ def assertBigQuerySchemaEqual(self, a, b): self.assertIn(k, b) self.assertBigQuerySchemaEqual(a[k], b[k]) else: - self.assertEquals(a, b) + self.assertEqual(a, b) def setUp(self): self.maxDiff=10000 @@ -514,7 +514,7 @@ def test_merge_recursive(self): "b":String() }) }) - self.assertEquals(a.merge(b).to_dict(), c.to_dict()) + self.assertEqual(a.merge(b).to_dict(), c.to_dict()) def test_extends(self): host = Record({ @@ -988,7 +988,7 @@ def test_bad_root(self): self.SCHEMA.validate(bad1, policy="error") self.assertTrue(False, "bad1 failed to fail") except DataValidationException as e: - self.assertEquals(e.path, ["b"]) + self.assertEqual(e.path, ["b"]) def test_bad_a_key(self): From bd518c82c18dc0e61a2d2ecc14240ecb8cebc7a9 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Thu, 16 Jan 2025 16:34:58 -0800 Subject: [PATCH 2/9] remove defunct travis.ci --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index d510bec..74d7c92 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@ ZSchema ======= -[![Build Status](https://travis-ci.org/zmap/zschema.svg?branch=master)](https://travis-ci.org/zmap/zschema) - ZSchema is a generic (meta-)schema language for defining database schemas. It facilitates (1) validating JSON documents against a schema definition and (2) compilating a schema to multiple database engines. For example, if you wanted From 75e82823cd466fb0f29a9511e4a96278fb47f306 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Thu, 16 Jan 2025 16:43:30 -0800 Subject: [PATCH 3/9] add GH CI --- .github/workflows/ci.yml | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..7cc7760 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,33 @@ +name: Run Unit Tests + +on: + push: + branches: + - main + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + + steps: + # Step 1: Check out the code + - name: Checkout code + uses: actions/checkout@v3 + + # Step 2: Set up Python + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' # Specify the Python version you want to use + + # Step 3: Install dependencies + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools pytest + pip install -r requirements.txt || true # In case requirements.txt doesn't exist + + # Step 4: Run tests + - name: Run tests + run: | + pip3 install ".[tests]" \ No newline at end of file From fab403afc1dc36fc177d6b80d26c62571986df8a Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Thu, 16 Jan 2025 16:43:38 -0800 Subject: [PATCH 4/9] update test command --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 74d7c92..89622f6 100644 --- a/README.md +++ b/README.md @@ -188,8 +188,7 @@ Running Tests Tests are run with [pytest](https://docs.pytest.org/en/stable/). Run them via: ```zsh -pip3 install setuptools -python setup.py test +pip3 install ".[tests]" ``` From 1e403027cf4a577ed70b2dfda1bb07fb94975c6d Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Thu, 16 Jan 2025 16:51:36 -0800 Subject: [PATCH 5/9] fix warn caused by regex --- zschema/leaves.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zschema/leaves.py b/zschema/leaves.py index 55d6419..44018a0 100644 --- a/zschema/leaves.py +++ b/zschema/leaves.py @@ -540,7 +540,7 @@ def _validate(self, name, value, path=_NO_ARG): if isinstance(value, datetime.datetime): dt = value elif isinstance(value, int): - dt = datetime.datetime.utcfromtimestamp(value) + dt = datetime.datetime.fromtimestamp(value, datetime.timezone.utc) else: dt = dateutil.parser.parse(value) except (ValueError, TypeError): @@ -577,7 +577,7 @@ class OID(String): VALID = "1.3.6.1.4.868.2.4.1" INVALID = "hello" - OID_REGEX = re.compile("[[0-9]+\\.]*") + OID_REGEX = re.compile(r"^(\d+\.)+\d+$") def _is_oid(self, data): return bool(self.OID_REGEX.match(data)) From 0e3f7b18ef93e240a3f512c8d90c5c149b12329d Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Thu, 16 Jan 2025 16:53:24 -0800 Subject: [PATCH 6/9] fix test command in CI and README --- .github/workflows/ci.yml | 3 ++- README.md | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7cc7760..5946939 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,4 +30,5 @@ jobs: # Step 4: Run tests - name: Run tests run: | - pip3 install ".[tests]" \ No newline at end of file + pip3 install ".[tests]" + pytest \ No newline at end of file diff --git a/README.md b/README.md index 89622f6..2d76a66 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,7 @@ Running Tests Tests are run with [pytest](https://docs.pytest.org/en/stable/). Run them via: ```zsh pip3 install ".[tests]" +pytest ``` From 61c7e48d7f9beb89d50539b1d190b865da52c963 Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Thu, 16 Jan 2025 17:02:25 -0800 Subject: [PATCH 7/9] make ci manually triggerable --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5946939..eedf7a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,6 +5,7 @@ on: branches: - main pull_request: + workflow_dispatch: jobs: test: From bc15818699553594d8df4521585845c93b02853e Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Thu, 16 Jan 2025 17:09:43 -0800 Subject: [PATCH 8/9] fix datetime timezone warnings --- zschema/leaves.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/zschema/leaves.py b/zschema/leaves.py index 44018a0..817600e 100644 --- a/zschema/leaves.py +++ b/zschema/leaves.py @@ -516,6 +516,13 @@ class DateTime(Leaf): # dateutil.parser.parse(int) throws...? is this intended to be a unix epoch offset? EXPECTED_CLASS = string_types + (int, datetime.datetime) + TZINFOS = { + "EDT": datetime.timezone(datetime.timedelta(hours=-4)), # Eastern Daylight Time + "EST": datetime.timezone(datetime.timedelta(hours=-5)), # Eastern Standard Time + 'CDT': datetime.timezone(datetime.timedelta(hours=-5)), # Central Daylight Time + 'CST': datetime.timezone(datetime.timedelta(hours=-6)), # Central Standard Time + } + VALID = "Wed Jul 8 08:52:01 EDT 2015" INVALID = "Wed DNE 35 08:52:01 EDT 2015" @@ -526,14 +533,14 @@ def __init__(self, *args, **kwargs): super(DateTime, self).__init__(*args, **kwargs) if self.min_value: - self._min_value_dt = dateutil.parser.parse(self.min_value) + self._min_value_dt = dateutil.parser.parse(self.min_value, tzinfos=self.TZINFOS) else: - self._min_value_dt = dateutil.parser.parse(self.MIN_VALUE) + self._min_value_dt = dateutil.parser.parse(self.MIN_VALUE, tzinfos=self.TZINFOS) if self.max_value: - self._max_value_dt = dateutil.parser.parse(self.max_value) + self._max_value_dt = dateutil.parser.parse(self.max_value, tzinfos=self.TZINFOS) else: - self._max_value_dt = dateutil.parser.parse(self.MAX_VALUE) + self._max_value_dt = dateutil.parser.parse(self.MAX_VALUE, tzinfos=self.TZINFOS) def _validate(self, name, value, path=_NO_ARG): try: @@ -542,7 +549,7 @@ def _validate(self, name, value, path=_NO_ARG): elif isinstance(value, int): dt = datetime.datetime.fromtimestamp(value, datetime.timezone.utc) else: - dt = dateutil.parser.parse(value) + dt = dateutil.parser.parse(value, tzinfos=self.TZINFOS) except (ValueError, TypeError): # Either `datetime.utcfromtimestamp` or `dateutil.parser.parse` above # may raise on invalid input. From 31375a2f23f8c54bffc8eb6502edbf7ed0af975b Mon Sep 17 00:00:00 2001 From: phillip-stephens Date: Wed, 22 Jan 2025 17:06:45 -0800 Subject: [PATCH 9/9] fixed some more python2 issues, tested with a zgrab2 integration test --- zschema/__main__.py | 108 +++--- zschema/compounds.py | 268 +++++++++------ zschema/example.py | 29 +- zschema/keys.py | 68 ++-- zschema/leaves.py | 172 ++++++---- zschema/registry.py | 4 + zschema/test_zschema.py | 730 ++++++++++++++++++++-------------------- 7 files changed, 761 insertions(+), 618 deletions(-) diff --git a/zschema/__main__.py b/zschema/__main__.py index 7d77e58..873df45 100644 --- a/zschema/__main__.py +++ b/zschema/__main__.py @@ -1,16 +1,16 @@ import sys +import importlib.util import os.path import json import zschema.registry import argparse -from imp import load_source from importlib import import_module from site import addsitedir -from leaves import * -from keys import * -from compounds import * +from .leaves import * +from .keys import * +from .compounds import * commands = [ "bigquery", @@ -20,43 +20,59 @@ "docs-es", "validate", "flat", - "json" + "json", ] cmdList = ", ".join(commands) parser = argparse.ArgumentParser( prog="zschema", - description="Process a zschema definition. " - "VERSION: %s" % zschema.__version__) - -parser.add_argument("command", - metavar="command", choices=commands, - help="The command to execute; one of [ %s ]" % cmdList) - -parser.add_argument("schema", - help="The name of the schema in the zschema.registry. " - "For backwards compatibility, a filename can be " - "prefixed with a colon, as in 'schema.py:my-type'.") - -parser.add_argument("target", nargs="?", - help="Only used for the validate command. " - "The input JSON file that will be checked against " - "the schema.") + description="Process a zschema definition. " "VERSION: %s" % zschema.__version__, +) + +parser.add_argument( + "command", + metavar="command", + choices=commands, + help="The command to execute; one of [ %s ]" % cmdList, +) + +parser.add_argument( + "schema", + help="The name of the schema in the zschema.registry. " + "For backwards compatibility, a filename can be " + "prefixed with a colon, as in 'schema.py:my-type'.", +) + +parser.add_argument( + "target", + nargs="?", + help="Only used for the validate command. " + "The input JSON file that will be checked against " + "the schema.", +) parser.add_argument("--module", help="The name of a module to import.") -parser.add_argument("--validation-policy", help="What to do when a validation " - "error occurs. This only overrides the top-level Record. It does not " - "override subrecords. Default: error.", choices=["ignore", "warn", "error"], - default=None) - -parser.add_argument("--validation-policy-override", help="Override validation " - "policy for all levels of the schema.", choices=["ignore", "warn", "error"], - default=None) - -parser.add_argument("--path", nargs="*", - help="Additional PYTHONPATH directories to include.") +parser.add_argument( + "--validation-policy", + help="What to do when a validation " + "error occurs. This only overrides the top-level Record. It does not " + "override subrecords. Default: error.", + choices=["ignore", "warn", "error"], + default=None, +) + +parser.add_argument( + "--validation-policy-override", + help="Override validation " "policy for all levels of the schema.", + choices=["ignore", "warn", "error"], + default=None, +) + +parser.add_argument( + "--path", nargs="*", help="Additional PYTHONPATH directories to include." +) args = parser.parse_args() @@ -71,7 +87,7 @@ def main(): # Backwards compatibility: given "file.py:schema", load file.py. if ":" in schema: path, recname = schema.split(":") - load_source('module', path) + load_source("module", path) schema = recname if args.module: @@ -82,31 +98,39 @@ def main(): record.set("validation_policy", args.validation_policy) command = args.command if command == "bigquery": - print json.dumps(record.to_bigquery()) + print(json.dumps(record.to_bigquery())) elif command == "elasticsearch": - print json.dumps(record.to_es(recname)) + print(json.dumps(record.to_es(recname))) elif command == "proto": - print record.to_proto(recname) + print(record.to_proto(recname)) elif command == "docs-es": - print json.dumps(record.docs_es(recname)) + print(json.dumps(record.docs_es(recname))) elif command == "docs-bq": - print json.dumps(record.docs_bq(recname)) + print(json.dumps(record.docs_bq(recname))) elif command == "json": - print record.to_json() + print(record.to_json()) elif command == "flat": for r in record.to_flat(): - print json.dumps(r) + print(json.dumps(r)) elif command == "validate": if not os.path.exists(args.target): sys.stderr.write("Invalid test file. %s does not exist.\n" % args.target) sys.exit(1) with open(args.target) as fd: for line in fd: - record.validate(json.loads(line.strip()), - args.validation_policy_override) + record.validate( + json.loads(line.strip()), args.validation_policy_override + ) else: usage() +def load_source(name, path): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + if __name__ == "__main__": main() diff --git a/zschema/compounds.py b/zschema/compounds.py index 52bb3c0..aae0811 100644 --- a/zschema/compounds.py +++ b/zschema/compounds.py @@ -14,14 +14,17 @@ def _is_valid_object(name, object_): if not isinstance(object_, Keyable): raise Exception("Invalid schema. %s is not a Keyable." % name) + def _proto_message_name(string): if string != string.lower(): return string string = "".join(w.capitalize() for w in string.split("_")) return string + def _proto_indent(string, n): - return "\n".join(n*" " + s for s in string.split("\n")) + return "\n".join(n * " " + s for s in string.split("\n")) + # Track protobuf message definitions that have been emitted. _proto_messages = OrderedDict() @@ -43,18 +46,19 @@ def __init__(self, object_, max_items=_NO_ARG, min_items=_NO_ARG, *args, **kwarg def exclude_bigquery(self): # If the child type is excluded, that is the same as excluding this -- # it's not clear what it would mean otherwise, from a schema perspective - return super(ListOf, self).exclude_bigquery \ - or self.object_.exclude_bigquery + return super(ListOf, self).exclude_bigquery or self.object_.exclude_bigquery @property def exclude_elasticsearch(self): - return super(ListOf, self).exclude_elasticsearch \ - or self.object_.exclude_elasticsearch + return ( + super(ListOf, self).exclude_elasticsearch + or self.object_.exclude_elasticsearch + ) def print_indent_string(self, name, indent): tabs = "\t" * indent if indent else "" print('{}{}"{:s}"'.format(tabs, name, self.__class__.__name__)) - self.object_.print_indent_string(self.key_to_string(name), indent+1) + self.object_.print_indent_string(self.key_to_string(name), indent + 1) def to_bigquery(self, name): retv = self.object_.to_bigquery(name) @@ -88,7 +92,9 @@ def docs_es(self, parent_category=None): retv["doc"] = self.doc return retv - def validate(self, name, value, policy=_NO_ARG, parent_policy=_NO_ARG, path=_NO_ARG): + def validate( + self, name, value, policy=_NO_ARG, parent_policy=_NO_ARG, path=_NO_ARG + ): calculated_policy = self._calculate_policy(name, policy, parent_policy) if not path: path = [] @@ -97,12 +103,18 @@ def validate(self, name, value, policy=_NO_ARG, parent_policy=_NO_ARG, path=_NO_ m = "%s: %s is not a list" % (name, str(value)) raise DataValidationException(m, path=path) if self.max_items > 0 and len(value) > self.max_items: - m = "%s: %s has too many values (max: %i)" % (name, str(value), - self.max_items) + m = "%s: %s has too many values (max: %i)" % ( + name, + str(value), + self.max_items, + ) raise DataValidationException(m, path=path) if self.min_items > 0 and len(value) < self.min_items: - m = "%s: %s has too few values (min: %i)" % (name, str(value), - self.min_items) + m = "%s: %s has too few values (min: %i)" % ( + name, + str(value), + self.min_items, + ) raise DataValidationException(m, path=path) except DataValidationException as e: self._handle_validation_exception(calculated_policy, e) @@ -110,27 +122,31 @@ def validate(self, name, value, policy=_NO_ARG, parent_policy=_NO_ARG, path=_NO_ return for i, item in enumerate(value): try: - self.object_.validate(name, item, policy, calculated_policy, path=path + [i]) + self.object_.validate( + name, item, policy, calculated_policy, path=path + [i] + ) except DataValidationException as e: self._handle_validation_exception(calculated_policy, e) def to_dict(self): - return {"type":"list", "list_of":self.object_.to_json()} + return {"type": "list", "list_of": self.object_.to_json()} def to_flat(self, parent, name): for rec in self.object_.to_flat(parent, name, repeated=True): yield rec -def ListOfType(object_, - required=_NO_ARG, - max_items=_NO_ARG, - doc=_NO_ARG, - desc=_NO_ARG, - examples=_NO_ARG, - category=_NO_ARG, - validation_policy=_NO_ARG, - pr_ignore=_NO_ARG): +def ListOfType( + object_, + required=_NO_ARG, + max_items=_NO_ARG, + doc=_NO_ARG, + desc=_NO_ARG, + examples=_NO_ARG, + category=_NO_ARG, + validation_policy=_NO_ARG, + pr_ignore=_NO_ARG, +): _is_valid_object("Anonymous ListOf", object_) t = type("ListOf", (ListOf,), {}) t.set_default("object_", object_) @@ -151,9 +167,16 @@ class SubRecord(Keyable): TYPE_NAME = None ES_NESTED = False - def __init__(self, definition=_NO_ARG, extends=_NO_ARG, - allow_unknown=_NO_ARG, type_name=_NO_ARG, es_nested=_NO_ARG, - *args, **kwargs): + def __init__( + self, + definition=_NO_ARG, + extends=_NO_ARG, + allow_unknown=_NO_ARG, + type_name=_NO_ARG, + es_nested=_NO_ARG, + *args, + **kwargs + ): super(SubRecord, self).__init__(*args, **kwargs) self.set("definition", definition) self.set("allow_unknown", allow_unknown) @@ -183,8 +206,10 @@ def new(self, **kwargs): # "ca": Certificate.new(doc="The CA certificate."), # "host": Certificate.new(doc="The host certificate.", required=True) # }) - e = "WARN: .new() is deprecated and will be removed in a "\ - "future release. Schemas should use SubRecordTypes.\n" + e = ( + "WARN: .new() is deprecated and will be removed in a " + "future release. Schemas should use SubRecordTypes.\n" + ) sys.stderr.write(e) return SubRecord({}, extends=self, **kwargs) @@ -195,8 +220,10 @@ def to_flat(self, parent, name, repeated=False): mode = "required" else: mode = "nullable" - this_name = ".".join([parent, self.key_to_es(name)]) if parent else self.key_to_es(name) - yield {"type":self.__class__.__name__, "name":this_name, "mode":mode} + this_name = ( + ".".join([parent, self.key_to_es(name)]) if parent else self.key_to_es(name) + ) + yield {"type": self.__class__.__name__, "name": this_name, "mode": mode} for subname, doc in sorted(self.definition.items()): for item in doc.to_flat(this_name, self.key_to_es(subname)): yield item @@ -206,7 +233,7 @@ def merge(self, other): newdef = {} l_keys = set(self.definition.keys()) r_keys = set(other.definition.keys()) - for key in (l_keys | r_keys): + for key in l_keys | r_keys: l_value = self.definition.get(key, None) r_value = other.definition.get(key, None) if not l_value: @@ -227,23 +254,26 @@ def merge(self, other): return self def to_bigquery(self, name): - fields = [v.to_bigquery(k) \ - for (k,v) in sorted(self.definition.items()) \ - if not v.exclude_bigquery - ] + fields = [ + v.to_bigquery(k) + for (k, v) in sorted(self.definition.items()) + if not v.exclude_bigquery + ] retv = { - "name":self.key_to_bq(name), - "type":"RECORD", - "fields":fields, - "mode":"REQUIRED" if self.required else "NULLABLE" + "name": self.key_to_bq(name), + "type": "RECORD", + "fields": fields, + "mode": "REQUIRED" if self.required else "NULLABLE", } return retv def to_proto(self, name, indent): - if self.type_name is not None: # named message type -- produced at top level, once + if ( + self.type_name is not None + ): # named message type -- produced at top level, once message_type = _proto_message_name(self.type_name) anon = False - else: # anonymous message type -- nests within containing message + else: # anonymous message type -- nests within containing message message_type = _proto_message_name(self.key_to_proto(name)) + "Struct" anon = True @@ -264,7 +294,7 @@ def to_proto(self, name, indent): n = 0 proto = [] retvs = explicits - for (v, i) in retvs: + for v, i in retvs: if v["message"]: proto += [v["message"]] if i is not None: @@ -272,22 +302,26 @@ def to_proto(self, name, indent): else: n += 1 proto += ["%s = %d;" % (v["field"], n)] - proto_def = "message %s {\n%s\n}" % \ - (message_type, _proto_indent("\n".join(proto), indent+1)) + proto_def = "message %s {\n%s\n}" % ( + message_type, + _proto_indent("\n".join(proto), indent + 1), + ) if not anon: _proto_messages[message_type] = proto_def proto_def = "" return { "message": proto_def, - "field": "%s %s" % (message_type, self.key_to_proto(name)) + "field": "%s %s" % (message_type, self.key_to_proto(name)), } def docs_bq(self, parent_category=None): category = self.category or parent_category retv = self._docs_common(category) - fields = { self.key_to_bq(k): v.docs_bq(parent_category=category) \ - for (k,v) in sorted(self.definition.items()) \ - if not v.exclude_bigquery } + fields = { + self.key_to_bq(k): v.docs_bq(parent_category=category) + for (k, v) in sorted(self.definition.items()) + if not v.exclude_bigquery + } retv["fields"] = fields return retv @@ -295,12 +329,14 @@ def print_indent_string(self, name, indent): tabs = "\t" * indent if indent else "" print("{}{:s}:subrecord:".format(tabs, self.key_to_string(name))) for name, value in sorted(self.definition.items()): - value.print_indent_string(name, indent+1) + value.print_indent_string(name, indent + 1) def to_es(self): - p = {self.key_to_es(k): v.to_es() \ - for k, v in sorted(self.definition.items()) \ - if not v.exclude_elasticsearch} + p = { + self.key_to_es(k): v.to_es() + for k, v in sorted(self.definition.items()) + if not v.exclude_elasticsearch + } retv = {"properties": p} if self.es_nested: retv["type"] = "nested" @@ -318,18 +354,26 @@ def _docs_common(self, category): def docs_es(self, parent_category=None): category = self.category or parent_category retv = self._docs_common(category) - retv["fields"] = { self.key_to_es(k): v.docs_es(parent_category=category) \ - for k, v in sorted(self.definition.items()) \ - if not v.exclude_elasticsearch } + retv["fields"] = { + self.key_to_es(k): v.docs_es(parent_category=category) + for k, v in sorted(self.definition.items()) + if not v.exclude_elasticsearch + } return retv def to_dict(self): source = sorted(self.definition.items()) p = {self.key_to_es(k): v.to_dict() for k, v in source} - return {"type":"subrecord", "subfields": p, "doc":self.doc, "required":self.required} - + return { + "type": "subrecord", + "subfields": p, + "doc": self.doc, + "required": self.required, + } - def validate(self, name, value, policy=_NO_ARG, parent_policy=_NO_ARG, path=_NO_ARG): + def validate( + self, name, value, policy=_NO_ARG, parent_policy=_NO_ARG, path=_NO_ARG + ): calculated_policy = self._calculate_policy(name, policy, parent_policy) if not path: path = [] @@ -345,11 +389,17 @@ def validate(self, name, value, policy=_NO_ARG, parent_policy=_NO_ARG, path=_NO_ for subkey, subvalue in sorted(value.items()): try: if not self.allow_unknown and subkey not in self.definition: - raise DataValidationException("%s: %s is not a valid subkey" % - (name, subkey), path=path) + raise DataValidationException( + "%s: %s is not a valid subkey" % (name, subkey), path=path + ) if subkey in self.definition: - self.definition[subkey].validate(subkey, subvalue, - policy, calculated_policy, path=path + [subkey]) + self.definition[subkey].validate( + subkey, + subvalue, + policy, + calculated_policy, + path=path + [subkey], + ) except DataValidationException as e: self._handle_validation_exception(calculated_policy, e) @@ -363,12 +413,18 @@ def _get_copy_default(cls, k): v = cls._INIT_DEFAULTS.get(k, _NO_ARG) return copy.deepcopy(v) - def __init__(self, definition=_NO_ARG, extends=_NO_ARG, - allow_unknown=_NO_ARG, type_name=_NO_ARG, *args, **kwargs): - definition = definition or self._get_copy_default('definition') - allow_unknown = allow_unknown or self._get_copy_default( - 'allow_unknown') - type_name = type_name or self._get_copy_default('type_name') + def __init__( + self, + definition=_NO_ARG, + extends=_NO_ARG, + allow_unknown=_NO_ARG, + type_name=_NO_ARG, + *args, + **kwargs + ): + definition = definition or self._get_copy_default("definition") + allow_unknown = allow_unknown or self._get_copy_default("allow_unknown") + type_name = type_name or self._get_copy_default("type_name") for k, v in self._INIT_DEFAULTS.items(): if k in {"definition", "allow_unknown", "type_name"}: # These keys are managed by the constructor @@ -389,21 +445,27 @@ def _set_default_at_init(cls, k, v): cls.set_default(k, v) -def SubRecordType(definition, - required=_NO_ARG, - type_name=_NO_ARG, - doc=_NO_ARG, - desc=_NO_ARG, - allow_unknown=_NO_ARG, - exclude=_NO_ARG, - category=_NO_ARG, - validation_policy=_NO_ARG, - pr_ignore=_NO_ARG): - #import pdb; pdb.set_trace() +def SubRecordType( + definition, + required=_NO_ARG, + type_name=_NO_ARG, + doc=_NO_ARG, + desc=_NO_ARG, + allow_unknown=_NO_ARG, + exclude=_NO_ARG, + category=_NO_ARG, + validation_policy=_NO_ARG, + pr_ignore=_NO_ARG, +): + # import pdb; pdb.set_trace() name = type_name if type_name else "SubRecordType" - t = type(name, (_SubRecordDefaulted,), { - "_INIT_DEFAULTS": dict(), - }) + t = type( + name, + (_SubRecordDefaulted,), + { + "_INIT_DEFAULTS": dict(), + }, + ) t._set_default_at_init("definition", definition) t._set_default_at_init("type_name", type_name) t._set_default_at_init("required", required) @@ -419,15 +481,23 @@ def SubRecordType(definition, class NestedListOf(ListOf): - def __init__(self, object_, subrecord_name, max_items=10, doc=None, category=None, *args, **kwargs): - super(NestedListOf, self).__init__(object_, max_items=max_items, - doc=doc, category=category, *args, **kwargs) + def __init__( + self, + object_, + subrecord_name, + max_items=10, + doc=None, + category=None, + *args, + **kwargs + ): + super(NestedListOf, self).__init__( + object_, max_items=max_items, doc=doc, category=category, *args, **kwargs + ) self.set("subrecord_name", subrecord_name) def to_bigquery(self, name): - subr = SubRecord({ - self.subrecord_name:ListOf(self.object_) - }) + subr = SubRecord({self.subrecord_name: ListOf(self.object_)}) retv = subr.to_bigquery(self.key_to_bq(name)) retv["mode"] = "REPEATED" if self.doc: @@ -435,9 +505,7 @@ def to_bigquery(self, name): return retv def docs_bq(self, parent_category=None): - subr = SubRecord({ - self.subrecord_name: ListOf(self.object_) - }) + subr = SubRecord({self.subrecord_name: ListOf(self.object_)}) category = self.category or parent_category retv = subr.docs_bq(parent_category=category) retv["repeated"] = True @@ -455,7 +523,7 @@ def to_es(self, name): subrecord = SubRecord.to_es(self) if self.es_dynamic_policy != None: subrecord["dynamic"] = self.es_dynamic_policy - return {name:subrecord} + return {name: subrecord} def docs_es(self, name, parent_category=None): category = self.category or parent_category @@ -463,10 +531,7 @@ def docs_es(self, name, parent_category=None): def to_bigquery(self): source = sorted(self.definition.items()) - return [s.to_bigquery(name) \ - for (name, s) in source \ - if not s.exclude_bigquery - ] + return [s.to_bigquery(name) for (name, s) in source if not s.exclude_bigquery] def to_proto(self, name): self.type_name = name @@ -476,7 +541,9 @@ def to_proto(self, name): import "google/protobuf/timestamp.proto"; -""" + "\n".join(_proto_messages.values()) +""" + "\n".join( + _proto_messages.values() + ) def docs_bq(self, name, parent_category=None): category = self.category or parent_category @@ -491,10 +558,14 @@ def validate(self, value, policy=_NO_ARG, path=_NO_ARG): policy = _NO_ARG if not path: path = [] - calculated_policy = self._calculate_policy("root", policy, self.validation_policy) + calculated_policy = self._calculate_policy( + "root", policy, self.validation_policy + ) # ^ note: record explicitly does not take a parent_policy if not isinstance(value, dict): - raise DataValidationException("record is not a dict:\n{}".format(value), path=path) + raise DataValidationException( + "record is not a dict:\n{}".format(value), path=path + ) for subkey, subvalue in sorted(value.items()): try: if subkey not in self.definition: @@ -525,4 +596,3 @@ def to_flat(self): @classmethod def from_json(cls, j): return cls({(k, __encode(v)) for k, v in sorted(j.items())}) - diff --git a/zschema/example.py b/zschema/example.py index c8e7cde..7ac8f21 100644 --- a/zschema/example.py +++ b/zschema/example.py @@ -3,18 +3,19 @@ from zschema.leaves import Boolean, DateTime, IPv4Address, String, Unsigned32BitInteger -heartbleed = SubRecord({ - "heartbeat_support":Boolean(), - "heartbleed_vulnerable":Boolean(), - "timestamp":DateTime() -}) +heartbleed = SubRecord( + { + "heartbeat_support": Boolean(), + "heartbleed_vulnerable": Boolean(), + "timestamp": DateTime(), + } +) -host = Record({ - "ipstr":IPv4Address(required=True), - "ip":Unsigned32BitInteger(), - Port(443):SubRecord({ - "tls":String(), - "heartbleed":heartbleed - }), - "tags":ListOf(String()) -}) +host = Record( + { + "ipstr": IPv4Address(required=True), + "ip": Unsigned32BitInteger(), + Port(443): SubRecord({"tls": String(), "heartbleed": heartbleed}), + "tags": ListOf(String()), + } +) diff --git a/zschema/keys.py b/zschema/keys.py index f608b2b..e06a65b 100644 --- a/zschema/keys.py +++ b/zschema/keys.py @@ -3,9 +3,11 @@ from six import string_types import logging +import sys _keyable_counter = 0 + class _NO_ARG(object): __nonzero__ = lambda _: False __bool__ = lambda _: False @@ -17,6 +19,7 @@ def __new__(cls): cls._instance = retv return retv + _NO_ARG = _NO_ARG() @@ -116,7 +119,9 @@ def __call__(self, *args, **kwargs): values for any keyword arguments. """ if args and self.args: - raise Exception("Positional arguments already bound during TypeFactory creation.") + raise Exception( + "Positional arguments already bound during TypeFactory creation." + ) if self.args and not args: args = self.args return self.cls(*args, **TypeFactoryFactory._left_merge(kwargs, self.kwargs)) @@ -125,19 +130,19 @@ def __call__(self, *args, **kwargs): class Keyable(object): VALID_ES_INDEXES = [ - "analyzed", # full-text - "not_analyzed", # searchable, not full-text, - "no", # field is not searchable + "analyzed", # full-text + "not_analyzed", # searchable, not full-text, + "no", # field is not searchable ] VALID_ES_ANALYZERS = [ - "standard", # The standard analyzer is the default analyzer that Elasticsearch uses. - # It is the best general choice for analyzing text that may be in any language. - # It splits the text on word boundaries, as defined by the Unicode Consortium, - # and removes most punctuation. - "simple", # The simple analyzer splits the text on anything that isn't a letter, - # and lowercases the terms. It would produce - "whitespace", # The whitespace analyzer splits the text on whitespace. It doesn't lowercase. + "standard", # The standard analyzer is the default analyzer that Elasticsearch uses. + # It is the best general choice for analyzing text that may be in any language. + # It splits the text on word boundaries, as defined by the Unicode Consortium, + # and removes most punctuation. + "simple", # The simple analyzer splits the text on anything that isn't a letter, + # and lowercases the terms. It would produce + "whitespace", # The whitespace analyzer splits the text on whitespace. It doesn't lowercase. ] # defaults @@ -215,7 +220,9 @@ def _handle_validation_exception(policy, e): elif policy == "ignore": pass else: - raise Exception("Invalid validation policy. Must be one of: error, warn, ignore") + raise Exception( + "Invalid validation policy. Must be one of: error, warn, ignore" + ) @staticmethod def _validate_policy(name, policy): @@ -281,6 +288,7 @@ def with_args(cls, *args, **kwargs): def _populate_types_by_name(cls): if cls._types_by_name: return + def __iter_classes(kls): try: for klass in kls.__subclasses__(): @@ -294,14 +302,26 @@ def __iter_classes(kls): except: pass yield kls + for klass in __iter_classes(Keyable): Keyable._types_by_name[klass.__name__] = klass - - def __init__(self, required=_NO_ARG, desc=_NO_ARG, doc=_NO_ARG, category=_NO_ARG, - exclude=_NO_ARG, deprecated=_NO_ARG, ignore=_NO_ARG, - examples=_NO_ARG, metadata=_NO_ARG, validation_policy=_NO_ARG, pr_index=_NO_ARG, - pr_ignore=_NO_ARG, es_dynamic_policy=_NO_ARG): + def __init__( + self, + required=_NO_ARG, + desc=_NO_ARG, + doc=_NO_ARG, + category=_NO_ARG, + exclude=_NO_ARG, + deprecated=_NO_ARG, + ignore=_NO_ARG, + examples=_NO_ARG, + metadata=_NO_ARG, + validation_policy=_NO_ARG, + pr_index=_NO_ARG, + pr_ignore=_NO_ARG, + es_dynamic_policy=_NO_ARG, + ): global _keyable_counter self.set("required", required) self.set("desc", desc) @@ -320,16 +340,18 @@ def __init__(self, required=_NO_ARG, desc=_NO_ARG, doc=_NO_ARG, category=_NO_ARG self.set("es_dynamic_policy", es_dynamic_policy) if self.DEPRECATED_TYPE: - e = "WARN: %s is deprecated and will be removed in a "\ - "future release\n" % self.__class__.__name__ + e = ( + "WARN: %s is deprecated and will be removed in a " + "future release\n" % self.__class__.__name__ + ) sys.stderr.write(e) def to_dict(self): retv = { - "required":self.required, - "doc":self.doc, - "type":self.__class__.__name__, - "metadata":self.metadata, + "required": self.required, + "doc": self.doc, + "type": self.__class__.__name__, + "metadata": self.metadata, "examples": self.examples, } return retv diff --git a/zschema/leaves.py b/zschema/leaves.py index 817600e..6889e68 100644 --- a/zschema/leaves.py +++ b/zschema/leaves.py @@ -22,37 +22,41 @@ class Leaf(Keyable): ES_INDEX = None ES_ANALYZER = None - def __init__(self, - required=_NO_ARG, - es_index=_NO_ARG, - es_analyzer=_NO_ARG, - desc=_NO_ARG, - doc=_NO_ARG, - examples=_NO_ARG, - es_include_raw=_NO_ARG, - deprecated=_NO_ARG, - ignore=_NO_ARG, - category=_NO_ARG, - exclude=_NO_ARG, - metadata=_NO_ARG, - units=_NO_ARG, - min_value=_NO_ARG, - max_value=_NO_ARG, - validation_policy=_NO_ARG, - pr_index=_NO_ARG, - pr_ignore=_NO_ARG): - Keyable.__init__(self, - required=required, - desc=desc, - doc=doc, - category=category, - exclude=exclude, - deprecated=deprecated, - ignore=ignore, - examples=examples, - validation_policy=validation_policy, - pr_index=pr_index, - pr_ignore=pr_ignore) + def __init__( + self, + required=_NO_ARG, + es_index=_NO_ARG, + es_analyzer=_NO_ARG, + desc=_NO_ARG, + doc=_NO_ARG, + examples=_NO_ARG, + es_include_raw=_NO_ARG, + deprecated=_NO_ARG, + ignore=_NO_ARG, + category=_NO_ARG, + exclude=_NO_ARG, + metadata=_NO_ARG, + units=_NO_ARG, + min_value=_NO_ARG, + max_value=_NO_ARG, + validation_policy=_NO_ARG, + pr_index=_NO_ARG, + pr_ignore=_NO_ARG, + ): + Keyable.__init__( + self, + required=required, + desc=desc, + doc=doc, + category=category, + exclude=exclude, + deprecated=deprecated, + ignore=ignore, + examples=examples, + validation_policy=validation_policy, + pr_index=pr_index, + pr_ignore=pr_ignore, + ) self.set("es_index", es_index) self.set("es_analyzer", es_analyzer) self.set("units", units) @@ -72,14 +76,12 @@ def to_dict(self): return retv def to_es(self): - retv = {"type":self.ES_TYPE} + retv = {"type": self.ES_TYPE} self.add_not_empty(retv, "index", "es_index") self.add_not_empty(retv, "analyzer", "es_analyzer") self.add_not_empty(retv, "search_analyzer", "es_search_analyzer") if self.es_include_raw: - retv["fields"] = { - "raw":{"type":"keyword"} - } + retv["fields"] = {"raw": {"type": "keyword"}} return retv def _docs_common(self, parent_category): @@ -108,7 +110,7 @@ def to_bigquery(self, name): if not self._check_valid_name(name): raise Exception("Invalid field name: %s" % name) mode = "REQUIRED" if self.required else "NULLABLE" - retv = {"name":self.key_to_bq(name), "type":self.BQ_TYPE, "mode":mode} + retv = {"name": self.key_to_bq(name), "type": self.BQ_TYPE, "mode": mode} if self.doc: retv["doc"] = self.doc return retv @@ -118,12 +120,11 @@ def to_proto(self, name, indent): raise Exception("Invalid field name: %s" % name) return { "message": "", - "field": "%s %s" % (self.PR_TYPE, self.key_to_proto(name)) + "field": "%s %s" % (self.PR_TYPE, self.key_to_proto(name)), } def to_string(self, name): - return "%s: %s" % (self.key_to_string(name), - self.__class__.__name__.lower()) + return "%s: %s" % (self.key_to_string(name), self.__class__.__name__.lower()) def to_flat(self, parent, name, repeated=False): if not self._check_valid_name(name): @@ -136,19 +137,19 @@ def to_flat(self, parent, name, repeated=False): mode = "nullable" full_name = ".".join([parent, name]) if parent else name yield { - "name":full_name, - "type":self.__class__.__name__, + "name": full_name, + "type": self.__class__.__name__, "es_type": self.ES_TYPE, - "documentation":self.doc, - "mode":mode + "documentation": self.doc, + "mode": mode, } if self.es_include_raw: yield { - "name":full_name + ".raw", - "type":self.__class__.__name__, - "documentation":self.doc, + "name": full_name + ".raw", + "type": self.__class__.__name__, + "documentation": self.doc, "es_type": self.ES_TYPE, - "mode":mode + "mode": mode, } def print_indent_string(self, name, indent): @@ -158,7 +159,9 @@ def print_indent_string(self, name, indent): val = tabs + val print(val) - def validate(self, name, value, policy=_NO_ARG, parent_policy=_NO_ARG, path=_NO_ARG): + def validate( + self, name, value, policy=_NO_ARG, parent_policy=_NO_ARG, path=_NO_ARG + ): calculated_policy = self._calculate_policy(name, policy, parent_policy) try: self._raising_validate(name, value, path=path) @@ -172,8 +175,7 @@ def _raising_validate(self, name, value, path=_NO_ARG): raise Exception("Invalid field name: %s" % name) if value is None: if self.required: - msg = "{:s} is a required field, but received None".format( - name) + msg = "{:s} is a required field, but received None".format(name) raise DataValidationException(msg, path=path) else: return @@ -241,7 +243,8 @@ class WhitespaceAnalyzedString(AnalyzedString): } }' """ - ES_ANALYZER="lower_whitespace" + + ES_ANALYZER = "lower_whitespace" ES_INCLUDE_RAW = True @@ -256,8 +259,8 @@ class HexString(Leaf): INVALID = "asdfasdfa" VALID = "003a929e3e0bd48a1e7567714a1e0e9d4597fe9087b4ad39deb83ab10c5a0278" - #ES_SEARCH_ANALYZER = "lower_whitespace" - HEX_REGEX = re.compile('(?:[A-Fa-f0-9][A-Fa-f0-9])+') + # ES_SEARCH_ANALYZER = "lower_whitespace" + HEX_REGEX = re.compile("(?:[A-Fa-f0-9][A-Fa-f0-9])+") def _is_hex(self, s): return bool(self.HEX_REGEX.match(s)) @@ -298,8 +301,8 @@ def _docs_common(self, parent_category): del retv["examples"] return retv -class HTML(AnalyzedString): +class HTML(AnalyzedString): """ curl -XPUT 'localhost:9200/ipv4/_settings' -d '{ "analysis" : { @@ -313,6 +316,7 @@ class HTML(AnalyzedString): } }' """ + ES_ANALYZER = "html" @@ -379,13 +383,17 @@ class _Integer(Leaf): def _validate(self, name, value, path=_NO_ARG): max_ = 2**self.BITS - 1 - min_ = -2**self.BITS + 1 + min_ = -(2**self.BITS) + 1 if value > max_: - raise DataValidationException("%s: %s is larger than max (%s)" % (\ - name, str(value), str(max_)), path=path) + raise DataValidationException( + "%s: %s is larger than max (%s)" % (name, str(value), str(max_)), + path=path, + ) if value < min_: - raise DataValidationException("%s: %s is smaller than min (%s)" % (\ - name, str(value), str(min_)), path=path) + raise DataValidationException( + "%s: %s is smaller than min (%s)" % (name, str(value), str(min_)), + path=path, + ) class Signed32BitInteger(_Integer): @@ -404,7 +412,7 @@ class Signed8BitInteger(_Integer): PR_TYPE = "int32" BITS = 8 - INVALID = 2**8+5 + INVALID = 2**8 + 5 VALID = 34 @@ -436,7 +444,7 @@ class Signed64BitInteger(_Integer): PR_TYPE = "int64" EXPECTED_CLASS = (int,) - INVALID = int(2)**68 + INVALID = int(2) ** 68 VALID = int(10) BITS = 64 @@ -485,7 +493,9 @@ class Binary(Leaf): ES_INDEX = "no" EXPECTED_CLASS = string_types - B64_REGEX = re.compile('^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$') + B64_REGEX = re.compile( + "^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$" + ) def _is_base64(self, data): return bool(self.B64_REGEX.match(data)) @@ -519,8 +529,8 @@ class DateTime(Leaf): TZINFOS = { "EDT": datetime.timezone(datetime.timedelta(hours=-4)), # Eastern Daylight Time "EST": datetime.timezone(datetime.timedelta(hours=-5)), # Eastern Standard Time - 'CDT': datetime.timezone(datetime.timedelta(hours=-5)), # Central Daylight Time - 'CST': datetime.timezone(datetime.timedelta(hours=-6)), # Central Standard Time + "CDT": datetime.timezone(datetime.timedelta(hours=-5)), # Central Daylight Time + "CST": datetime.timezone(datetime.timedelta(hours=-6)), # Central Standard Time } VALID = "Wed Jul 8 08:52:01 EDT 2015" @@ -533,14 +543,22 @@ def __init__(self, *args, **kwargs): super(DateTime, self).__init__(*args, **kwargs) if self.min_value: - self._min_value_dt = dateutil.parser.parse(self.min_value, tzinfos=self.TZINFOS) + self._min_value_dt = dateutil.parser.parse( + self.min_value, tzinfos=self.TZINFOS + ) else: - self._min_value_dt = dateutil.parser.parse(self.MIN_VALUE, tzinfos=self.TZINFOS) + self._min_value_dt = dateutil.parser.parse( + self.MIN_VALUE, tzinfos=self.TZINFOS + ) if self.max_value: - self._max_value_dt = dateutil.parser.parse(self.max_value, tzinfos=self.TZINFOS) + self._max_value_dt = dateutil.parser.parse( + self.max_value, tzinfos=self.TZINFOS + ) else: - self._max_value_dt = dateutil.parser.parse(self.MAX_VALUE, tzinfos=self.TZINFOS) + self._max_value_dt = dateutil.parser.parse( + self.MAX_VALUE, tzinfos=self.TZINFOS + ) def _validate(self, name, value, path=_NO_ARG): try: @@ -557,18 +575,25 @@ def _validate(self, name, value, path=_NO_ARG): raise DataValidationException(m, path=path) dt = DateTime._ensure_tz_aware(dt) if dt > self._max_value_dt: - m = "%s: %s is greater than allowed maximum (%s)" % (name, - str(value), str(self._max_value_dt)) + m = "%s: %s is greater than allowed maximum (%s)" % ( + name, + str(value), + str(self._max_value_dt), + ) raise DataValidationException(m, path=path) if dt < self._min_value_dt: - m = "%s: %s is less than allowed minimum (%s)" % (name, - str(value), str(self._min_value_dt)) + m = "%s: %s is less than allowed minimum (%s)" % ( + name, + str(value), + str(self._min_value_dt), + ) raise DataValidationException(m, path=path) @staticmethod def _ensure_tz_aware(dt): """Ensures that the given datetime is timezone-aware. If it is not timezone-aware as - given, this function localizes it to UTC. Returns a timezone-aware datetime instance.""" + given, this function localizes it to UTC. Returns a timezone-aware datetime instance. + """ if dt.tzinfo: return dt return pytz.utc.localize(dt) @@ -665,4 +690,3 @@ class URI(URL): URI, EmailAddress, ] - diff --git a/zschema/registry.py b/zschema/registry.py index c50fcd8..4d44fed 100644 --- a/zschema/registry.py +++ b/zschema/registry.py @@ -8,14 +8,17 @@ except NameError: __zschema_schemas = {} + def register_schema(name, schema): global __zschema_schemas __zschema_schemas[name] = schema + def get_schema(name): global __zschema_schemas return __zschema_schemas[name] + def all_schemas(): global __zschema_schemas return __zschema_schemas.copy() @@ -25,4 +28,5 @@ def __register(self, name): register_schema(name, self) return self + Record.register = __register diff --git a/zschema/test_zschema.py b/zschema/test_zschema.py index 1f930bc..c15c1e2 100644 --- a/zschema/test_zschema.py +++ b/zschema/test_zschema.py @@ -7,13 +7,24 @@ from zschema import registry from zschema.compounds import ListOf, NestedListOf, Record, SubRecord, SubRecordType from zschema.keys import Keyable, Port, MergeConflictException -from zschema.leaves import Boolean, DateTime, Enum, IPv4Address, String, Unsigned8BitInteger, Unsigned32BitInteger, VALID_LEAVES +from zschema.leaves import ( + Boolean, + DateTime, + Enum, + IPv4Address, + String, + Unsigned8BitInteger, + Unsigned32BitInteger, + VALID_LEAVES, +) from zschema.leaves import DataValidationException def json_fixture(name): filename = name + ".json" - fixture_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fixtures', filename) + fixture_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "fixtures", filename + ) with open(fixture_path) as fixture_file: fixture = json.load(fixture_file) return fixture @@ -29,8 +40,7 @@ def test_invalid(self): for leaf in VALID_LEAVES: try: leaf(validation_policy="error").validate(leaf.__name__, leaf.INVALID) - raise Exception("invalid value did not fail for %s", - leaf.__name__) + raise Exception("invalid value did not fail for %s", leaf.__name__) except DataValidationException: continue @@ -60,33 +70,19 @@ def test_docs_bq(self): "properties": { "443": { "properties": { - "tls": { - "type": "keyword" - }, + "tls": {"type": "keyword"}, "heartbleed": { "properties": { - "heartbeat_support": { - "type": "boolean" - }, - "heartbleed_vulnerable": { - "type": "boolean" - }, - "timestamp": { - "type": "date" - } + "heartbeat_support": {"type": "boolean"}, + "heartbleed_vulnerable": {"type": "boolean"}, + "timestamp": {"type": "date"}, } - } + }, } }, - "ipstr": { - "type": "ip" - }, - "ip": { - "type": "long" - }, - "tags": { - "type": "keyword" - } + "ipstr": {"type": "ip"}, + "ip": {"type": "long"}, + "tags": {"type": "keyword"}, } } } @@ -110,7 +106,7 @@ def test_docs_bq(self): "doc": None, "examples": [], "required": False, - "type": "boolean" + "type": "boolean", }, "heartbleed_vulnerable": { "category": "Vulnerabilities", @@ -118,7 +114,7 @@ def test_docs_bq(self): "doc": None, "examples": [], "required": False, - "type": "boolean" + "type": "boolean", }, "timestamp": { "category": "heartbleed", @@ -126,11 +122,11 @@ def test_docs_bq(self): "doc": None, "examples": [], "required": False, - "type": "date" - } + "type": "date", + }, }, "required": False, - "type": "SubRecord" + "type": "SubRecord", }, "tls": { "category": "heartbleed", @@ -138,11 +134,11 @@ def test_docs_bq(self): "doc": None, "examples": [], "required": False, - "type": "keyword" - } + "type": "keyword", + }, }, "required": False, - "type": "SubRecord" + "type": "SubRecord", }, "ip": { "category": None, @@ -150,17 +146,15 @@ def test_docs_bq(self): "doc": "The IP Address of the host", "examples": [], "required": False, - "type": "long" + "type": "long", }, "ipstr": { "category": None, "detail_type": "IPv4Address", "doc": None, - "examples": [ - "8.8.8.8" - ], + "examples": ["8.8.8.8"], "required": True, - "type": "ip" + "type": "ip", }, "tags": { "category": None, @@ -169,11 +163,11 @@ def test_docs_bq(self): "examples": [], "repeated": True, "required": False, - "type": "keyword" - } + "type": "keyword", + }, }, "required": False, - "type": "Record" + "type": "Record", } } @@ -188,17 +182,15 @@ def test_docs_bq(self): "doc": "The IP Address of the host", "examples": [], "required": False, - "type": "INTEGER" + "type": "INTEGER", }, "ipstr": { "category": None, "detail_type": "IPv4Address", "doc": None, - "examples": [ - "8.8.8.8" - ], + "examples": ["8.8.8.8"], "required": True, - "type": "STRING" + "type": "STRING", }, "p443": { "category": "heartbleed", @@ -214,7 +206,7 @@ def test_docs_bq(self): "doc": None, "examples": [], "required": False, - "type": "BOOLEAN" + "type": "BOOLEAN", }, "heartbleed_vulnerable": { "category": "Vulnerabilities", @@ -222,7 +214,7 @@ def test_docs_bq(self): "doc": None, "examples": [], "required": False, - "type": "BOOLEAN" + "type": "BOOLEAN", }, "timestamp": { "category": "heartbleed", @@ -230,11 +222,11 @@ def test_docs_bq(self): "doc": None, "examples": [], "required": False, - "type": "DATETIME" - } + "type": "DATETIME", + }, }, "required": False, - "type": "SubRecord" + "type": "SubRecord", }, "tls": { "category": "heartbleed", @@ -242,11 +234,11 @@ def test_docs_bq(self): "doc": None, "examples": [], "required": False, - "type": "STRING" - } + "type": "STRING", + }, }, "required": False, - "type": "SubRecord" + "type": "SubRecord", }, "tags": { "category": None, @@ -255,64 +247,48 @@ def test_docs_bq(self): "examples": [], "repeated": True, "required": False, - "type": "STRING" - } + "type": "STRING", + }, }, "required": False, - "type": "Record" + "type": "Record", } } VALID_BIG_QUERY = [ { "fields": [ - { - "type": "STRING", - "name": "tls", - "mode": "NULLABLE" - }, + {"type": "STRING", "name": "tls", "mode": "NULLABLE"}, { "fields": [ { "type": "BOOLEAN", "name": "heartbeat_support", - "mode": "NULLABLE" + "mode": "NULLABLE", }, { "type": "BOOLEAN", "name": "heartbleed_vulnerable", - "mode": "NULLABLE" + "mode": "NULLABLE", }, - { - "type": "DATETIME", - "name": "timestamp", - "mode": "NULLABLE" - } + {"type": "DATETIME", "name": "timestamp", "mode": "NULLABLE"}, ], "type": "RECORD", "name": "heartbleed", - "mode": "NULLABLE" - } + "mode": "NULLABLE", + }, ], "type": "RECORD", "name": "p443", - "mode": "NULLABLE" - }, - { - "type": "STRING", - "name": "ipstr", - "mode": "REQUIRED" - }, - { - "type": "STRING", - "name": "tags", - "mode": "REPEATED" + "mode": "NULLABLE", }, + {"type": "STRING", "name": "ipstr", "mode": "REQUIRED"}, + {"type": "STRING", "name": "tags", "mode": "REPEATED"}, { "type": "INTEGER", "name": "ip", "doc": "The IP Address of the host", - "mode": "NULLABLE" + "mode": "NULLABLE", }, ] @@ -353,12 +329,11 @@ def assertBigQuerySchemaEqual(self, a, b): return else: self.assertEqual(type(a), type(b)) - if isinstance(a, Sized) \ - and isinstance(b, Sized): + if isinstance(a, Sized) and isinstance(b, Sized): self.assertEqual(len(a), len(b)) if isinstance(a, list) and isinstance(b, list): - name_ordered_a = sorted(a, key=lambda x: x['name']) - name_ordered_b = sorted(b, key=lambda x: x['name']) + name_ordered_a = sorted(a, key=lambda x: x["name"]) + name_ordered_b = sorted(b, key=lambda x: x["name"]) for x, y in zip(name_ordered_a, name_ordered_b): self.assertBigQuerySchemaEqual(x, y) elif isinstance(a, dict): @@ -369,22 +344,32 @@ def assertBigQuerySchemaEqual(self, a, b): self.assertEqual(a, b) def setUp(self): - self.maxDiff=10000 - - heartbleed = SubRecord({ # with explicit proto field indices - "heartbeat_support":Boolean(pr_index=11), - "heartbleed_vulnerable":Boolean(category="Vulnerabilities", pr_ignore=True), - "timestamp":DateTime(pr_index=10) - }, pr_index=77) - self.host = Record({ - "ipstr":IPv4Address(required=True, examples=["8.8.8.8"], pr_index=1), - "ip":Unsigned32BitInteger(doc="The IP Address of the host", pr_index=2), - Port(443):SubRecord({ - "tls":String(pr_index=1), - "heartbleed":heartbleed - }, category="heartbleed", pr_index=3), - "tags":ListOf(String(), pr_index=47) - }) + self.maxDiff = 10000 + + heartbleed = SubRecord( + { # with explicit proto field indices + "heartbeat_support": Boolean(pr_index=11), + "heartbleed_vulnerable": Boolean( + category="Vulnerabilities", pr_ignore=True + ), + "timestamp": DateTime(pr_index=10), + }, + pr_index=77, + ) + self.host = Record( + { + "ipstr": IPv4Address(required=True, examples=["8.8.8.8"], pr_index=1), + "ip": Unsigned32BitInteger( + doc="The IP Address of the host", pr_index=2 + ), + Port(443): SubRecord( + {"tls": String(pr_index=1), "heartbleed": heartbleed}, + category="heartbleed", + pr_index=3, + ), + "tags": ListOf(String(), pr_index=47), + } + ) def test_bigquery(self): global VALID_BIG_QUERY @@ -411,23 +396,11 @@ def test_docs_output(self): self.assertEqual(r, VALID_DOCS_OUTPUT_FOR_BIG_QUERY_FIELDS) def test_validation_known_good(self): - test = { - "ipstr":"141.212.120.1", - "ip":2379511809, - "443":{ - "tls":"test" - } - } + test = {"ipstr": "141.212.120.1", "ip": 2379511809, "443": {"tls": "test"}} self.host.validate(test) def test_validation_bad_key(self): - test = { - "keydne":"141.212.120.1asdf", - "ip":2379511809, - "443":{ - "tls":"test" - } - } + test = {"keydne": "141.212.120.1asdf", "ip": 2379511809, "443": {"tls": "test"}} try: self.host.validate(test) raise Exception("validation did not fail") @@ -435,13 +408,7 @@ def test_validation_bad_key(self): pass def test_validation_bad_value(self): - test = { - "ipstr":10, - "ip":2379511809, - "443":{ - "tls":"test" - } - } + test = {"ipstr": 10, "ip": 2379511809, "443": {"tls": "test"}} try: self.host.validate(test) raise Exception("validation did not fail") @@ -449,32 +416,28 @@ def test_validation_bad_value(self): pass def test_merge_no_conflict(self): - a = SubRecord({ - "a":String(), - "b":SubRecord({ - "c":String() - }) - }) - b = SubRecord({ - "d":String(), - }) - valid = SubRecord({ - "a":String(), - "b":SubRecord({ - "c":String() - }), - "d":String(), - - }) + a = SubRecord({"a": String(), "b": SubRecord({"c": String()})}) + b = SubRecord( + { + "d": String(), + } + ) + valid = SubRecord( + { + "a": String(), + "b": SubRecord({"c": String()}), + "d": String(), + } + ) self.assertEqual(a.merge(b).to_dict(), valid.to_dict()) def test_merge_different_types(self): - a = SubRecord({ - "a":String(), - }) - b = SubRecord({ - "a":SubRecord({}) - }) + a = SubRecord( + { + "a": String(), + } + ) + b = SubRecord({"a": SubRecord({})}) try: a.merge(b) raise Exception("validation did not fail") @@ -482,12 +445,16 @@ def test_merge_different_types(self): pass def test_merge_unmergable_types(self): - a = SubRecord({ - "a":String(), - }) - b = SubRecord({ - "a":String(), - }) + a = SubRecord( + { + "a": String(), + } + ) + b = SubRecord( + { + "a": String(), + } + ) try: a.merge(b) raise Exception("validation did not fail") @@ -495,65 +462,43 @@ def test_merge_unmergable_types(self): pass def test_merge_recursive(self): - a = SubRecord({ - "m":SubRecord({ - "a":String() - }) - }) - b = SubRecord({ - "a":String(), - "m":SubRecord({ - "b":String() - }) - - }) - c = SubRecord({ - "a":String(), - "m":SubRecord({ - "a":String(), - "b":String() - }) - }) + a = SubRecord({"m": SubRecord({"a": String()})}) + b = SubRecord({"a": String(), "m": SubRecord({"b": String()})}) + c = SubRecord({"a": String(), "m": SubRecord({"a": String(), "b": String()})}) self.assertEqual(a.merge(b).to_dict(), c.to_dict()) def test_extends(self): - host = Record({ - "host":IPv4Address(required=True), - "time":DateTime(required=True), - "data":SubRecord({}), - "error":String() - }) - banner_grab = Record({ - "data":SubRecord({ - "banner":String() - }) - }, extends=host) - tls_banner_grab = Record({ - "data":SubRecord({ - "tls":SubRecord({}) - }) - }, extends=banner_grab) - smtp_starttls = Record({ - "data":SubRecord({ - "ehlo":String() - }) - }, extends=tls_banner_grab) - - valid = Record({ - "host":IPv4Address(required=True), - "time":DateTime(required=True), - "data":SubRecord({ - "banner":String(), - "tls":SubRecord({}), - "ehlo":String() - }), - "error":String() - }) + host = Record( + { + "host": IPv4Address(required=True), + "time": DateTime(required=True), + "data": SubRecord({}), + "error": String(), + } + ) + banner_grab = Record({"data": SubRecord({"banner": String()})}, extends=host) + tls_banner_grab = Record( + {"data": SubRecord({"tls": SubRecord({})})}, extends=banner_grab + ) + smtp_starttls = Record( + {"data": SubRecord({"ehlo": String()})}, extends=tls_banner_grab + ) + + valid = Record( + { + "host": IPv4Address(required=True), + "time": DateTime(required=True), + "data": SubRecord( + {"banner": String(), "tls": SubRecord({}), "ehlo": String()} + ), + "error": String(), + } + ) self.assertEqual(smtp_starttls.to_dict(), valid.to_dict()) def test_null_required(self): test = { - "ipstr":None, + "ipstr": None, } try: self.host.validate(test) @@ -561,7 +506,7 @@ def test_null_required(self): except DataValidationException: pass - #def test_missing_required(self): + # def test_missing_required(self): # # ipstr is not set # test = { # "443":{ @@ -579,7 +524,7 @@ def test_null_subkey(self): "ipstr": "1.2.3.4", "443": { "heartbleed": None, - } + }, } try: self.host.validate(test) @@ -599,82 +544,66 @@ def test_null_port(self): pass def test_null_notrequired(self): - test = { - "ip":None, - "443":{ - "tls":"None" - } - } + test = {"ip": None, "443": {"tls": "None"}} self.host.validate(test) def test_parses_ipv4_records(self): - ipv4_host_ssh = Record({ - Port(22):SubRecord({ - "ssh":SubRecord({ - "banner": SubRecord({ - "comment":String(), - "timestamp":DateTime() - }) - }) - }) - }) - ipv4_host_ssh.validate(json_fixture('ipv4-ssh-record')) + ipv4_host_ssh = Record( + { + Port(22): SubRecord( + { + "ssh": SubRecord( + { + "banner": SubRecord( + {"comment": String(), "timestamp": DateTime()} + ) + } + ) + } + ) + } + ) + ipv4_host_ssh.validate(json_fixture("ipv4-ssh-record")) def test_es_dynamic_record(self): ipv4_host_with_dynamic_strict = Record() es = ipv4_host_with_dynamic_strict.to_es("strict-record") self.assertFalse("dynamic" in es["strict-record"]) - ipv4_host_with_dynamic_strict = Record( - es_dynamic_policy="strict" - ) + ipv4_host_with_dynamic_strict = Record(es_dynamic_policy="strict") es = ipv4_host_with_dynamic_strict.to_es("strict-record") self.assertEqual("strict", es["strict-record"]["dynamic"]) def test_subrecord_types(self): - SSH = SubRecordType({ - "banner":SubRecord({ - "comment":String(), - "timestamp":DateTime() - }) - }, + SSH = SubRecordType( + {"banner": SubRecord({"comment": String(), "timestamp": DateTime()})}, doc="class doc", - required=False) + required=False, + ) self.assertEqual(SSH.DOC, "class doc") self.assertEqual(SSH.REQUIRED, False) ssh = SSH(doc="instance doc") - ipv4_host_ssh = Record({ - Port(22):SubRecord({ - "ssh":ssh - }) - }) + ipv4_host_ssh = Record({Port(22): SubRecord({"ssh": ssh})}) self.assertEqual(ssh.doc, "instance doc") self.assertEqual(ssh.required, False) - ipv4_host_ssh.validate(json_fixture('ipv4-ssh-record')) + ipv4_host_ssh.validate(json_fixture("ipv4-ssh-record")) # class unchanged self.assertEqual(SSH.DOC, "class doc") self.assertEqual(SSH.REQUIRED, False) def test_subrecord_type_override(self): - SSH = SubRecordType({ - "banner": SubRecord({ - "comment": String(), - "timestamp": DateTime() - }) - }, + SSH = SubRecordType( + {"banner": SubRecord({"comment": String(), "timestamp": DateTime()})}, doc="class doc", - required=False) + required=False, + ) self.assertEqual(SSH.DOC, "class doc") self.assertEqual(SSH.REQUIRED, False) ssh = SSH(doc="instance doc", required=True) - ipv4_host_ssh = Record({ - Port(22):SubRecord({ - "ssh":ssh - }) - }) + ipv4_host_ssh = Record({Port(22): SubRecord({"ssh": ssh})}) self.assertEqual(ssh.doc, "instance doc") self.assertEqual(ssh.required, True) - ipv4_host_ssh.validate(json_fixture('ipv4-ssh-record')) + ipv4_host_ssh.validate(json_fixture("ipv4-ssh-record")) # class unchanged self.assertEqual(SSH.DOC, "class doc") self.assertEqual(SSH.REQUIRED, False) @@ -683,13 +612,17 @@ def test_subrecord_type_override(self): class RegistryTests(unittest.TestCase): def setUp(self): - self.host = Record({ - "ipstr":IPv4Address(required=True), - "ip":Unsigned32BitInteger(), - }) - self.domain = Record({ - "domain":String(required=True), - }) + self.host = Record( + { + "ipstr": IPv4Address(required=True), + "ip": Unsigned32BitInteger(), + } + ) + self.domain = Record( + { + "domain": String(required=True), + } + ) def test_get_registered(self): try: @@ -707,7 +640,7 @@ def test_get_registered(self): self.fail("registered schema should not throw") all_schemas = registry.all_schemas() self.assertEqual(1, len(all_schemas)) - all_schemas['domain'] = self.domain + all_schemas["domain"] = self.domain all_schemas = registry.all_schemas() self.assertEqual(1, len(all_schemas)) @@ -727,11 +660,14 @@ class SubRecordTests(unittest.TestCase): def test_subrecord_child_types_can_override_parent_attributes(self): Certificate = SubRecordType({}, doc="A parsed certificate.") c = Certificate(doc="The CA certificate.") - OtherType = SubRecord({ - "ca": c, - "host": Certificate(doc="The host certificate."), - }, doc="hello") - self.assertEqual("A parsed certificate." , Certificate().doc) + OtherType = SubRecord( + { + "ca": c, + "host": Certificate(doc="The host certificate."), + }, + doc="hello", + ) + self.assertEqual("A parsed certificate.", Certificate().doc) self.assertEqual("The CA certificate.", OtherType.definition["ca"].doc) self.assertEqual("The host certificate.", OtherType.definition["host"].doc) @@ -744,15 +680,21 @@ class WithArgsTests(unittest.TestCase): def test_with_args(self): Certificate = SubRecord.with_args({}, doc="A parsed certificate.") CertificateChain = ListOf.with_args(Certificate()) - AlgorithmType = String.with_args(doc="An algorithm identifier", examples=["a", "b", "c"]) - OtherType = SubRecord({ - "ca": Certificate(doc="The CA certificate."), - "host": Certificate(doc="The host certificate."), - "chain": CertificateChain(doc="The certificate chain."), - "host_alg": AlgorithmType(doc="The host algorithm", examples=["x", "y"]), - "client_alg": AlgorithmType(doc="The client algorithm"), - "sig_alg": AlgorithmType(examples=["p", "q"]), - }) + AlgorithmType = String.with_args( + doc="An algorithm identifier", examples=["a", "b", "c"] + ) + OtherType = SubRecord( + { + "ca": Certificate(doc="The CA certificate."), + "host": Certificate(doc="The host certificate."), + "chain": CertificateChain(doc="The certificate chain."), + "host_alg": AlgorithmType( + doc="The host algorithm", examples=["x", "y"] + ), + "client_alg": AlgorithmType(doc="The client algorithm"), + "sig_alg": AlgorithmType(examples=["p", "q"]), + } + ) # Check default self.assertEqual("A parsed certificate.", Certificate().doc) @@ -764,7 +706,9 @@ def test_with_args(self): self.assertEqual("The certificate chain.", OtherType.definition["chain"].doc) # Check that instance default is used in child - self.assertEqual("A parsed certificate.", OtherType.definition["chain"].object_.doc) + self.assertEqual( + "A parsed certificate.", OtherType.definition["chain"].object_.doc + ) # Check Leaf type doc overrides self.assertEqual("The host algorithm", OtherType.definition["host_alg"].doc) @@ -812,6 +756,7 @@ def __init__(self, *args, **kwargs): self.assertEqual(("a",), p0doc.args) self.assertEqual("some docs", p0doc.doc) + class DatetimeTest(unittest.TestCase): def test_datetime_DateTime(self): @@ -824,48 +769,65 @@ def test_datetime_DateTime(self): class ValidationPolicies(unittest.TestCase): def setUp(self): - self.maxDiff=10000 - - Child = SubRecordType({ - "foo":Boolean(), - "bar":Boolean(validation_policy="error"), - }, validation_policy="error") - self.record = Record({ - "a":Child(validation_policy="error"), - "b":Child(validation_policy="warn"), - "c":Child(validation_policy="ignore"), - "d":Child(validation_policy="inherit"), - }) + self.maxDiff = 10000 + + Child = SubRecordType( + { + "foo": Boolean(), + "bar": Boolean(validation_policy="error"), + }, + validation_policy="error", + ) + self.record = Record( + { + "a": Child(validation_policy="error"), + "b": Child(validation_policy="warn"), + "c": Child(validation_policy="ignore"), + "d": Child(validation_policy="inherit"), + } + ) def test_policy_setting_warn(self): - self.record.validate({"b":{"foo":"string value"}}) + self.record.validate({"b": {"foo": "string value"}}) def test_policy_setting_ignore(self): - self.record.validate({"c":{"foo":"string value"}}) + self.record.validate({"c": {"foo": "string value"}}) def test_policy_setting_error(self): - self.assertRaises(DataValidationException, lambda: self.record.validate({"c":{"bar":"string value"}})) + self.assertRaises( + DataValidationException, + lambda: self.record.validate({"c": {"bar": "string value"}}), + ) def test_policy_setting_inherit(self): - self.assertRaises(DataValidationException, lambda: self.record.validate({"a":{"foo":"string value"}})) + self.assertRaises( + DataValidationException, + lambda: self.record.validate({"a": {"foo": "string value"}}), + ) def test_policy_setting_multi_level_inherit(self): - self.assertRaises(DataValidationException, lambda: self.record.validate({"a":{"bar":"string value"}})) + self.assertRaises( + DataValidationException, + lambda: self.record.validate({"a": {"bar": "string value"}}), + ) def test_explicit_policy(self): - self.record.validate({"a":{"foo":"string value"}}, - policy="ignore") + self.record.validate({"a": {"foo": "string value"}}, policy="ignore") def test_child_subtree_overrides_and_inherits(self): - schema = Record({ - Port(445): SubRecord({ - "smb": SubRecord({ - "banner": SubRecord({ - "smb_v1": Boolean() - }) - }, validation_policy="error") - }) - }, validation_policy="warn") + schema = Record( + { + Port(445): SubRecord( + { + "smb": SubRecord( + {"banner": SubRecord({"smb_v1": Boolean()})}, + validation_policy="error", + ) + } + ) + }, + validation_policy="warn", + ) doc = { "445": { @@ -879,6 +841,7 @@ def test_child_subtree_overrides_and_inherits(self): } self.assertRaises(DataValidationException, lambda: schema.validate(doc)) + class ExcludeTests(unittest.TestCase): def test_ListOf_exclude(self): a = ListOf(String()) @@ -901,22 +864,32 @@ def test_ListOf_exclude(self): class PathLogUnitTests(unittest.TestCase): - sub_type = SubRecord({ - "sub1": String(), - "sub2": SubRecord({ - "sub2sub1": Unsigned8BitInteger(), - "sub2sub2": NestedListOf(String(), "sub2sub2.subrecord_name"), - }), - "sub3": Enum(values=["a", "b", "c"]) - }, validation_policy="error") - SCHEMA = Record({ - "a": SubRecord({ - "a1": String(), - "a2": ListOf(sub_type), - "a3": Unsigned8BitInteger(), - }), - "b": String(), - }, validation_policy="error") + sub_type = SubRecord( + { + "sub1": String(), + "sub2": SubRecord( + { + "sub2sub1": Unsigned8BitInteger(), + "sub2sub2": NestedListOf(String(), "sub2sub2.subrecord_name"), + } + ), + "sub3": Enum(values=["a", "b", "c"]), + }, + validation_policy="error", + ) + SCHEMA = Record( + { + "a": SubRecord( + { + "a1": String(), + "a2": ListOf(sub_type), + "a3": Unsigned8BitInteger(), + } + ), + "b": String(), + }, + validation_policy="error", + ) def test_good(self): good = { @@ -981,7 +954,7 @@ def test_bad_root(self): except DataValidationException as e: self.assertTrue(not e.path) - del(bad1["does_not_exist"]) + del bad1["does_not_exist"] bad1["b"] = 23 try: @@ -990,7 +963,6 @@ def test_bad_root(self): except DataValidationException as e: self.assertEqual(e.path, ["b"]) - def test_bad_a_key(self): bad = { "a": { @@ -1024,7 +996,7 @@ def test_bad_a_key(self): self.assertTrue(False, "bad failed to fail") except DataValidationException as e: self.assertEqual(e.path, ["a"]) - del(bad["a"]["does_not_exist"]) + del bad["a"]["does_not_exist"] bad["a"]["a3"] = "not an int" try: ret = self.SCHEMA.validate(bad, policy="error") @@ -1064,11 +1036,17 @@ def test_bad_deep_key(self): self.SCHEMA.validate(bad, policy="error") self.assertTrue(False, "failed to fail") except DataValidationException as e: - self.assertEqual(e.path, ["a", "a2", 0, "sub2", ]) - del(bad["a"]["a2"][0]["sub2"]["does_not_exist"]) - bad["a"]["a2"][0]["sub2"]["sub2sub2"][1] = { - "wrong type": "bad type" - } + self.assertEqual( + e.path, + [ + "a", + "a2", + 0, + "sub2", + ], + ) + del bad["a"]["a2"][0]["sub2"]["does_not_exist"] + bad["a"]["a2"][0]["sub2"]["sub2sub2"][1] = {"wrong type": "bad type"} try: self.SCHEMA.validate(bad, policy="error") self.assertTrue(False, "bad failed to fail") @@ -1079,10 +1057,12 @@ def test_bad_deep_key(self): class TestSubRecordType(unittest.TestCase): def test_subrecord_type(self): - A = SubRecordType({ - "string": String(), - "boolean": Boolean(), - }) + A = SubRecordType( + { + "string": String(), + "boolean": Boolean(), + } + ) first = A() second = A() @@ -1095,31 +1075,38 @@ def test_subrecord_type(self): self.assertIsInstance(second, A) # Check the properties aren't shared - self.assertIsNone(first.definition['string'].doc) - self.assertIsNone(second.definition['string'].doc) - first.definition['string'].doc = "hello" - self.assertIsNone(second.definition['string'].doc) + self.assertIsNone(first.definition["string"].doc) + self.assertIsNone(second.definition["string"].doc) + first.definition["string"].doc = "hello" + self.assertIsNone(second.definition["string"].doc) def test_subrecord_type_extends(self): - S = SubRecordType({ - "provided": Boolean(), - }) + S = SubRecordType( + { + "provided": Boolean(), + } + ) - extended_type = SubRecord({ - "property": String(), - "record": SubRecord({ - "another": String(), - }), - }, extends=S()) + extended_type = SubRecord( + { + "property": String(), + "record": SubRecord( + { + "another": String(), + } + ), + }, + extends=S(), + ) base = S() extends = extended_type self.assertNotIsInstance(extends, S) - self.assertFalse(base.definition['provided'].exclude) - self.assertFalse(extended_type.definition['provided'].exclude) - base.definition['provided'].exclude = ['bigquery'] - self.assertEqual(['bigquery'], base.definition['provided'].exclude) - self.assertFalse(extended_type.definition['provided'].exclude) + self.assertFalse(base.definition["provided"].exclude) + self.assertFalse(extended_type.definition["provided"].exclude) + base.definition["provided"].exclude = ["bigquery"] + self.assertEqual(["bigquery"], base.definition["provided"].exclude) + self.assertFalse(extended_type.definition["provided"].exclude) def test_indexing_works(self): definition = { @@ -1139,22 +1126,35 @@ def test_indexing_works(self): self.assertFalse(second.exclude) self.assertFalse(second["id"].exclude) - CertType = SubRecordType({ - "id": Unsigned32BitInteger(doc="The numerical certificate type value. 1 identifies user certificates, 2 identifies host certificates."), - "name": Enum(values=["USER", "HOST", "unknown"], doc="The human-readable name for the certificate type."), - }) + CertType = SubRecordType( + { + "id": Unsigned32BitInteger( + doc="The numerical certificate type value. 1 identifies user certificates, 2 identifies host certificates." + ), + "name": Enum( + values=["USER", "HOST", "unknown"], + doc="The human-readable name for the certificate type.", + ), + } + ) ssh_certkey_public_key_type = CertType(exclude={"bigquery"}) - ssh_certkey_public_key_type["id"].set("exclude", - ssh_certkey_public_key_type["id"].exclude | - {"elasticsearch"}) + ssh_certkey_public_key_type["id"].set( + "exclude", ssh_certkey_public_key_type["id"].exclude | {"elasticsearch"} + ) def test_multiple_subrecord_types(self): - A = SubRecordType({ - "first": String(), - }, type_name="A") - B = SubRecordType({ - "second": Boolean(), - }, type_name="B") + A = SubRecordType( + { + "first": String(), + }, + type_name="A", + ) + B = SubRecordType( + { + "second": Boolean(), + }, + type_name="B", + ) a = A() self.assertIn("first", a.definition) @@ -1162,5 +1162,3 @@ def test_multiple_subrecord_types(self): self.assertIn("second", b.definition) a = A() self.assertIn("first", a.definition) - -