Source code for mbodied.agents.backends.openai_backend
# Copyright 2024 mbodi ai## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## https://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportAny,ListimportbackoffimporthttpxfromanthropicimportRateLimitErrorasAnthropicRateLimitErrorfromopenai._exceptionsimportRateLimitErrorasOpenAIRateLimitErrorfrommbodied.agents.backends.serializerimportSerializerfrommbodied.types.messageimportMessagefrommbodied.types.sense.visionimportImageERRORS=(OpenAIRateLimitError,AnthropicRateLimitError,httpx.HTTPError,ConnectionError,)
[docs]classOpenAISerializer(Serializer):"""Serializer for OpenAI-specific data formats."""
[docs]@classmethoddefserialize_image(cls,image:Image)->dict[str,Any]:"""Serializes an image to the OpenAI format. Args: image: The image to be serialized. Returns: A dictionary representing the serialized image. """return{"type":"image_url","image_url":{"url":image.url,},}
[docs]@classmethoddefserialize_text(cls,text:str)->dict[str,Any]:"""Serializes a text string to the OpenAI format. Args: text: The text to be serialized. Returns: A dictionary representing the serialized text. """return{"type":"text","text":text}
[docs]classOpenAIBackendMixin:"""Backend for interacting with OpenAI's API. Attributes: api_key: The API key for the OpenAI service. client: The client for the OpenAI service. serialized: The serializer for the OpenAI backend. response_format: The format for the response. """INITIAL_CONTEXT=[Message(role="system",content="You are a robot with advanced spatial reasoning."),]DEFAULT_MODEL="gpt-4o"def__init__(self,api_key:str|None=None,client:Any|None=None,response_format:str=None,**kwargs):"""Initializes the OpenAIBackend with the given API key and client. Args: api_key: The API key for the OpenAI service. client: An optional client for the OpenAI service. response_format: The format for the response. **kwargs: Additional keyword arguments. """self.api_key=api_keyself.client=clientifself.clientisNone:fromopenaiimportOpenAIkwargs.pop("model_src",None)self.client=OpenAI(api_key=self.api_key,**kwargs)self.serialized=OpenAISerializerself.response_format=response_formatdef_create_completion(self,messages:List[Message],model:str="gpt-4o",stream:bool=False,**kwargs)->str:"""Creates a completion for the given messages using the OpenAI API standard. Args: messages: A list of messages to be sent to the completion API. model: The model to be used for the completion. stream: Whether to stream the response. Defaults to False. **kwargs: Additional keyword arguments. Returns: str: The content of the completion response. """serialized_messages=[self.serialized(msg)formsginmessages]completion=self.client.chat.completions.create(model=model,messages=serialized_messages,temperature=0,max_tokens=1000,stream=stream,response_format=self.response_format,**kwargs,)returncompletion.choices[0].message.content
[docs]@backoff.on_exception(backoff.expo,ERRORS,max_tries=3,)defpredict(self,message:Message,context:List[Message]|None=None,model:Any|None=None,**kwargs)->str:"""Create a completion based on the given message and context. Args: message (Message): The message to process. context (Optional[List[Message]]): The context of messages. model (Optional[Any]): The model used for processing the messages. **kwargs: Additional keyword arguments. Returns: str: The result of the completion. """ifcontextisNone:context=[]ifmodelisnotNone:kwargs["model"]=modelreturnself._create_completion(context+[message],**kwargs)