diff --git a/src/pv/binding_generator/__init__.py b/src/pv/binding_generator/__init__.py index 1866a24..8bf737e 100644 --- a/src/pv/binding_generator/__init__.py +++ b/src/pv/binding_generator/__init__.py @@ -26,7 +26,26 @@ def generate_binding(tar_archive: bytes, plugin: str) -> bytes: """ Decompress plugin's source code in the temporary directory """ with io.BytesIO(tar_archive) as tar_stream: with tarfile.open(fileobj=tar_stream, mode='r:bz2') as tar: - tar.extractall(path=tmp_dir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=tmp_dir) """ Compile plugin """ plugin_dir = os.path.join(tmp_dir, plugin) diff --git a/src/pv/verifiers/controller.py b/src/pv/verifiers/controller.py index 6750d54..6d76516 100644 --- a/src/pv/verifiers/controller.py +++ b/src/pv/verifiers/controller.py @@ -55,7 +55,26 @@ with io.BytesIO(plugin_code) as tar_stream: try: with tarfile.open(fileobj=tar_stream, mode='r:bz2') as tar: - tar.extractall(path=tmpdir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=tmpdir) logger.info('* Archive extracted.') except tarfile.TarError as error: logger.error(error)