Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Tests/test_file_spider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
TEST_FILE = "Tests/images/hopper.spider"


def teardown_module() -> None:
del Image.EXTENSION[".spider"]


def test_sanity() -> None:
with Image.open(TEST_FILE) as im:
im.load()
Expand Down
3 changes: 3 additions & 0 deletions Tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ def test_registered_extensions_uninitialized(self) -> None:
# Assert
assert Image._initialized == 2

for extension in Image.EXTENSION:
assert extension in Image._EXTENSION_PLUGIN

def test_registered_extensions(self) -> None:
# Arrange
# Open an image to trigger plugin registration
Expand Down
148 changes: 133 additions & 15 deletions src/PIL/Image.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,108 @@ def getmodebands(mode: str) -> int:

_initialized = 0

# Mapping from file extension to plugin module name for lazy importing
_EXTENSION_PLUGIN: dict[str, str] = {
# Common formats (preinit)
".bmp": "BmpImagePlugin",
".dib": "BmpImagePlugin",
".gif": "GifImagePlugin",
".jfif": "JpegImagePlugin",
".jpe": "JpegImagePlugin",
".jpg": "JpegImagePlugin",
".jpeg": "JpegImagePlugin",
".pbm": "PpmImagePlugin",
".pgm": "PpmImagePlugin",
".pnm": "PpmImagePlugin",
".ppm": "PpmImagePlugin",
".pfm": "PpmImagePlugin",
".png": "PngImagePlugin",
".apng": "PngImagePlugin",
# Less common formats (init)
".avif": "AvifImagePlugin",
".avifs": "AvifImagePlugin",
".blp": "BlpImagePlugin",
".bufr": "BufrStubImagePlugin",
".cur": "CurImagePlugin",
".dcx": "DcxImagePlugin",
".dds": "DdsImagePlugin",
".ps": "EpsImagePlugin",
".eps": "EpsImagePlugin",
".fit": "FitsImagePlugin",
".fits": "FitsImagePlugin",
".fli": "FliImagePlugin",
".flc": "FliImagePlugin",
".fpx": "FpxImagePlugin",
".ftc": "FtexImagePlugin",
".ftu": "FtexImagePlugin",
".gbr": "GbrImagePlugin",
".grib": "GribStubImagePlugin",
".h5": "Hdf5StubImagePlugin",
".hdf": "Hdf5StubImagePlugin",
".icns": "IcnsImagePlugin",
".ico": "IcoImagePlugin",
".im": "ImImagePlugin",
".iim": "IptcImagePlugin",
".jp2": "Jpeg2KImagePlugin",
".j2k": "Jpeg2KImagePlugin",
".jpc": "Jpeg2KImagePlugin",
".jpf": "Jpeg2KImagePlugin",
".jpx": "Jpeg2KImagePlugin",
".j2c": "Jpeg2KImagePlugin",
".mic": "MicImagePlugin",
".mpg": "MpegImagePlugin",
".mpeg": "MpegImagePlugin",
".mpo": "MpoImagePlugin",
".msp": "MspImagePlugin",
".palm": "PalmImagePlugin",
".pcd": "PcdImagePlugin",
".pcx": "PcxImagePlugin",
".pdf": "PdfImagePlugin",
".pxr": "PixarImagePlugin",
".psd": "PsdImagePlugin",
".qoi": "QoiImagePlugin",
".bw": "SgiImagePlugin",
".rgb": "SgiImagePlugin",
".rgba": "SgiImagePlugin",
".sgi": "SgiImagePlugin",
".ras": "SunImagePlugin",
".tga": "TgaImagePlugin",
".icb": "TgaImagePlugin",
".vda": "TgaImagePlugin",
".vst": "TgaImagePlugin",
".tif": "TiffImagePlugin",
".tiff": "TiffImagePlugin",
".webp": "WebPImagePlugin",
".wmf": "WmfImagePlugin",
".emf": "WmfImagePlugin",
".xbm": "XbmImagePlugin",
".xpm": "XpmImagePlugin",
}


def _import_plugin_for_extension(ext: str | bytes) -> bool:
"""Import only the plugin needed for a specific file extension."""
if not ext:
return False

if isinstance(ext, bytes):
ext = ext.decode()
ext = ext.lower()
if ext in EXTENSION:
return True

plugin = _EXTENSION_PLUGIN.get(ext)
if plugin is None:
return False

try:
logger.debug("Importing %s", plugin)
__import__(f"{__spec__.parent}.{plugin}", globals(), locals(), [])
return True
except ImportError as e:
logger.debug("Image: failed to import %s: %s", plugin, e)
return False


def preinit() -> None:
"""
Expand Down Expand Up @@ -382,11 +484,10 @@ def init() -> bool:
if _initialized >= 2:
return False

parent_name = __name__.rpartition(".")[0]
for plugin in _plugins:
try:
logger.debug("Importing %s", plugin)
__import__(f"{parent_name}.{plugin}", globals(), locals(), [])
__import__(f"{__spec__.parent}.{plugin}", globals(), locals(), [])
except ImportError as e:
logger.debug("Image: failed to import %s: %s", plugin, e)

Expand Down Expand Up @@ -2535,12 +2636,20 @@ def save(
# only set the name for metadata purposes
filename = os.fspath(fp.name)

preinit()
if format:
preinit()
else:
filename_ext = os.path.splitext(filename)[1].lower()
ext = (
filename_ext.decode()
if isinstance(filename_ext, bytes)
else filename_ext
)

filename_ext = os.path.splitext(filename)[1].lower()
ext = filename_ext.decode() if isinstance(filename_ext, bytes) else filename_ext
# Try importing only the plugin for this extension first
if not _import_plugin_for_extension(ext):
preinit()

if not format:
if ext not in EXTENSION:
init()
try:
Expand Down Expand Up @@ -3524,7 +3633,11 @@ def open(

prefix = fp.read(16)

preinit()
# Try to import just the plugin needed for this file extension
# before falling back to preinit() which imports common plugins
ext = os.path.splitext(filename)[1] if filename else ""
if not _import_plugin_for_extension(ext):
preinit()

warning_messages: list[str] = []

Expand Down Expand Up @@ -3560,14 +3673,19 @@ def _open_core(
im = _open_core(fp, filename, prefix, formats)

if im is None and formats is ID:
checked_formats = ID.copy()
if init():
im = _open_core(
fp,
filename,
prefix,
tuple(format for format in formats if format not in checked_formats),
)
# Try preinit (few common plugins) then init (all plugins)
for loader in (preinit, init):
checked_formats = ID.copy()
loader()
if formats != checked_formats:
im = _open_core(
fp,
filename,
prefix,
tuple(f for f in formats if f not in checked_formats),
)
if im is not None:
break

if im:
im._exclusive_fp = exclusive_fp
Expand Down
6 changes: 3 additions & 3 deletions src/PIL/SpiderImagePlugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:

def _save_spider(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
# get the filename extension and register it with Image
filename_ext = os.path.splitext(filename)[1]
ext = filename_ext.decode() if isinstance(filename_ext, bytes) else filename_ext
Image.register_extension(SpiderImageFile.format, ext)
if filename_ext := os.path.splitext(filename)[1]:
ext = filename_ext.decode() if isinstance(filename_ext, bytes) else filename_ext
Image.register_extension(SpiderImageFile.format, ext)
_save(im, fp, filename)


Expand Down
Loading