# Copyright 2024 mbodi ai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrap any common image representation in an Image class to convert to any other common format.
The following image representations are supported:
- NumPy array
- PIL Image
- Base64 encoded string
- File path
- URL
- Bytes object
The image can be resized to and from any size, compressed, and converted to and from any supported format:
```python
image = Image("path/to/image.png", size=new_size_tuple).save("path/to/new/image.jpg")
image.save("path/to/new/image.jpg", quality=5)
TODO: Implement Lazy attribute loading for the image data.
"""
import base64 as base64lib
import io
import logging
from pathlib import Path
from typing import Any, Tuple, Union
from urllib.parse import urlparse
import numpy as np
from datasets.features import Features
from datasets.features import Image as HFImage
from gymnasium import spaces
from PIL import Image as PILModule
from PIL.Image import Image as PILImage
from pydantic import (
AnyUrl,
Base64Str,
ConfigDict,
Field,
FilePath,
InstanceOf,
model_serializer,
model_validator,
)
from typing_extensions import Literal
from mbodied.types.ndarray import NumpyArray
from mbodied.types.sample import Sample
SupportsImage = Union[np.ndarray, PILImage, Base64Str, AnyUrl, FilePath] # noqa: UP007
[docs]
class Image(Sample):
"""An image sample that can be represented in various formats.
The image can be represented as a NumPy array, a base64 encoded string, a file path, a PIL Image object,
or a URL. The image can be resized to and from any size and converted to and from any supported format.
Attributes:
array (Optional[np.ndarray]): The image represented as a NumPy array.
base64 (Optional[Base64Str]): The base64 encoded string of the image.
path (Optional[FilePath]): The file path of the image.
pil (Optional[PILImage]): The image represented as a PIL Image object.
url (Optional[AnyUrl]): The URL of the image.
size (Optional[tuple[int, int]]): The size of the image as a (width, height) tuple.
encoding (Optional[Literal["png", "jpeg", "jpg", "bmp", "gif"]]): The encoding of the image.
Example:
>>> image = Image("https://example.com/image.jpg")
>>> image = Image("/path/to/image.jpg")
>>> image = Image("")
>>> jpeg_from_png = Image("path/to/image.png", encoding="jpeg")
>>> resized_image = Image(image, size=(224, 224))
>>> pil_image = Image(image).pil
>>> array = Image(image).array
>>> base64 = Image(image).base64
"""
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True, extras="forbid", validate_assignment=False)
array: NumpyArray
size: tuple[int, int]
pil: InstanceOf[PILImage] | None = Field(
None,
repr=False,
exclude=True,
description="The image represented as a PIL Image object.",
)
encoding: Literal["png", "jpeg", "jpg", "bmp", "gif"]
base64: InstanceOf[Base64Str] | None = None
url: InstanceOf[AnyUrl] | str | None = None
path: FilePath | None = None
[docs]
@classmethod
def supports(cls, arg: SupportsImage) -> bool:
if not isinstance(arg, np.ndarray | PILImage | AnyUrl | str):
return False
return Path(arg).exists() or arg.startswith("data:image")
def __init__(
self,
arg: SupportsImage = None,
url: str | None = None,
path: str | None = None,
base64: str | None = None,
array: np.ndarray | None = None,
pil: PILImage | None = None,
encoding: str | None = "jpeg",
size: Tuple | None = None,
bytes_obj: bytes | None = None,
**kwargs,
):
"""Initializes an image. Either one source argument or size tuple must be provided.
Args:
arg (SupportsImage, optional): The primary image source.
url (Optional[str], optional): The URL of the image.
path (Optional[str], optional): The file path of the image.
base64 (Optional[str], optional): The base64 encoded string of the image.
array (Optional[np.ndarray], optional): The numpy array of the image.
pil (Optional[PILImage], optional): The PIL image object.
encoding (Optional[str], optional): The encoding format of the image. Defaults to 'jpeg'.
size (Optional[Tuple[int, int]], optional): The size of the image as a (width, height) tuple.
**kwargs: Additional keyword arguments.
"""
kwargs["encoding"] = encoding or "jpeg"
kwargs["size"] = size
if arg is not None:
if isinstance(arg, bytes):
kwargs["bytes"] = arg
elif isinstance(arg, str):
if isinstance(arg, AnyUrl):
kwargs["url"] = arg
elif Path(arg).exists():
kwargs["path"] = arg
else:
kwargs["base64"] = arg
elif isinstance(arg, Path):
kwargs["path"] = str(arg)
elif isinstance(arg, np.ndarray):
kwargs["array"] = arg
elif isinstance(arg, PILImage):
kwargs["pil"] = arg
elif isinstance(arg, Image):
# Overwrite an Image instance with the new kwargs
kwargs.update({"array": arg.array})
elif isinstance(arg, Tuple) and len(arg) == 2:
kwargs["size"] = arg
else:
raise ValueError(f"Unsupported argument type '{type(arg)}'.")
else:
if url is not None:
kwargs["url"] = url
elif path is not None:
kwargs["path"] = path
elif base64 is not None:
kwargs["base64"] = base64
elif array is not None:
kwargs["array"] = array
elif pil is not None:
kwargs["pil"] = pil
elif bytes_obj is not None:
kwargs["bytes"] = bytes_obj
super().__init__(**kwargs)
def __repr__(self):
"""Return a string representation of the image."""
if self.base64 is None:
return f"Image(encoding={self.encoding}, size={self.size})"
return f"Image(base64={self.base64[:10]}..., encoding={self.encoding}, size={self.size})"
def __str__(self):
"""Return a string representation of the image."""
return f"Image(base64={self.base64[:10]}..., encoding={self.encoding}, size={self.size})"
[docs]
@staticmethod
def from_base64(base64_str: str, encoding: str, size=None) -> "Image":
"""Decodes a base64 string to create an Image instance.
Args:
base64_str (str): The base64 string to decode.
encoding (str): The format used for encoding the image when converting to base64.
size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple.
Returns:
Image: An instance of the Image class with populated fields.
"""
image_data = base64lib.b64decode(base64_str)
image = PILModule.open(io.BytesIO(image_data)).convert("RGB")
return Image(image, encoding, size)
[docs]
@staticmethod
def open(path: str, encoding: str = "jpeg", size=None) -> "Image":
"""Opens an image from a file path.
Args:
path (str): The path to the image file.
encoding (str): The format used for encoding the image when converting to base64.
size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple.
Returns:
Image: An instance of the Image class with populated fields.
"""
image = PILModule.open(path).convert("RGB")
return Image(image, encoding, size)
[docs]
@staticmethod
def pil_to_data(image: PILImage, encoding: str, size=None) -> dict:
"""Creates an Image instance from a PIL image.
Args:
image (PIL.Image.Image): The source PIL image from which to create the Image instance.
encoding (str): The format used for encoding the image when converting to base64.
size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple.
Returns:
Image: An instance of the Image class with populated fields.
"""
if encoding.lower() == "jpg":
encoding = "jpeg"
buffer = io.BytesIO()
image.convert("RGB").save(buffer, format=encoding.upper())
base64_encoded = base64lib.b64encode(buffer.getvalue()).decode("utf-8")
data_url = f"data:image/{encoding};base64,{base64_encoded}"
if size is not None:
image = image.resize(size)
else:
size = image.size
return {
"array": np.array(image),
"base64": base64_encoded,
"pil": image,
"size": size,
"url": data_url,
"encoding": encoding.lower(),
}
[docs]
@staticmethod
def load_url(url: str, download=False) -> PILImage | None:
"""Downloads an image from a URL or decodes it from a base64 data URI.
Args:
url (str): The URL of the image to download, or a base64 data URI.
Returns:
PIL.Image.Image: The downloaded and decoded image as a PIL Image object.
"""
if url.startswith("data:image"):
# Extract the base64 part of the data URI
base64_str = url.split(";base64", 1)[1]
image_data = base64lib.b64decode(base64_str)
return PILModule.open(io.BytesIO(image_data)).convert("RGB")
try:
# Open the URL and read the image data
import urllib.request
user_agent = (
"Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.9.0.7) Gecko/2009021910 Firefox/3.0.7"
)
headers = {
"User-Agent": user_agent,
}
if download:
accept = input("Do you want to download the image? (y/n): ")
if "y" not in accept.lower():
return None
if not url.startswith("http"):
raise ValueError("URL must start with 'http' or 'https'.")
request = urllib.request.Request(url, None, headers) # The assembled request
response = urllib.request.urlopen(request)
data = response.read() # The data u need
return PILModule.open(io.BytesIO(data)).convert("RGB")
except Exception as e:
logging.warning(f"Failed to load image from URL: {url}. {e}")
logging.warning("Not validating the Image data")
return None
[docs]
@classmethod
def from_bytes(cls, bytes_data: bytes, encoding: str = "jpeg", size=None) -> "Image":
"""Creates an Image instance from a bytes object.
Args:
bytes_data (bytes): The bytes object to convert to an image.
encoding (str): The format used for encoding the image when converting to base64.
size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple.
Returns:
Image: An instance of the Image class with populated fields.
"""
image = PILModule.open(io.BytesIO(bytes_data)).convert("RGB")
return cls(image, encoding, size)
[docs]
@staticmethod
def bytes_to_data(bytes_data: bytes, encoding: str = "jpeg", size=None) -> dict:
"""Creates an Image instance from a bytes object.
Args:
bytes_data (bytes): The bytes object to convert to an image.
encoding (str): The format used for encoding the image when converting to base64.
size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple.
Returns:
Image: An instance of the Image class with populated fields.
"""
image = PILModule.open(io.BytesIO(bytes_data)).convert("RGB")
return Image.pil_to_data(image, encoding, size)
[docs]
@model_validator(mode="before")
@classmethod
def validate_kwargs(cls, values) -> dict:
# Ensure that exactly one image source is provided
provided_fields = [
k for k in values if values[k] is not None and k in ["array", "base64", "path", "pil", "url"]
]
if len(provided_fields) > 1:
raise ValueError(f"Multiple image sources provided; only one is allowed but got: {provided_fields}")
# Initialize all fields to None or their default values
validated_values = {
"array": None,
"base64": None,
"encoding": values.get("encoding", "jpeg").lower(),
"path": None,
"pil": None,
"url": None,
"size": values.get("size", None),
}
# Validate the encoding first
if validated_values["encoding"] not in ["png", "jpeg", "jpg", "bmp", "gif"]:
raise ValueError("The 'encoding' must be a valid image format (png, jpeg, jpg, bmp, gif).")
if "bytes" in values and values["bytes"] is not None:
validated_values.update(cls.bytes_to_data(values["bytes"], values["encoding"], values["size"]))
return validated_values
if "pil" in values and values["pil"] is not None:
validated_values.update(
cls.pil_to_data(values["pil"], values["encoding"], values["size"]),
)
return validated_values
# Process the provided image source
if "path" in provided_fields:
image = PILModule.open(values["path"]).convert("RGB")
validated_values["path"] = values["path"]
validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"]))
elif "array" in provided_fields:
image = PILModule.fromarray(values["array"]).convert("RGB")
validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"]))
elif "pil" in provided_fields:
validated_values.update(
cls.pil_to_data(values["pil"], validated_values["encoding"], validated_values["size"]),
)
elif "base64" in provided_fields:
validated_values.update(
cls.from_base64(values["base64"], validated_values["encoding"], validated_values["size"]),
)
elif "url" in provided_fields:
url_path = urlparse(values["url"]).path
file_extension = (
Path(url_path).suffix[1:].lower() if Path(url_path).suffix else validated_values["encoding"]
)
validated_values["encoding"] = file_extension
image = cls.load_url(values["url"])
if image is None:
validated_values["array"] = np.zeros((224, 224, 3), dtype=np.uint8)
validated_values["size"] = (224, 224)
return validated_values
validated_values.update(cls.pil_to_data(image, file_extension, validated_values["size"]))
validated_values["url"] = values["url"]
elif "size" in values and values["size"] is not None:
array = np.zeros((values["size"][0], values["size"][1], 3), dtype=np.uint8)
image = PILModule.fromarray(array).convert("RGB")
validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"]))
if any(validated_values[k] is None for k in ["array", "base64", "pil", "url"]):
logging.warning(
f"Failed to validate image data. Could only fetch {[k for k in validated_values if validated_values[k] is not None]}",
)
return validated_values
[docs]
def save(self, path: str, encoding: str | None = None, quality: int = 10) -> None:
"""Save the image to the specified path.
If the image is a JPEG, the quality parameter can be used to set the quality of the saved image.
The path attribute of the image is updated to the new file path.
Args:
path (str): The path to save the image to.
encoding (Optional[str]): The encoding to use for saving the image.
quality (int): The quality to use for saving the image.
"""
if encoding == "png" and quality < 10:
raise ValueError("Quality can only be set for JPEG images.")
encoding = encoding or self.encoding
if quality < 10:
encoding = "jpeg"
pil_image = self.pil
if encoding != self.encoding:
pil_image = Image(self.array, encoding=encoding).pil
pil_image.save(path, encoding, quality=quality)
self.path = path # Update the path attribute to the new file path
[docs]
def show(self) -> None:
import platform
import matplotlib
if platform.system() == "Darwin":
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
plt.imshow(self.array)
[docs]
def space(self) -> spaces.Box:
"""Returns the space of the image."""
if self.size is None:
raise ValueError("Image size is not defined.")
return spaces.Box(low=0, high=255, shape=(*self.size, 3), dtype=np.uint8)
[docs]
@model_serializer(mode="plain", when_used="json")
def exclude_pil(self) -> dict:
"""Convert the image to a base64 encoded string."""
if self.base64 in self.url:
return {"size": self.size, "url": self.url, "encoding": self.encoding}
return {"base64": self.base64, "size": self.size, "url": self.url, "encoding": self.encoding}
[docs]
def dump(self, *args, as_field: str | None = None, **kwargs) -> dict | Any:
"""Return a dict or a field of the image."""
if as_field is not None:
return getattr(self, as_field)
return {
"array": self.array,
"base64": self.base64,
"size": self.size,
"url": self.url,
"encoding": self.encoding,
}
[docs]
def infer_features_dict(self) -> Features:
"""Infer features of the image."""
return HFImage()