Source code for physiomotion4d.contour_tools

"""
Tools for creating and manipulating contours.
"""

from __future__ import annotations

import logging
from typing import cast

import itk
import numpy as np
import pyvista as pv
import trimesh

from physiomotion4d.image_tools import ImageTools
from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase
from physiomotion4d.transform_tools import TransformTools


[docs] class ContourTools(PhysioMotion4DBase): """ Tools for creating and manipulating contours. """
[docs] def __init__(self, log_level: int | str = logging.INFO): """Initialize ContourTools. Args: log_level: Logging level (default: logging.INFO) """ super().__init__(class_name=self.__class__.__name__, log_level=log_level)
[docs] def extract_contours( self, mask_image: itk.image, ) -> pv.PolyData: """ Make contours from a mask image. Args: mask_image (itk.image): The mask image to create contours from output_file (str, optional): If provided, save the contours to this VTP file Returns: pv.PolyData: The contours as a PyVista PolyData object """ labels = pv.wrap(itk.vtk_image_from_image(mask_image)) contours = cast( pv.PolyData, labels.contour_labels( boundary_style="all", pad_background=False, smoothing=True, smoothing_iterations=10, output_mesh_type="triangles", ), ) contours.smooth_taubin( inplace=True, n_iter=50, pass_band=0.05, ) # self.contours.decimate_pro( # inplace=True, # reduction=0.7, # feature_angle=45, # preserve_topology=True, # ) return contours
[docs] def transform_contours( self, contours: pv.PolyData, tfm: itk.Transform, with_deformation_magnitude: bool = False, ) -> pv.PolyData: """ Transform contours using a given transform. Args: tfm (itk.Transform): The transform to use Returns: pv.PolyData: The transformed contours with deformation magnitude """ new_contours = TransformTools().transform_pvcontour( contours, tfm, with_deformation_magnitude=with_deformation_magnitude ) return new_contours
[docs] def merge_meshes( self, meshes: list[pv.PolyData] ) -> tuple[pv.PolyData, list[pv.PolyData]]: """ Merge multiple fixed meshes into a single mesh. Returns ------- pv.PolyData Merged mesh """ self.log_info("Merging meshes...") trimesh_meshes: list[trimesh.Trimesh] = [] if hasattr(meshes[0], "n_faces_strict"): trimesh_meshes = [ trimesh.Trimesh( vertices=mesh.points, faces=mesh.faces.reshape((mesh.n_faces_strict, 4))[:, 1:], ) for mesh in meshes ] else: trimesh_meshes = [ trimesh.Trimesh( vertices=mesh.points, faces=mesh.faces.reshape(-1, 4)[:, 1:4] ) for mesh in meshes ] # Merge meshes merged_trimesh = trimesh.util.concatenate(trimesh_meshes) flip_matrix = np.array( [[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] ) merged_trimesh.apply_transform(flip_matrix) # Apply flip transformation for mesh in trimesh_meshes: mesh.apply_transform(flip_matrix) merged_mesh = pv.wrap(merged_trimesh) pv_meshes = [pv.wrap(mesh) for mesh in trimesh_meshes] return merged_mesh, pv_meshes
[docs] def create_reference_image( self, mesh: pv.DataSet, spatial_resolution: float = 0.5, buffer_factor: float = 0.25, ptype: type = itk.F, ) -> itk.Image: """ Create a reference image from a mesh. """ points = np.array(mesh.points) min_bounds = points.min(axis=0) max_bounds = points.max(axis=0) min_bounds = min_bounds - buffer_factor * (max_bounds - min_bounds) max_bounds = max_bounds + buffer_factor * (max_bounds - min_bounds) region = ( ((max_bounds - min_bounds) / spatial_resolution + 1) .astype(np.int32) .tolist() ) itk_region = itk.ImageRegion[3]() itk_region.SetSize(region) reference_image = itk.Image[ptype, 3].New() reference_image.SetRegions(itk_region) reference_image.SetSpacing([spatial_resolution] * 3) reference_image.SetOrigin(min_bounds.tolist()) reference_image.Allocate() return reference_image
[docs] def create_mask_from_mesh( self, mesh: pv.DataSet | pv.UnstructuredGrid, reference_image: itk.Image, ) -> itk.Image: ref_spacing = np.array(reference_image.GetSpacing()) # Create trimesh object with LPS coordinates if isinstance(mesh, pv.UnstructuredGrid): mesh = mesh.extract_surface(algorithm="dataset_surface") if hasattr(mesh, "n_faces_strict"): # PyVista PolyData num_points_per_face = len(mesh.faces) // mesh.n_faces_strict faces = mesh.faces.reshape((mesh.n_faces_strict, num_points_per_face))[ :, 1: ] else: # Handle other mesh types faces = mesh.faces.reshape((-1, 4))[:, 1:] trimesh_mesh = trimesh.Trimesh(vertices=mesh.points, faces=faces) # Determine voxel spacing (use minimum spacing from reference) voxel_pitch = float(np.min(ref_spacing)) # Voxelize the mesh # trimesh.voxelized() creates a grid aligned with the mesh's bounding box # The voxel grid origin is at the minimum corner of the bounding box vox = trimesh_mesh.voxelized(pitch=voxel_pitch) binary_array = vox.matrix.astype(np.uint8) # Get the physical origin of the voxel grid in LPS space # trimesh voxel grids use a transformation matrix, and the voxel grid starts # at the mesh's minimum bounds. The physical origin is where voxel [0,0,0] # center is located. # Get mesh bounds in LPS coordinates mesh_bounds_lps = ( trimesh_mesh.bounds ) # shape (2, 3): [[x_min, y_min, z_min], [x_max, y_max, z_max]] # The voxel grid origin is at the minimum corner, but ITK origin is the CENTER # of voxel (0,0,0) # So we need to add half a voxel pitch to each dimension voxel_grid_origin_lps = mesh_bounds_lps[0] + voxel_pitch / 2.0 voxel_grid_origin_lps[2] = ( voxel_grid_origin_lps[2] + voxel_pitch * binary_array.shape[2] ) # Create ITK image from the voxel array # ITK uses ZYX ordering (numpy array convention), trimesh uses XYZ # Need to transpose: (X, Y, Z) -> (Z, Y, X) binary_array_zyx = np.transpose(binary_array, (2, 1, 0)) binary_array_flip = np.flip(binary_array_zyx, axis=0) binary_image = itk.GetImageFromArray(binary_array_flip) # Set ITK image metadata in LPS coordinates # Origin: where the center of voxel (0,0,0) is located in physical space binary_image.SetOrigin(voxel_grid_origin_lps) # Spacing: uniform voxel pitch in all directions binary_image.SetSpacing([voxel_pitch] * 3) # Direction: use identity for now (axis-aligned), will be handled by resampling # Flip Z axis to match ITK convention ref_dir = np.array(binary_image.GetDirection()) ref_dir[2, 2] = -ref_dir[2, 2] binary_image.SetDirection(ref_dir) # Fill holes to create solid mask ImageType = type(binary_image) fill_filter = itk.BinaryFillholeImageFilter[ImageType].New() fill_filter.SetInput(binary_image) fill_filter.SetForegroundValue(1) fill_filter.Update() mask_image = fill_filter.GetOutput() resampler = itk.ResampleImageFilter.New(Input=mask_image) resampler.SetReferenceImage(reference_image) resampler.SetUseReferenceImage(True) resampler.SetInterpolator( itk.NearestNeighborInterpolateImageFunction.New(mask_image) ) resampler.SetDefaultPixelValue(0) resampler.Update() mask_image = resampler.GetOutput() return mask_image
[docs] def create_distance_map( self, mesh: pv.DataSet | pv.UnstructuredGrid, reference_image: itk.Image, squared_distance: bool = False, negative_inside: bool = True, zero_inside: bool = False, norm_to_max_distance: float = 0.0, ) -> itk.Image: self.log_info("Computing signed distance map...") # Convert mask to binary points = mesh.points size = reference_image.GetLargestPossibleRegion().GetSize() size = (size[2], size[1], size[0]) tmp_arr = np.zeros(size, dtype=np.int32) itk_point = itk.Point[itk.D, 3]() point_count = 0 for point in points: itk_point[0] = float(point[0]) itk_point[1] = float(point[1]) itk_point[2] = float(point[2]) indx = reference_image.TransformPhysicalPointToIndex(itk_point) if ( indx[0] < 0 or indx[1] < 0 or indx[2] < 0 or indx[0] >= size[0] or indx[1] >= size[1] or indx[2] >= size[2] ): continue tmp_arr[indx[2], indx[1], indx[0]] = 1 point_count += 1 tmp_binary_image = itk.GetImageFromArray(tmp_arr.astype(np.uint8)) tmp_binary_image.CopyInformation(reference_image) assert ( tmp_binary_image.GetLargestPossibleRegion().GetSize() == reference_image.GetLargestPossibleRegion().GetSize() ) distance_filter = itk.SignedMaurerDistanceMapImageFilter.New( Input=tmp_binary_image ) distance_filter.SetSquaredDistance(False) distance_filter.SetUseImageSpacing(True) distance_filter.Update() distance_image = distance_filter.GetOutput() distance_arr = itk.GetArrayFromImage(distance_image).astype(np.float32) if zero_inside: distance_arr = np.clip(distance_arr, 0.0, None) if not negative_inside: distance_arr = np.abs(distance_arr) if squared_distance: distance_arr = np.sign(distance_arr) * distance_arr**2 if norm_to_max_distance != 0.0: distance_arr = distance_arr / norm_to_max_distance distance_arr = np.clip(distance_arr, -1.0, 1.0) distance_image = itk.GetImageFromArray(distance_arr) distance_image.CopyInformation(reference_image) return distance_image
[docs] def create_deformation_field( self, points: np.ndarray, point_displacements: np.ndarray, reference_image: itk.Image, blur_sigma: float = 2.5, ptype: type = itk.D, ) -> itk.Image: """ Create a displacement map from model points and displacements. """ size = reference_image.GetLargestPossibleRegion().GetSize() norm_map = np.zeros((size[2], size[1], size[0])).astype(np.float32) displacement_map_x = np.zeros((size[2], size[1], size[0])).astype(np.float32) displacement_map_y = np.zeros((size[2], size[1], size[0])).astype(np.float32) displacement_map_z = np.zeros((size[2], size[1], size[0])).astype(np.float32) itk_point = itk.Point[itk.D, 3]() for i, point in enumerate(points): itk_point[0] = float(point[0]) itk_point[1] = float(point[1]) itk_point[2] = float(point[2]) indx = reference_image.TransformPhysicalPointToIndex(itk_point) if ( indx[0] < 0 or indx[1] < 0 or indx[2] < 0 or indx[0] >= size[0] or indx[1] >= size[1] or indx[2] >= size[2] ): continue displacement_map_x[int(indx[2]), int(indx[1]), int(indx[0])] = ( point_displacements[i, 0] ) displacement_map_y[int(indx[2]), int(indx[1]), int(indx[0])] = ( point_displacements[i, 1] ) displacement_map_z[int(indx[2]), int(indx[1]), int(indx[0])] = ( point_displacements[i, 2] ) norm_map[int(indx[2]), int(indx[1]), int(indx[0])] = 1 norm_img = itk.GetImageFromArray(norm_map) norm_img.CopyInformation(reference_image) assert ( norm_img.GetLargestPossibleRegion().GetSize() == reference_image.GetLargestPossibleRegion().GetSize() ) blurred_norm = itk.SmoothingRecursiveGaussianImageFilter( Input=norm_img, Sigma=blur_sigma ) blurred_norm_arr = itk.GetArrayFromImage(blurred_norm) blurred_norm_arr = np.where(blurred_norm_arr < 1.0e-4, 1.0e-4, blurred_norm_arr) deformation_field_x_img = itk.GetImageFromArray(displacement_map_x) deformation_field_x_img.CopyInformation(reference_image) deformation_field_x_img = itk.SmoothingRecursiveGaussianImageFilter( Input=deformation_field_x_img, Sigma=blur_sigma ) deformation_field_y_img = itk.GetImageFromArray(displacement_map_y) deformation_field_y_img.CopyInformation(reference_image) deformation_field_y_img = itk.SmoothingRecursiveGaussianImageFilter( Input=deformation_field_y_img, Sigma=blur_sigma ) deformation_field_z_img = itk.GetImageFromArray(displacement_map_z) deformation_field_z_img.CopyInformation(reference_image) deformation_field_z_img = itk.SmoothingRecursiveGaussianImageFilter( Input=deformation_field_z_img, Sigma=blur_sigma ) deformation_field_x = ( itk.GetArrayFromImage(deformation_field_x_img) / blurred_norm_arr ) deformation_field_y = ( itk.GetArrayFromImage(deformation_field_y_img) / blurred_norm_arr ) deformation_field_z = ( itk.GetArrayFromImage(deformation_field_z_img) / blurred_norm_arr ) deformation_field_x = np.where( blurred_norm_arr > 1.0e-3, deformation_field_x, 0.0 ) deformation_field_y = np.where( blurred_norm_arr > 1.0e-3, deformation_field_y, 0.0 ) deformation_field_z = np.where( blurred_norm_arr > 1.0e-3, deformation_field_z, 0.0 ) deformation_field = np.stack( [deformation_field_x, deformation_field_y, deformation_field_z], axis=-1 ) image_tools = ImageTools() deformation_field_img = image_tools.convert_array_to_image_of_vectors( deformation_field, reference_image, ptype=ptype ) return deformation_field_img