# Render utils for PyTorch3D # Adapted and improved from: https://github.com/ThunderVVV/HaWoR/blob/main/lib/vis/renderer.py import torch import numpy as np from typing import List, Tuple, Union from pytorch3d.renderer import ( PerspectiveCameras, MeshRenderer, MeshRasterizer, SoftPhongShader, RasterizationSettings, PointLights, TexturesVertex ) from pytorch3d.structures import Meshes from pytorch3d.renderer.camera_conversions import _cameras_from_opencv_projection def update_intrinsics_from_bbox( K_org: torch.Tensor, bbox: torch.Tensor ) -> Tuple[torch.Tensor, List[Tuple[int, int]]]: """ Update intrinsic matrix K according to the given bounding box. Args: K_org (torch.Tensor): Original intrinsic matrix of shape (B, 3, 3). bbox (torch.Tensor): Bounding boxes of shape (B, 4) in (left, top, right, bottom) format. Returns: K_new (torch.Tensor): Updated intrinsics with shape (B, 4, 4). image_sizes (List[Tuple[int, int]]): List of image sizes (height, width) for each bbox. """ device, dtype = K_org.device, K_org.dtype # Initialize 4x4 intrinsic matrix K_new = torch.zeros((K_org.shape[0], 4, 4), device=device, dtype=dtype) K_new[:, :3, :3] = K_org.clone() K_new[:, 2, 2] = 0 K_new[:, 2, -1] = 1 K_new[:, -1, 2] = 1 image_sizes = [] for idx, box in enumerate(bbox): left, top, right, bottom = box cx, cy = K_new[idx, 0, 2], K_new[idx, 1, 2] # Adjust principal point according to bbox new_cx = cx - left new_cy = cy - top # Compute new width and height new_height = max(bottom - top, 1) new_width = max(right - left, 1) # Flip principal point coordinates if needed new_cx = new_width - new_cx new_cy = new_height - new_cy K_new[idx, 0, 2] = new_cx K_new[idx, 1, 2] = new_cy image_sizes.append((int(new_height), int(new_width))) return K_new, image_sizes class Renderer(): """ Renderer class using PyTorch3D for mesh rendering with Phong shading. Attributes: width (int): Target image width. height (int): Target image height. focal_length (Union[float, Tuple[float, float]]): Camera focal length(s). device (torch.device): Device to run rendering on. renderer (MeshRenderer): PyTorch3D mesh renderer. cameras (PerspectiveCameras): Camera object. lights (PointLights): Lighting setup for rendering. """ def __init__( self, width: int, height: int, focal_length: Union[float, Tuple[float, float]], device: torch.device, bin_size: int = 512, max_faces_per_bin: int = 200000, ): self.width = width self.height = height self.focal_length = focal_length self.device = device # Initialize camera parameters self._initialize_camera_params() # Set up lighting self.lights = PointLights( device=device, location = ((0.0, -1.5, -1.5),), ambient_color=((0.75, 0.75, 0.75),), diffuse_color=((0.25, 0.25, 0.25),), specular_color=((0.02, 0.02, 0.02),) ) # Initialize renderer self._create_renderer(bin_size, max_faces_per_bin) def _create_renderer(self, bin_size: int, max_faces_per_bin: int): """ Create the PyTorch3D MeshRenderer with rasterizer and shader. """ self.renderer = MeshRenderer( rasterizer=MeshRasterizer( raster_settings=RasterizationSettings( image_size=self.image_sizes[0], blur_radius=1e-5, bin_size=bin_size, max_faces_per_bin=max_faces_per_bin, ) ), shader=SoftPhongShader( device=self.device, lights=self.lights, ), ) def _initialize_camera_params(self): """ Initialize camera intrinsics and extrinsics. """ # Extrinsics (identity rotation and zero translation) self.R = torch.eye(3, device=self.device).unsqueeze(0) self.T = torch.zeros(1, 3, device=self.device) # Intrinsics if isinstance(self.focal_length, (list, tuple)): fx, fy = self.focal_length else: fx = fy = self.focal_length self.K = torch.tensor( [[fx, 0, self.width / 2], [0, fy, self.height / 2], [0, 0, 1]], device=self.device, dtype=torch.float32, ).unsqueeze(0) self.bboxes = torch.tensor([[0, 0, self.width, self.height]], dtype=torch.float32) self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes) # Create PyTorch3D cameras self.cameras = self._create_camera_from_cv() def _create_camera_from_cv( self, R: torch.Tensor = None, T: torch.Tensor = None, K: torch.Tensor = None, image_size: torch.Tensor = None, ) -> PerspectiveCameras: """ Create a PyTorch3D camera from OpenCV-style intrinsics and extrinsics. """ if R is None: R = self.R if T is None: T = self.T if K is None: K = self.K if image_size is None: image_size = torch.tensor(self.image_sizes, device=self.device) cameras = _cameras_from_opencv_projection(R, T, K, image_size) return cameras def render( self, verts_list: List[torch.Tensor], faces_list: List[torch.Tensor], colors_list: List[torch.Tensor], ) -> Tuple[np.ndarray, np.ndarray]: """ Render a batch of meshes into an RGB image and mask. Args: verts_list (List[torch.Tensor]): List of vertex tensors. faces_list (List[torch.Tensor]): List of face tensors. colors_list (List[torch.Tensor]): List of per-vertex color tensors. Returns: rend (np.ndarray): Rendered RGB image as uint8 array. mask (np.ndarray): Boolean mask of rendered pixels. """ all_verts = [] all_faces = [] all_colors = [] vertex_offset = 0 for verts, faces, colors in zip(verts_list, faces_list, colors_list): all_verts.append(verts) all_colors.append(colors) all_faces.append(faces + vertex_offset) # Offset face indices vertex_offset += verts.shape[0] # Combine all meshes into a single mesh for rendering all_verts = torch.cat(all_verts, dim=0) all_faces = torch.cat(all_faces, dim=0) all_colors = torch.cat(all_colors, dim=0) mesh = Meshes( verts=[all_verts], # batch_size=1 faces=[all_faces], textures=TexturesVertex(all_colors.unsqueeze(0)), ) # Render the image images = self.renderer(mesh, cameras=self.cameras, lights=self.lights) rend = np.clip(images[0, ..., :3].cpu().numpy() * 255, 0, 255).astype(np.uint8) mask = images[0, ..., -1].cpu().numpy() > 0 return rend, mask