Source code for mbodied.agents.backends.gradio_backend

# Copyright 2024 mbodi ai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from gradio_client import Client
from gradio_client.client import Job

from mbodied.agents.backends.backend import Backend


[docs] class GradioBackend(Backend): """Gradio backend that handles connections to gradio servers.""" def __init__( self, model_src: str = None, **kwargs, ) -> None: self.model_src = model_src self.client = Client(src=model_src, **kwargs)
[docs] def predict(self, *args, **kwargs) -> str: """Forward queries to the gradio api endpoint `predict`. Args: *args: The arguments to pass to the gradio server. **kwargs: The keywrod arguments to pass to the gradio server. """ return self.client.predict(*args, **kwargs)
[docs] def submit(self, *args, api_name="/predict", result_callbacks=None, **kwargs) -> Job: """Submit queries asynchronously without need of asyncio. Args: *args: The arguments to pass to the gradio server. api_name: The name of the api endpoint to submit the job. result_callbacks: The callbacks to apply to the result. **kwargs: The keywrod arguments to pass to the gradio server. Returns: Job: Gradio job object. """ return self.client.submit(api_name=api_name, result_callbacks=result_callbacks, *args, **kwargs)