Source code for mbodied.agents.backends.openai_backend
# Copyright 2024 mbodi ai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
import backoff
import httpx
from anthropic import RateLimitError as AnthropicRateLimitError
from openai._exceptions import RateLimitError as OpenAIRateLimitError
from mbodied.agents.backends.serializer import Serializer
from mbodied.types.message import Message
from mbodied.types.sense.vision import Image
ERRORS = (
OpenAIRateLimitError,
AnthropicRateLimitError,
httpx.HTTPError,
ConnectionError,
)
[docs]
class OpenAISerializer(Serializer):
"""Serializer for OpenAI-specific data formats."""
[docs]
@classmethod
def serialize_image(cls, image: Image) -> dict[str, Any]:
"""Serializes an image to the OpenAI format.
Args:
image: The image to be serialized.
Returns:
A dictionary representing the serialized image.
"""
return {
"type": "image_url",
"image_url": {
"url": image.url,
},
}
[docs]
@classmethod
def serialize_text(cls, text: str) -> dict[str, Any]:
"""Serializes a text string to the OpenAI format.
Args:
text: The text to be serialized.
Returns:
A dictionary representing the serialized text.
"""
return {"type": "text", "text": text}
[docs]
class OpenAIBackendMixin:
"""Backend for interacting with OpenAI's API.
Attributes:
api_key: The API key for the OpenAI service.
client: The client for the OpenAI service.
serialized: The serializer for the OpenAI backend.
response_format: The format for the response.
"""
INITIAL_CONTEXT = [
Message(role="system", content="You are a robot with advanced spatial reasoning."),
]
DEFAULT_MODEL = "gpt-4o"
def __init__(self, api_key: str | None = None, client: Any | None = None, response_format: str = None, **kwargs):
"""Initializes the OpenAIBackend with the given API key and client.
Args:
api_key: The API key for the OpenAI service.
client: An optional client for the OpenAI service.
response_format: The format for the response.
**kwargs: Additional keyword arguments.
"""
self.api_key = api_key
self.client = client
if self.client is None:
from openai import OpenAI
kwargs.pop("model_src", None)
self.client = OpenAI(api_key=self.api_key, **kwargs)
self.serialized = OpenAISerializer
self.response_format = response_format
def _create_completion(self, messages: List[Message], model: str = "gpt-4o", stream: bool = False, **kwargs) -> str:
"""Creates a completion for the given messages using the OpenAI API standard.
Args:
messages: A list of messages to be sent to the completion API.
model: The model to be used for the completion.
stream: Whether to stream the response. Defaults to False.
**kwargs: Additional keyword arguments.
Returns:
str: The content of the completion response.
"""
serialized_messages = [self.serialized(msg) for msg in messages]
completion = self.client.chat.completions.create(
model=model,
messages=serialized_messages,
temperature=0,
max_tokens=1000,
stream=stream,
response_format=self.response_format,
**kwargs,
)
return completion.choices[0].message.content
[docs]
@backoff.on_exception(
backoff.expo,
ERRORS,
max_tries=3,
)
def predict(self, message: Message, context: List[Message] | None = None, model: Any | None = None, **kwargs) -> str:
"""Create a completion based on the given message and context.
Args:
message (Message): The message to process.
context (Optional[List[Message]]): The context of messages.
model (Optional[Any]): The model used for processing the messages.
**kwargs: Additional keyword arguments.
Returns:
str: The result of the completion.
"""
if context is None:
context = []
if model is not None:
kwargs["model"] = model
return self._create_completion(context + [message], **kwargs)