Source code for mbodied.agents.sense.object_pose_estimator_3d

from typing import Dict, List

import numpy as np
from gradio_client import Client, handle_file
from PIL import Image as PILImage

from mbodied.agents.sense.sensory_agent import SensoryAgent
from mbodied.types.geometry import Pose6D
from mbodied.types.sample import Sample


[docs] class ObjectPoseEstimator3D(SensoryAgent): """3D object pose estimation class to interact with a Gradio server for image processing. Attributes: server_url (str): URL of the Gradio server. client (Client): Gradio client to interact with the server. """ def __init__(self, server_url: str = "https://api.mbodi.ai/3d-object-pose-detection") -> None: """Initialize the ObjectPoseEstimator3D with the server URL. Args: server_url (str): The URL of the Gradio server. """ self.server_url = server_url self.client = Client(self.server_url)
[docs] @staticmethod def save_data( color_image_array: np.ndarray, depth_image_array: np.ndarray, color_image_path: str, depth_image_path: str, intrinsic_matrix: np.ndarray, ) -> None: """Save color and depth images as PNG files. Args: color_image_array (np.ndarray): The color image array. depth_image_array (np.ndarray): The depth image array. color_image_path (str): The path to save the color image. depth_image_path (str): The path to save the depth image. intrinsic_matrix (np.ndarray): The intrinsic matrix. Example: >>> color_image = np.zeros((480, 640, 3), dtype=np.uint8) >>> depth_image = np.zeros((480, 640), dtype=np.uint16) >>> intrinsic_matrix = np.eye(3) >>> ObjectPoseEstimator3D.save_data(color_image, depth_image, "color.png", "depth.png", intrinsic_matrix) """ color_image = PILImage.fromarray(color_image_array, mode="RGB") depth_image = PILImage.fromarray(depth_image_array.astype("uint16"), mode="I;16") color_image.save(color_image_path, format="PNG") depth_image.save(depth_image_path, format="PNG") np.save("resources/intrinsic_matrix.npy", intrinsic_matrix)
[docs] def act( self, rgb_image_path: str, depth_image_path: str, camera_intrinsics: List[float] | np.ndarray, distortion_coeffs: List[float] | None = None, aruco_pose_world_frame: Pose6D | None = None, object_classes: List[str] | None = None, confidence_threshold: float | None = None, using_realsense: bool = False, ) -> Dict: """Capture images using the RealSense camera, process them, and send a request to estimate object poses. Args: rgb_image_path (str): Path to the RGB image. depth_image_path (str): Path to the depth image. camera_intrinsics (List[float] | np.ndarray): Path to the camera intrinsics or the intrinsic matrix. distortion_coeffs (Optional[List[float]]): List of distortion coefficients. aruco_pose_world_frame (Optional[Pose6D]): Pose of the ArUco marker in the world frame. object_classes (Optional[List[str]]): List of object classes. confidence_threshold (Optional[float]): Confidence threshold for object detection. using_realsense (bool): Whether to use the RealSense camera. Returns: Dict: Result from the Gradio server. Example: >>> estimator = ObjectPoseEstimator3D() >>> result = estimator.act( ... "resources/color_image.png", ... "resources/depth_image.png", ... [911, 911, 653, 371], ... [0.0, 0.0, 0.0, 0.0, 0.0], ... [0.0, 0.2032, 0.0, -90, 0, -90], ... ["Remote Control", "Basket", "Fork", "Spoon", "Red Marker"], ... 0.5, ... False, ... ) """ camera_source = "realsense" if using_realsense else "webcam" result = self.client.predict( image=handle_file(rgb_image_path), depth=handle_file(depth_image_path), camera_intrinsics={ "headers": ["fx", "fy", "cx", "cy"], "data": [Sample(camera_intrinsics).to("list")], "metadata": None, }, distortion_coeffs={ "headers": ["k1", "k2", "p1", "p2", "k3"], "data": [Sample(distortion_coeffs).to("list")], "metadata": None, }, aruco_to_base_offset={ "headers": ["Z(m)", "Y(m)", "X(m)", "Roll(degrees)", "Pitch(degrees)", "Yaw(degrees)"], "data": [Sample(aruco_pose_world_frame).to("list")], "metadata": None, }, object_classes={"headers": ["1"], "data": [Sample(object_classes).to("list")], "metadata": None}, confidence_threshold=confidence_threshold, camera_source=camera_source, ) return result # noqa: RET504