diff --git a/scrapscript.py b/scrapscript.py index 805f8d1f..895aa34d 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -984,6 +984,8 @@ class Variant(Object): TYPE_ACCESS := b"@", TYPE_SPREAD := b"S", TYPE_NAMED_SPREAD := b"R", + TYPE_TRUE := b"T", + TYPE_FALSE := b"F", ] FLAG_REF = 0x80 @@ -1091,6 +1093,10 @@ def serialize(self, obj: Object) -> None: self.serialize(item) return if isinstance(obj, Variant): + if obj.tag == "true" and isinstance(obj.value, Hole): + return self.emit(TYPE_TRUE) + if obj.tag == "false" and isinstance(obj.value, Hole): + return self.emit(TYPE_FALSE) # TODO(max): Determine if this should be a ref self.emit(TYPE_VARIANT) # TODO(max): String pool (via refs) for strings longer than some length? @@ -1329,6 +1335,12 @@ def parse(self) -> Object: return Spread() if ty == TYPE_NAMED_SPREAD: return Spread(self._string()) + if ty == TYPE_TRUE: + assert not is_ref + return Variant("true", Hole()) + if ty == TYPE_FALSE: + assert not is_ref + return Variant("false", Hole()) raise NotImplementedError(bytes(ty)) @@ -2531,7 +2543,7 @@ def do_GET(self) -> None: if scrap is not None: self.send_response(200) self.send_header("Content-Type", "application/scrap; charset=binary") - self.send_header("Content-Disposition", f'attachment; filename={json.dumps(f"{path}.scrap")}') + self.send_header("Content-Disposition", f"attachment; filename={json.dumps(f'{path}.scrap')}") self.send_header("Content-Length", str(len(scrap))) self.end_headers() self.wfile.write(scrap) diff --git a/scrapscript_tests.py b/scrapscript_tests.py index ceecf58b..be6972b7 100644 --- a/scrapscript_tests.py +++ b/scrapscript_tests.py @@ -3778,6 +3778,22 @@ def test_access(self) -> None: def test_spread(self) -> None: self.assertEqual(self._serialize(Spread()), TYPE_SPREAD) self.assertEqual(self._serialize(Spread("rest")), TYPE_NAMED_SPREAD + b"\x08rest") + + def test_true_variant(self) -> None: + obj = Variant("true", Hole()) + self.assertEqual(self._serialize(obj), TYPE_TRUE) + + def test_false_variant(self) -> None: + obj = Variant("false", Hole()) + self.assertEqual(self._serialize(obj), TYPE_FALSE) + + def test_true_variant_with_non_hole_uses_regular_variant(self) -> None: + obj = Variant("true", Int(123)) + self.assertEqual(self._serialize(obj), TYPE_VARIANT + b"\x08truei\xf6\x01") + + def test_false_variant_with_non_hole_uses_regular_variant(self) -> None: + obj = Variant("false", Int(123)) + self.assertEqual(self._serialize(obj), TYPE_VARIANT + b"\x0afalsei\xf6\x01") class RoundTripSerializationTests(unittest.TestCase):