Source code for mbodied.agents.backends.openai_backend
# Copyright 2024 mbodi ai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
import backoff
import httpx
from anthropic import RateLimitError as AnthropicRateLimitError
from openai._exceptions import RateLimitError as OpenAIRateLimitError
from mbodied.agents.backends.serializer import Serializer
from mbodied.types.message import Message
from mbodied.types.sense.vision import Image
ERRORS = (
    OpenAIRateLimitError,
    AnthropicRateLimitError,
    httpx.HTTPError,
    ConnectionError,
)
[docs]
class OpenAISerializer(Serializer):
    """Serializer for OpenAI-specific data formats."""
[docs]
    @classmethod
    def serialize_image(cls, image: Image) -> dict[str, Any]:
        """Serializes an image to the OpenAI format.
        Args:
            image: The image to be serialized.
        Returns:
            A dictionary representing the serialized image.
        """
        return {
            "type": "image_url",
            "image_url": {
                "url": image.url,
            },
        } 
[docs]
    @classmethod
    def serialize_text(cls, text: str) -> dict[str, Any]:
        """Serializes a text string to the OpenAI format.
        Args:
            text: The text to be serialized.
        Returns:
            A dictionary representing the serialized text.
        """
        return {"type": "text", "text": text} 
 
[docs]
class OpenAIBackendMixin:
    """Backend for interacting with OpenAI's API.
    Attributes:
        api_key: The API key for the OpenAI service.
        client: The client for the OpenAI service.
        serialized: The serializer for the OpenAI backend.
        response_format: The format for the response.
    """
    INITIAL_CONTEXT = [
        Message(role="system", content="You are a robot with advanced spatial reasoning."),
    ]
    DEFAULT_MODEL = "gpt-4o"
    def __init__(self, api_key: str | None = None, client: Any | None = None, response_format: str = None, **kwargs):
        """Initializes the OpenAIBackend with the given API key and client.
        Args:
            api_key: The API key for the OpenAI service.
            client: An optional client for the OpenAI service.
            response_format: The format for the response.
            **kwargs: Additional keyword arguments.
        """
        self.api_key = api_key
        self.client = client
        if self.client is None:
            from openai import OpenAI
            kwargs.pop("model_src", None)
            self.client = OpenAI(api_key=self.api_key, **kwargs)
        self.serialized = OpenAISerializer
        self.response_format = response_format
    def _create_completion(self, messages: List[Message], model: str = "gpt-4o", stream: bool = False, **kwargs) -> str:
        """Creates a completion for the given messages using the OpenAI API standard.
        Args:
            messages: A list of messages to be sent to the completion API.
            model: The model to be used for the completion.
            stream: Whether to stream the response. Defaults to False.
            **kwargs: Additional keyword arguments.
        Returns:
            str: The content of the completion response.
        """
        serialized_messages = [self.serialized(msg) for msg in messages]
        completion = self.client.chat.completions.create(
            model=model,
            messages=serialized_messages,
            temperature=0,
            max_tokens=1000,
            stream=stream,
            response_format=self.response_format,
            **kwargs,
        )
        return completion.choices[0].message.content
[docs]
    @backoff.on_exception(
        backoff.expo,
        ERRORS,
        max_tries=3,
    )
    def predict(self, message: Message, context: List[Message] | None = None, model: Any | None = None, **kwargs) -> str:
        """Create a completion based on the given message and context.
        Args:
            message (Message): The message to process.
            context (Optional[List[Message]]): The context of messages.
            model (Optional[Any]): The model used for processing the messages.
            **kwargs: Additional keyword arguments.
        Returns:
            str: The result of the completion.
        """
        if context is None:
            context = []
        if model is not None:
            kwargs["model"] = model
        return self._create_completion(context + [message], **kwargs)