Source code for mbodied.agents.backends.serializer

# Copyright 2024 mbodi ai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from pydantic import ConfigDict, model_serializer, model_validator

from mbodied.types.message import Message
from mbodied.types.sample import Sample
from mbodied.types.sense.vision import Image


[docs] class Serializer(Sample): """A class to serialize messages and samples. This class provides a mechanism to serialize messages and samples into a dictionary format used by i.e. OpenAI, Anthropic, or other APIs. Attributes: wrapped: The message or sample to be serialized. model_config: The Pydantic configuration for the Serializer model. """ wrapped: Any | None = None model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) def __init__( self, wrapped: Message | Sample | list[Message] | None = None, *, message: Message | None = None, sample: Sample | None = None, **data, ): """Initializes the Serializer with various possible wrapped types. Args: wrapped: An instance of Message, Sample, a list of Messages, or None. message: An optional Message to be wrapped. sample: An optional Sample to be wrapped. **data: Additional data to initialize the Sample base class. """ if wrapped is not None: data["wrapped"] = wrapped elif message is not None: data["wrapped"] = message elif sample is not None: data["wrapped"] = sample super().__init__(**data)
[docs] @model_validator(mode="before") @classmethod def validate_model(cls, values: dict[str, Any]) -> dict[str, Any] | list[Any]: """Validates the 'wrapped' field of the model. Args: values: A dictionary of field values to validate. Returns: The validated values dictionary. Raises: ValueError: If the 'wrapped' field contains an invalid type. """ if ( "wrapped" in values and values["wrapped"] is not None and not isinstance( values["wrapped"], Message | Sample | list | str | Image, ) ): raise ValueError( f"Invalid wrapped type {type(values['wrapped'])}", ) return values
[docs] def serialize_sample(self, sample: Any) -> dict[str, Any]: """Serializes a given sample. Args: sample: The sample to be serialized. Returns: A dictionary representing the serialized sample. """ if isinstance(sample, Message): return self.serialize_msg(sample) if not isinstance(sample, Sample): sample = Sample(sample) if isinstance(sample, Image): return self.serialize_image(sample) if Image.supports(sample): return self.serialize_image(Image(sample)) if hasattr(sample, "datum") and isinstance(sample.datum, str): return self.serialize_text(sample.datum) return self.serialize_text(str(sample))
[docs] @model_serializer(when_used="always") def serialize(self) -> dict[str, Any] | list[Any]: """Serializes the wrapped content of the Serializer instance. Returns: A dictionary representing the serialized wrapped content. """ if isinstance(self.wrapped, Message): return self.serialize_msg(self.wrapped) if isinstance(self.wrapped, list): if all(isinstance(m, Message) for m in self.wrapped): return [self.serialize_msg(m) for m in self.wrapped] return [self.serialize_sample(m) for m in self.wrapped] return self.serialize_sample(self.wrapped)
[docs] def serialize_msg(self, message: Message) -> dict[str, Any]: """Serializes a Message instance. Args: message: The Message to be serialized. Returns: A dictionary representing the serialized Message. """ return { "role": message.role, "content": [self.serialize_sample(c) for c in message.content], }
[docs] @classmethod def serialize_image(cls, image: Image) -> dict[str, Any]: """Serializes an Image instance. Args: image: The Image to be serialized. Returns: A dictionary representing the serialized Image. """ return { "type": "image", "image_url": f"data:image/{image.encoding};base64," + image.base64, }
[docs] @classmethod def serialize_text(cls, text: str) -> dict[str, Any]: """Serializes a text string. Args: text: The text to be serialized. Returns: A dictionary representing the serialized text. """ return {"type": "text", "text": text}
def __call__(self) -> dict[str, Any] | list[Any]: """Calls the serialize method. Returns: A dictionary representing the serialized wrapped content. """ return self.model_dump()