samgeo2 module¶
SamGeo2
¶
The main class for segmenting geospatial data with the Segment Anything Model 2 (SAM2). See https://github.com/facebookresearch/segment-anything-2 for details.
Source code in samgeo/samgeo2.py
class SamGeo2:
"""The main class for segmenting geospatial data with the Segment Anything Model 2 (SAM2). See
https://github.com/facebookresearch/segment-anything-2 for details.
"""
def __init__(
self,
model_id: str = "sam2-hiera-large",
device: Optional[str] = None,
empty_cache: bool = True,
automatic: bool = True,
video: bool = False,
mode: str = "eval",
hydra_overrides_extra: Optional[List[str]] = None,
apply_postprocessing: bool = False,
points_per_side: Optional[int] = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.8,
stability_score_thresh: float = 0.95,
stability_score_offset: float = 1.0,
mask_threshold: float = 0.0,
box_nms_thresh: float = 0.7,
crop_n_layers: int = 0,
crop_nms_thresh: float = 0.7,
crop_overlap_ratio: float = 512 / 1500,
crop_n_points_downscale_factor: int = 1,
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
use_m2m: bool = False,
multimask_output: bool = True,
max_hole_area: float = 0.0,
max_sprinkle_area: float = 0.0,
**kwargs: Any,
) -> None:
"""
Initializes the SamGeo2 class.
Args:
model_id (str): The model ID to use. Can be one of the following: "sam2-hiera-tiny",
"sam2-hiera-small", "sam2-hiera-base-plus", "sam2-hiera-large".
Defaults to "sam2-hiera-large".
device (Optional[str]): The device to use (e.g., "cpu", "cuda", "mps"). Defaults to None.
empty_cache (bool): Whether to empty the cache. Defaults to True.
automatic (bool): Whether to use automatic mask generation. Defaults to True.
video (bool): Whether to use video prediction. Defaults to False.
mode (str): The mode to use. Defaults to "eval".
hydra_overrides_extra (Optional[List[str]]): Additional Hydra overrides. Defaults to None.
apply_postprocessing (bool): Whether to apply postprocessing. Defaults to False.
points_per_side (int or None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_per_batch (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
pred_iou_thresh (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
mask_threshold (float): Threshold for binarizing the mask logits
box_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks.
crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray) or None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
min_mask_region_area (int): If >0, postprocessing will be applied
to remove disconnected regions and holes in masks with area smaller
than min_mask_region_area. Requires opencv.
output_mode (str): The form masks are returned in. Can be 'binary_mask',
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
multimask_output (bool): Whether to output multimask at each point of the grid.
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
**kwargs (Any): Additional keyword arguments to pass to
SAM2AutomaticMaskGenerator.from_pretrained() or SAM2ImagePredictor.from_pretrained().
"""
if isinstance(model_id, str):
if not model_id.startswith("facebook/"):
model_id = f"facebook/{model_id}"
else:
raise ValueError("model_id must be a string")
allowed_models = [
"facebook/sam2-hiera-tiny",
"facebook/sam2-hiera-small",
"facebook/sam2-hiera-base-plus",
"facebook/sam2-hiera-large",
]
if model_id not in allowed_models:
raise ValueError(
f"model_id must be one of the following: {', '.join(allowed_models)}"
)
if device is None:
device = common.choose_device(empty_cache=empty_cache)
if hydra_overrides_extra is None:
hydra_overrides_extra = []
self.model_id = model_id
self.device = device
if video:
automatic = False
if automatic:
self.mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
model_id,
device=device,
mode=mode,
hydra_overrides_extra=hydra_overrides_extra,
apply_postprocessing=apply_postprocessing,
points_per_side=points_per_side,
points_per_batch=points_per_batch,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
stability_score_offset=stability_score_offset,
mask_threshold=mask_threshold,
box_nms_thresh=box_nms_thresh,
crop_n_layers=crop_n_layers,
crop_nms_thresh=crop_nms_thresh,
crop_overlap_ratio=crop_overlap_ratio,
crop_n_points_downscale_factor=crop_n_points_downscale_factor,
point_grids=point_grids,
min_mask_region_area=min_mask_region_area,
output_mode=output_mode,
use_m2m=use_m2m,
multimask_output=multimask_output,
**kwargs,
)
elif video:
self.predictor = SAM2VideoPredictor.from_pretrained(
model_id,
device=device,
mode=mode,
hydra_overrides_extra=hydra_overrides_extra,
apply_postprocessing=apply_postprocessing,
**kwargs,
)
else:
self.predictor = SAM2ImagePredictor.from_pretrained(
model_id,
device=device,
mode=mode,
hydra_overrides_extra=hydra_overrides_extra,
apply_postprocessing=apply_postprocessing,
mask_threshold=mask_threshold,
max_hole_area=max_hole_area,
max_sprinkle_area=max_sprinkle_area,
**kwargs,
)
def generate(
self,
source: Union[str, np.ndarray],
output: Optional[str] = None,
foreground: bool = True,
erosion_kernel: Optional[Tuple[int, int]] = None,
mask_multiplier: int = 255,
unique: bool = True,
min_size: int = 0,
max_size: int = None,
**kwargs: Any,
) -> List[Dict[str, Any]]:
"""
Generate masks for the input image.
Args:
source (Union[str, np.ndarray]): The path to the input image or the
input image as a numpy array.
output (Optional[str]): The path to the output image. Defaults to None.
foreground (bool): Whether to generate the foreground mask. Defaults
to True.
erosion_kernel (Optional[Tuple[int, int]]): The erosion kernel for
filtering object masks and extract borders.
Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
mask_multiplier (int): The mask multiplier for the output mask,
which is usually a binary mask [0, 1].
You can use this parameter to scale the mask to a larger range,
for example [0, 255]. Defaults to 255.
The parameter is ignored if unique is True.
unique (bool): Whether to assign a unique value to each object.
Defaults to True.
The unique value increases from 1 to the number of objects. The
larger the number, the larger the object area.
min_size (int): The minimum size of the object. Defaults to 0.
max_size (int): The maximum size of the object. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the generated masks.
"""
if isinstance(source, str):
if source.startswith("http"):
source = common.download_file(source)
if not os.path.exists(source):
raise ValueError(f"Input path {source} does not exist.")
image = cv2.imread(source)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif isinstance(source, np.ndarray):
image = source
source = None
else:
raise ValueError("Input source must be either a path or a numpy array.")
self.source = source # Store the input image path
self.image = image # Store the input image as a numpy array
mask_generator = self.mask_generator # The automatic mask generator
masks = mask_generator.generate(image) # Segment the input image
self.masks = masks # Store the masks as a list of dictionaries
self._min_size = min_size
self._max_size = max_size
if output is not None:
# Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
self.save_masks(
output,
foreground,
unique,
erosion_kernel,
mask_multiplier,
min_size,
max_size,
**kwargs,
)
def save_masks(
self,
output: Optional[str] = None,
foreground: bool = True,
unique: bool = True,
erosion_kernel: Optional[Tuple[int, int]] = None,
mask_multiplier: int = 255,
min_size: int = 0,
max_size: int = None,
**kwargs: Any,
) -> None:
"""Save the masks to the output path. The output is either a binary mask
or a mask of objects with unique values.
Args:
output (str, optional): The path to the output image. Defaults to
None, saving the masks to SamGeo.objects.
foreground (bool, optional): Whether to generate the foreground mask.
Defaults to True.
unique (bool, optional): Whether to assign a unique value to each
object. Defaults to True.
erosion_kernel (tuple, optional): The erosion kernel for filtering
object masks and extract borders.
Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to
None.
mask_multiplier (int, optional): The mask multiplier for the output
mask, which is usually a binary mask [0, 1]. You can use this
parameter to scale the mask to a larger range, for example
[0, 255]. Defaults to 255.
min_size (int, optional): The minimum size of the object. Defaults to 0.
max_size (int, optional): The maximum size of the object. Defaults to None.
**kwargs: Additional keyword arguments for common.array_to_image().
"""
if self.masks is None:
raise ValueError("No masks found. Please run generate() first.")
h, w, _ = self.image.shape
masks = self.masks
# Set output image data type based on the number of objects
if len(masks) < 255:
dtype = np.uint8
elif len(masks) < 65535:
dtype = np.uint16
else:
dtype = np.uint32
# Generate a mask of objects with unique values
if unique:
# Sort the masks by area in descending order
sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
# Create an output image with the same size as the input image
objects = np.zeros(
(
sorted_masks[0]["segmentation"].shape[0],
sorted_masks[0]["segmentation"].shape[1],
)
)
# Assign a unique value to each object
count = len(sorted_masks)
for index, ann in enumerate(sorted_masks):
m = ann["segmentation"]
if min_size > 0 and ann["area"] < min_size:
continue
if max_size is not None and ann["area"] > max_size:
continue
objects[m] = count - index
# Generate a binary mask
else:
if foreground: # Extract foreground objects only
resulting_mask = np.zeros((h, w), dtype=dtype)
else:
resulting_mask = np.ones((h, w), dtype=dtype)
resulting_borders = np.zeros((h, w), dtype=dtype)
for m in masks:
if min_size > 0 and m["area"] < min_size:
continue
if max_size is not None and m["area"] > max_size:
continue
mask = (m["segmentation"] > 0).astype(dtype)
resulting_mask += mask
# Apply erosion to the mask
if erosion_kernel is not None:
mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
mask_erode = (mask_erode > 0).astype(dtype)
edge_mask = mask - mask_erode
resulting_borders += edge_mask
resulting_mask = (resulting_mask > 0).astype(dtype)
resulting_borders = (resulting_borders > 0).astype(dtype)
objects = resulting_mask - resulting_borders
objects = objects * mask_multiplier
objects = objects.astype(dtype)
self.objects = objects
if output is not None: # Save the output image
common.array_to_image(self.objects, output, self.source, **kwargs)
def show_masks(
self,
figsize: Tuple[int, int] = (12, 10),
cmap: str = "binary_r",
axis: str = "off",
foreground: bool = True,
**kwargs: Any,
) -> None:
"""Show the binary mask or the mask of objects with unique values.
Args:
figsize (tuple, optional): The figure size. Defaults to (12, 10).
cmap (str, optional): The colormap. Defaults to "binary_r".
axis (str, optional): Whether to show the axis. Defaults to "off".
foreground (bool, optional): Whether to show the foreground mask only.
Defaults to True.
**kwargs: Other arguments for save_masks().
"""
import matplotlib.pyplot as plt
if self.objects is None:
self.save_masks(foreground=foreground, **kwargs)
plt.figure(figsize=figsize)
plt.imshow(self.objects, cmap=cmap)
plt.axis(axis)
plt.show()
def show_anns(
self,
figsize: Tuple[int, int] = (12, 10),
axis: str = "off",
alpha: float = 0.35,
output: Optional[str] = None,
blend: bool = True,
**kwargs: Any,
) -> None:
"""Show the annotations (objects with random color) on the input image.
Args:
figsize (tuple, optional): The figure size. Defaults to (12, 10).
axis (str, optional): Whether to show the axis. Defaults to "off".
alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
output (str, optional): The path to the output image. Defaults to None.
blend (bool, optional): Whether to show the input image. Defaults to True.
"""
import matplotlib.pyplot as plt
anns = self.masks
if self.image is None:
print("Please run generate() first.")
return
if anns is None or len(anns) == 0:
return
plt.figure(figsize=figsize)
plt.imshow(self.image)
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones(
(
sorted_anns[0]["segmentation"].shape[0],
sorted_anns[0]["segmentation"].shape[1],
4,
)
)
img[:, :, 3] = 0
for ann in sorted_anns:
if hasattr(self, "_min_size") and (ann["area"] < self._min_size):
continue
if (
hasattr(self, "_max_size")
and isinstance(self._max_size, int)
and ann["area"] > self._max_size
):
continue
m = ann["segmentation"]
color_mask = np.concatenate([np.random.random(3), [alpha]])
img[m] = color_mask
ax.imshow(img)
if "dpi" not in kwargs:
kwargs["dpi"] = 100
if "bbox_inches" not in kwargs:
kwargs["bbox_inches"] = "tight"
plt.axis(axis)
self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)
if output is not None:
if blend:
array = common.blend_images(
self.annotations, self.image, alpha=alpha, show=False
)
else:
array = self.annotations
common.array_to_image(array, output, self.source)
@torch.no_grad()
def set_image(
self,
image: Union[str, np.ndarray, Image],
) -> None:
"""Set the input image as a numpy array.
Args:
image (Union[str, np.ndarray, Image]): The input image as a path,
a numpy array, or an Image.
"""
if isinstance(image, str):
if image.startswith("http"):
image = common.download_file(image)
if not os.path.exists(image):
raise ValueError(f"Input path {image} does not exist.")
self.source = image
image = cv2.imread(image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.image = image
elif isinstance(image, np.ndarray) or isinstance(image, Image):
pass
else:
raise ValueError("Input image must be either a path or a numpy array.")
self.predictor.set_image(image)
@torch.no_grad()
def set_image_batch(
self,
image_list: List[Union[np.ndarray, str, Image]],
) -> None:
"""Set a batch of images for prediction.
Args:
image_list (List[Union[np.ndarray, str, Image]]): A list of images,
which can be numpy arrays, file paths, or PIL images.
Raises:
ValueError: If an input image path does not exist or if the input
image type is not supported.
"""
images = []
for image in image_list:
if isinstance(image, str):
if image.startswith("http"):
image = common.download_file(image)
if not os.path.exists(image):
raise ValueError(f"Input path {image} does not exist.")
image = cv2.imread(image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif isinstance(image, Image):
image = np.array(image)
elif isinstance(image, np.ndarray):
pass
else:
raise ValueError("Input image must be either a path or a numpy array.")
images.append(image)
self.predictor.set_image_batch(images)
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
boxes: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
normalize_coords: bool = True,
point_crs: Optional[str] = None,
output: Optional[str] = None,
index: Optional[int] = None,
mask_multiplier: int = 255,
dtype: str = "float32",
return_results: bool = False,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Predict the mask for the input image.
Args:
point_coords (np.ndarray, optional): The point coordinates. Defaults to None.
point_labels (np.ndarray, optional): The point labels. Defaults to None.
boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
multimask_output (bool, optional): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
multimask_output (bool, optional): Whether to output multimask at each
point of the grid. Defaults to True.
return_logits (bool, optional): If true, returns un-thresholded masks logits
instead of a binary mask.
normalize_coords (bool, optional): Whether to normalize the coordinates.
Defaults to True.
point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
output (str, optional): The path to the output image. Defaults to None.
index (index, optional): The index of the mask to save. Defaults to None,
which will save the mask with the highest score.
mask_multiplier (int, optional): The mask multiplier for the output mask,
which is usually a binary mask [0, 1].
dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
return_results (bool, optional): Whether to return the predicted masks,
scores, and logits. Defaults to False.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: The mask, the multimask,
and the logits.
"""
import geopandas as gpd
out_of_bounds = []
if isinstance(boxes, str):
gdf = gpd.read_file(boxes)
if gdf.crs is not None:
gdf = gdf.to_crs("epsg:4326")
boxes = gdf.geometry.bounds.values.tolist()
elif isinstance(boxes, dict):
import json
geojson = json.dumps(boxes)
gdf = gpd.read_file(geojson, driver="GeoJSON")
boxes = gdf.geometry.bounds.values.tolist()
if isinstance(point_coords, str):
point_coords = common.vector_to_geojson(point_coords)
if isinstance(point_coords, dict):
point_coords = common.geojson_to_coords(point_coords)
if hasattr(self, "point_coords"):
point_coords = self.point_coords
if hasattr(self, "point_labels"):
point_labels = self.point_labels
if (point_crs is not None) and (point_coords is not None):
point_coords, out_of_bounds = common.coords_to_xy(
self.source, point_coords, point_crs, return_out_of_bounds=True
)
if isinstance(point_coords, list):
point_coords = np.array(point_coords)
if point_coords is not None:
if point_labels is None:
point_labels = [1] * len(point_coords)
elif isinstance(point_labels, int):
point_labels = [point_labels] * len(point_coords)
if isinstance(point_labels, list):
if len(point_labels) != len(point_coords):
if len(point_labels) == 1:
point_labels = point_labels * len(point_coords)
elif len(out_of_bounds) > 0:
print(f"Removing {len(out_of_bounds)} out-of-bound points.")
point_labels_new = []
for i, p in enumerate(point_labels):
if i not in out_of_bounds:
point_labels_new.append(p)
point_labels = point_labels_new
else:
raise ValueError(
"The length of point_labels must be equal to the length of point_coords."
)
point_labels = np.array(point_labels)
predictor = self.predictor
input_boxes = None
if isinstance(boxes, list) and (point_crs is not None):
coords = common.bbox_to_xy(self.source, boxes, point_crs)
input_boxes = np.array(coords)
if isinstance(coords[0], int):
input_boxes = input_boxes[None, :]
else:
input_boxes = torch.tensor(input_boxes, device=self.device)
input_boxes = predictor.transform.apply_boxes_torch(
input_boxes, self.image.shape[:2]
)
elif isinstance(boxes, list) and (point_crs is None):
input_boxes = np.array(boxes)
if isinstance(boxes[0], int):
input_boxes = input_boxes[None, :]
self.boxes = input_boxes
if (
boxes is None
or (len(boxes) == 1)
or (len(boxes) == 4 and isinstance(boxes[0], float))
):
if isinstance(boxes, list) and isinstance(boxes[0], list):
boxes = boxes[0]
masks, scores, logits = predictor.predict(
point_coords,
point_labels,
input_boxes,
mask_input,
multimask_output,
return_logits,
)
else:
masks, scores, logits = predictor.predict_torch(
point_coords=point_coords,
point_labels=point_coords,
boxes=input_boxes,
multimask_output=True,
)
self.masks = masks
self.scores = scores
self.logits = logits
if output is not None:
if boxes is None or (not isinstance(boxes[0], list)):
self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
else:
self.tensor_to_numpy(
index, output, mask_multiplier, dtype, save_args=kwargs
)
if return_results:
return masks, scores, logits
return self.predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
box=boxes,
mask_input=mask_input,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)
def predict_batch(
self,
point_coords_batch: List[np.ndarray] = None,
point_labels_batch: List[np.ndarray] = None,
box_batch: List[np.ndarray] = None,
mask_input_batch: List[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
normalize_coords=True,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""Predict masks for a batch of images.
Args:
point_coords_batch (Optional[List[np.ndarray]]): A batch of point
coordinates. Defaults to None.
point_labels_batch (Optional[List[np.ndarray]]): A batch of point
labels. Defaults to None.
box_batch (Optional[List[np.ndarray]]): A batch of bounding boxes.
Defaults to None.
mask_input_batch (Optional[List[np.ndarray]]): A batch of mask inputs.
Defaults to None.
multimask_output (bool): Whether to output multimask at each point
of the grid. Defaults to True.
return_logits (bool): Whether to return the logits. Defaults to False.
normalize_coords (bool): Whether to normalize the coordinates.
Defaults to True.
Returns:
Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: Lists
of masks, multimasks, and logits.
"""
return self.predictor.predict_batch(
point_coords_batch=point_coords_batch,
point_labels_batch=point_labels_batch,
box_batch=box_batch,
mask_input_batch=mask_input_batch,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)
@torch.inference_mode()
def init_state(
self,
video_path: str,
offload_video_to_cpu: bool = False,
offload_state_to_cpu: bool = False,
async_loading_frames: bool = False,
) -> Any:
"""Initialize an inference state.
Args:
video_path (str): The path to the video file.
offload_video_to_cpu (bool): Whether to offload the video to CPU.
Defaults to False.
offload_state_to_cpu (bool): Whether to offload the state to CPU.
Defaults to False.
async_loading_frames (bool): Whether to load frames asynchronously.
Defaults to False.
Returns:
Any: The initialized inference state.
"""
return self.predictor.init_state(
video_path,
offload_video_to_cpu=offload_video_to_cpu,
offload_state_to_cpu=offload_state_to_cpu,
async_loading_frames=async_loading_frames,
)
@torch.inference_mode()
def reset_state(self, inference_state: Any) -> None:
"""Remove all input points or masks in all frames throughout the video.
Args:
inference_state (Any): The current inference state.
"""
self.predictor.reset_state(inference_state)
@torch.inference_mode()
def add_new_points_or_box(
self,
inference_state: Any,
frame_idx: int,
obj_id: int,
points: Optional[np.ndarray] = None,
labels: Optional[np.ndarray] = None,
clear_old_points: bool = True,
normalize_coords: bool = True,
box: Optional[np.ndarray] = None,
) -> Any:
"""Add new points or a box to the inference state.
Args:
inference_state (Any): The current inference state.
frame_idx (int): The frame index.
obj_id (int): The object ID.
points (Optional[np.ndarray]): The points to add. Defaults to None.
labels (Optional[np.ndarray]): The labels for the points. Defaults to None.
clear_old_points (bool): Whether to clear old points. Defaults to True.
normalize_coords (bool): Whether to normalize the coordinates. Defaults to True.
box (Optional[np.ndarray]): The bounding box to add. Defaults to None.
Returns:
Any: The updated inference state.
"""
return self.predictor.add_new_points_or_box(
inference_state,
frame_idx,
obj_id,
points=points,
labels=labels,
clear_old_points=clear_old_points,
normalize_coords=normalize_coords,
box=box,
)
@torch.inference_mode()
def add_new_mask(
self,
inference_state: Any,
frame_idx: int,
obj_id: int,
mask: np.ndarray,
) -> Any:
"""Add a new mask to the inference state.
Args:
inference_state (Any): The current inference state.
frame_idx (int): The frame index.
obj_id (int): The object ID.
mask (np.ndarray): The mask to add.
Returns:
Any: The updated inference state.
"""
return self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask)
@torch.inference_mode()
def propagate_in_video_preflight(self, inference_state: Any) -> Any:
"""Propagate the inference state in video preflight.
Args:
inference_state (Any): The current inference state.
Returns:
Any: The propagated inference state.
"""
return self.predictor.propagate_in_video_preflight(inference_state)
@torch.inference_mode()
def propagate_in_video(
self,
inference_state: Any,
start_frame_idx: Optional[int] = None,
max_frame_num_to_track: Optional[int] = None,
reverse: bool = False,
) -> Any:
"""Propagate the inference state in video.
Args:
inference_state (Any): The current inference state.
start_frame_idx (Optional[int]): The starting frame index. Defaults to None.
max_frame_num_to_track (Optional[int]): The maximum number of frames
to track. Defaults to None.
reverse (bool): Whether to propagate in reverse. Defaults to False.
Returns:
Any: The propagated inference state.
"""
return self.predictor.propagate_in_video(
inference_state,
start_frame_idx=start_frame_idx,
max_frame_num_to_track=max_frame_num_to_track,
reverse=reverse,
)
def tensor_to_numpy(
self,
index: Optional[int] = None,
output: Optional[str] = None,
mask_multiplier: int = 255,
dtype: str = "uint8",
save_args: Optional[Dict[str, Any]] = None,
) -> Optional[np.ndarray]:
"""Convert the predicted masks from tensors to numpy arrays.
Args:
index (Optional[int], optional): The index of the mask to save.
Defaults to None, which will save the mask with the highest score.
output (Optional[str], optional): The path to the output image.
Defaults to None.
mask_multiplier (int, optional): The mask multiplier for the output
mask, which is usually a binary mask [0, 1].
dtype (str, optional): The data type of the output image. Defaults
to "uint8".
save_args (Optional[Dict[str, Any]], optional): Optional arguments
for saving the output image. Defaults to None.
Returns:
Optional[np.ndarray]: The predicted mask as a numpy array, or None
if output is specified.
"""
if save_args is None:
save_args = {}
boxes = self.boxes
masks = self.masks
image_pil = self.image
image_np = np.array(image_pil)
if index is None:
index = 1
masks = masks[:, index, :, :]
masks = masks.squeeze(1)
if boxes is None or (len(boxes) == 0): # No "object" instances found
print("No objects found in the image.")
return
else:
# Create an empty image to store the mask overlays
mask_overlay = np.zeros_like(
image_np[..., 0], dtype=dtype
) # Adjusted for single channel
for i, (_, mask) in enumerate(zip(boxes, masks)):
# Convert tensor to numpy array if necessary and ensure it contains integers
if isinstance(mask, torch.Tensor):
mask = (
mask.cpu().numpy().astype(dtype)
) # If mask is on GPU, use .cpu() before .numpy()
mask_overlay += ((mask > 0) * (i + 1)).astype(
dtype
) # Assign a unique value for each mask
# Normalize mask_overlay to be in [0, 255]
mask_overlay = (
mask_overlay > 0
) * mask_multiplier # Binary mask in [0, 255]
if output is not None:
common.array_to_image(
mask_overlay, output, self.source, dtype=dtype, **save_args
)
else:
return mask_overlay
def save_prediction(
self,
output: str,
index: Optional[int] = None,
mask_multiplier: int = 255,
dtype: str = "float32",
vector: Optional[str] = None,
simplify_tolerance: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Save the predicted mask to the output path.
Args:
output (str): The path to the output image.
index (Optional[int], optional): The index of the mask to save.
Defaults to None, which will save the mask with the highest score.
mask_multiplier (int, optional): The mask multiplier for the output
mask, which is usually a binary mask [0, 1].
dtype (str, optional): The data type of the output image. Defaults
to "float32".
vector (Optional[str], optional): The path to the output vector file.
Defaults to None.
simplify_tolerance (Optional[float], optional): The maximum allowed
geometry displacement. The higher this value, the smaller the
number of vertices in the resulting geometry.
**kwargs (Any): Additional keyword arguments.
"""
if self.scores is None:
raise ValueError("No predictions found. Please run predict() first.")
if index is None:
index = self.scores.argmax(axis=0)
array = self.masks[index] * mask_multiplier
self.prediction = array
common.array_to_image(array, output, self.source, dtype=dtype, **kwargs)
if vector is not None:
common.raster_to_vector(
output, vector, simplify_tolerance=simplify_tolerance
)
def show_map(
self,
basemap: str = "SATELLITE",
repeat_mode: bool = True,
out_dir: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Show the interactive map.
Args:
basemap (str, optional): The basemap. It can be one of the following:
SATELLITE, ROADMAP, TERRAIN, HYBRID.
repeat_mode (bool, optional): Whether to use the repeat mode for
draw control. Defaults to True.
out_dir (Optional[str], optional): The path to the output directory.
Defaults to None.
Returns:
Any: The map object.
"""
return common.sam_map_gui(
self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs
)
def show_canvas(
self,
fg_color: Tuple[int, int, int] = (0, 255, 0),
bg_color: Tuple[int, int, int] = (0, 0, 255),
radius: int = 5,
) -> Tuple[list, list]:
"""Show a canvas to collect foreground and background points.
Args:
fg_color (Tuple[int, int, int], optional): The color for the foreground points.
Defaults to (0, 255, 0).
bg_color (Tuple[int, int, int], optional): The color for the background points.
Defaults to (0, 0, 255).
radius (int, optional): The radius of the points. Defaults to 5.
Returns:
Tuple[list, list]: A tuple of two lists of foreground and background points.
"""
if self.image is None:
raise ValueError("Please run set_image() first.")
image = self.image
fg_points, bg_points = common.show_canvas(image, fg_color, bg_color, radius)
self.fg_points = fg_points
self.bg_points = bg_points
point_coords = fg_points + bg_points
point_labels = [1] * len(fg_points) + [0] * len(bg_points)
self.point_coords = point_coords
self.point_labels = point_labels
def _convert_prompts(self, prompts: Dict[int, Any]) -> Dict[int, Any]:
"""Convert the points and labels in the prompts to numpy arrays with specific data types.
Args:
prompts (Dict[str, Any]): A dictionary containing the prompts with points and labels.
Returns:
Dict[str, Any]: The updated dictionary with points and labels converted to numpy arrays.
"""
for _, value in prompts.items():
# Convert points to np.float32 array
if "points" in value:
value["points"] = np.array(value["points"], dtype=np.float32)
# Convert labels to np.int32 array
if "labels" in value:
value["labels"] = np.array(value["labels"], dtype=np.int32)
# Convert box to np.float32 array
if "box" in value:
value["box"] = np.array(value["box"], dtype=np.float32)
return prompts
def set_video(
self,
video_path: str,
output_dir: str = None,
frame_rate: Optional[int] = None,
prefix: str = "",
) -> None:
"""Set the video path and parameters.
Args:
video_path (str): The path to the video file.
start_frame (int, optional): The starting frame index. Defaults to 0.
end_frame (Optional[int], optional): The ending frame index. Defaults to None.
step (int, optional): The step size. Defaults to 1.
frame_rate (Optional[int], optional): The frame rate. Defaults to None.
"""
if isinstance(video_path, str):
if video_path.startswith("http"):
video_path = common.download_file(video_path)
if os.path.isfile(video_path):
if output_dir is None:
output_dir = common.make_temp_dir()
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Output directory: {output_dir}")
common.video_to_images(
video_path, output_dir, frame_rate=frame_rate, prefix=prefix
)
elif os.path.isdir(video_path):
files = sorted(os.listdir(video_path))
if len(files) == 0:
raise ValueError(f"No files found in {video_path}.")
elif files[0].endswith(".tif"):
self._tif_source = os.path.join(video_path, files[0])
self._tif_dir = video_path
self._tif_names = files
video_path = common.geotiff_to_jpg_batch(video_path)
output_dir = video_path
if not os.path.exists(video_path):
raise ValueError(f"Input path {video_path} does not exist.")
else:
raise ValueError("Input video_path must be a string.")
self.video_path = output_dir
self._num_images = len(os.listdir(output_dir))
self._frame_names = sorted(os.listdir(output_dir))
self.inference_state = self.predictor.init_state(video_path=output_dir)
def predict_video(
self,
prompts: Dict[int, Any] = None,
point_crs: Optional[str] = None,
output_dir: Optional[str] = None,
img_ext: str = "png",
) -> None:
"""Predict masks for the video.
Args:
prompts (Dict[int, Any]): A dictionary containing the prompts with points and labels.
point_crs (Optional[str]): The coordinate reference system (CRS) of the point prompts.
output_dir (Optional[str]): The directory to save the output images. Defaults to None.
img_ext (str): The file extension for the output images. Defaults to "png".
"""
from PIL import Image
def save_image_from_dict(data, output_path="output_image.png"):
# Find the shape of the first array in the dictionary (assuming all arrays have the same shape)
array_shape = next(iter(data.values())).shape[1:]
# Initialize an empty array with the same shape as the arrays in the dictionary, filled with zeros
output_array = np.zeros(array_shape, dtype=np.uint8)
# Iterate over each key and array in the dictionary
for key, array in data.items():
# Assign the key value wherever the boolean array is True
output_array[array[0]] = key
# Convert the output array to a PIL image
image = Image.fromarray(output_array)
# Save the image
image.save(output_path)
if prompts is None:
if hasattr(self, "prompts"):
prompts = self.prompts
else:
raise ValueError("Please provide prompts.")
if point_crs is not None and self._tif_source is not None:
for prompt in prompts.values():
points = prompt.get("points", None)
if points is not None:
points = common.coords_to_xy(self._tif_source, points, point_crs)
prompt["points"] = points
box = prompt.get("box", None)
if box is not None:
box = common.bbox_to_xy(self._tif_source, box, point_crs)
prompt["box"] = box
prompts = self._convert_prompts(prompts)
predictor = self.predictor
inference_state = self.inference_state
for obj_id, prompt in prompts.items():
points = prompt.get("points", None)
labels = prompt.get("labels", None)
box = prompt.get("box", None)
frame_idx = prompt.get("frame_idx", None)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=obj_id,
points=points,
labels=labels,
box=box,
)
video_segments = {}
num_frames = self._num_images
num_digits = len(str(num_frames))
if output_dir is not None:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
if output_dir is not None:
output_path = os.path.join(
output_dir, f"{str(out_frame_idx).zfill(num_digits)}.{img_ext}"
)
save_image_from_dict(video_segments[out_frame_idx], output_path)
self.video_segments = video_segments
# if output_dir is not None:
# self.save_video_segments(output_dir, img_ext)
def save_video_segments(self, output_dir: str, img_ext: str = "png") -> None:
"""Save the video segments to the output directory.
Args:
output_dir (str): The path to the output directory.
img_ext (str): The file extension for the output images. Defaults to "png".
"""
from PIL import Image
def save_image_from_dict(
data, output_path="output_image.png", crs_source=None, **kwargs
):
# Find the shape of the first array in the dictionary (assuming all arrays have the same shape)
array_shape = next(iter(data.values())).shape[1:]
# Initialize an empty array with the same shape as the arrays in the dictionary, filled with zeros
output_array = np.zeros(array_shape, dtype=np.uint8)
# Iterate over each key and array in the dictionary
for key, array in data.items():
# Assign the key value wherever the boolean array is True
output_array[array[0]] = key
if crs_source is None:
# Convert the output array to a PIL image
image = Image.fromarray(output_array)
# Save the image
image.save(output_path)
else:
output_path = output_path.replace(".png", ".tif")
common.array_to_image(output_array, output_path, crs_source, **kwargs)
num_frames = len(self.video_segments)
num_digits = len(str(num_frames))
if hasattr(self, "_tif_source") and self._tif_source.endswith(".tif"):
crs_source = self._tif_source
filenames = self._tif_names
else:
crs_source = None
filenames = None
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Initialize the tqdm progress bar
for frame_idx, video_segment in tqdm(
self.video_segments.items(), desc="Rendering frames", total=num_frames
):
if filenames is None:
output_path = os.path.join(
output_dir, f"{str(frame_idx).zfill(num_digits)}.{img_ext}"
)
else:
output_path = os.path.join(output_dir, filenames[frame_idx])
save_image_from_dict(video_segment, output_path, crs_source)
def save_video_segments_blended(
self,
output_dir: str,
img_ext: str = "png",
alpha: float = 0.6,
dpi: int = 200,
frame_stride: int = 1,
output_video: Optional[str] = None,
fps: int = 30,
) -> None:
"""Save blended video segments to the output directory and optionally create a video.
Args:
output_dir (str): The directory to save the output images.
img_ext (str): The file extension for the output images. Defaults to "png".
alpha (float): The alpha value for the blended masks. Defaults to 0.6.
dpi (int): The DPI (dots per inch) for the output images. Defaults to 200.
frame_stride (int): The stride for selecting frames to save. Defaults to 1.
output_video (Optional[str]): The path to the output video file. Defaults to None.
fps (int): The frames per second for the output video. Defaults to 30.
"""
from PIL import Image
def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], alpha])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
plt.close("all")
video_segments = self.video_segments
video_dir = self.video_path
frame_names = self._frame_names
num_frames = len(frame_names)
num_digits = len(str(num_frames))
# Initialize the tqdm progress bar
for out_frame_idx in tqdm(
range(0, len(frame_names), frame_stride), desc="Rendering frames"
):
image = Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))
# Get original image dimensions
w, h = image.size
# Set DPI and calculate figure size based on the original image dimensions
figsize = (
w / dpi,
h / dpi,
)
figsize = (
figsize[0] * 1.3,
figsize[1] * 1.3,
)
# Create a figure with the exact size and DPI
fig = plt.figure(figsize=figsize, dpi=dpi)
# Disable axis to prevent whitespace
plt.axis("off")
# Display the original image
plt.imshow(image)
# Overlay masks for each object ID
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
# Save the figure with no borders or extra padding
filename = f"{str(out_frame_idx).zfill(num_digits)}.{img_ext}"
filepath = os.path.join(output_dir, filename)
plt.savefig(filepath, dpi=dpi, pad_inches=0, bbox_inches="tight")
plt.close(fig)
if output_video is not None:
common.images_to_video(output_dir, output_video, fps=fps)
def show_images(self, path: str = None) -> None:
"""Show the images in the video.
Args:
path (str, optional): The path to the images. Defaults to None.
"""
if path is None:
path = self.video_path
if path is not None:
common.show_image_gui(path)
def show_prompts(
self,
prompts: Dict[int, Any],
frame_idx: int = 0,
mask: Any = None,
random_color: bool = False,
point_crs: Optional[str] = None,
figsize: Tuple[int, int] = (9, 6),
) -> None:
"""Show the prompts on the image.
Args:
prompts (Dict[int, Any]): A dictionary containing the prompts with
points and labels.
frame_idx (int, optional): The frame index. Defaults to 0.
mask (Any, optional): The mask. Defaults to None.
random_color (bool, optional): Whether to use random colors for the
masks. Defaults to False.
point_crs (Optional[str], optional): The coordinate reference system
figsize (Tuple[int, int], optional): The figure size. Defaults to (9, 6).
"""
from PIL import Image
def show_mask(mask, ax, obj_id=None, random_color=random_color):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle(
(x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2
)
)
if point_crs is not None and self._tif_source is not None:
for prompt in prompts.values():
points = prompt.get("points", None)
if points is not None:
points = common.coords_to_xy(self._tif_source, points, point_crs)
prompt["points"] = points
box = prompt.get("box", None)
if box is not None:
box = common.bbox_to_xy(self._tif_source, box, point_crs)
prompt["box"] = box
prompts = self._convert_prompts(prompts)
self.prompts = prompts
video_dir = self.video_path
frame_names = self._frame_names
fig = plt.figure(figsize=figsize)
fig.canvas.toolbar_visible = True
fig.canvas.header_visible = False
fig.canvas.footer_visible = True
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
for obj_id, prompt in prompts.items():
points = prompt.get("points", None)
labels = prompt.get("labels", None)
box = prompt.get("box", None)
anno_frame_idx = prompt.get("frame_idx", None)
if anno_frame_idx == frame_idx:
if points is not None:
show_points(points, labels, plt.gca())
if box is not None:
show_box(box, plt.gca())
if mask is not None:
show_mask(mask, plt.gca(), obj_id=obj_id)
plt.show()
def raster_to_vector(self, raster, vector, simplify_tolerance=None, **kwargs):
"""Convert a raster image file to a vector dataset.
Args:
raster (str): The path to the raster image.
output (str): The path to the vector file.
simplify_tolerance (float, optional): The maximum allowed geometry displacement.
The higher this value, the smaller the number of vertices in the resulting geometry.
"""
common.raster_to_vector(
raster, vector, simplify_tolerance=simplify_tolerance, **kwargs
)
__init__(self, model_id='sam2-hiera-large', device=None, empty_cache=True, automatic=True, video=False, mode='eval', hydra_overrides_extra=None, apply_postprocessing=False, points_per_side=32, points_per_batch=64, pred_iou_thresh=0.8, stability_score_thresh=0.95, stability_score_offset=1.0, mask_threshold=0.0, box_nms_thresh=0.7, crop_n_layers=0, crop_nms_thresh=0.7, crop_overlap_ratio=0.3413333333333333, crop_n_points_downscale_factor=1, point_grids=None, min_mask_region_area=0, output_mode='binary_mask', use_m2m=False, multimask_output=True, max_hole_area=0.0, max_sprinkle_area=0.0, **kwargs)
special
¶
Initializes the SamGeo2 class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_id |
str |
The model ID to use. Can be one of the following: "sam2-hiera-tiny", "sam2-hiera-small", "sam2-hiera-base-plus", "sam2-hiera-large". Defaults to "sam2-hiera-large". |
'sam2-hiera-large' |
device |
Optional[str] |
The device to use (e.g., "cpu", "cuda", "mps"). Defaults to None. |
None |
empty_cache |
bool |
Whether to empty the cache. Defaults to True. |
True |
automatic |
bool |
Whether to use automatic mask generation. Defaults to True. |
True |
video |
bool |
Whether to use video prediction. Defaults to False. |
False |
mode |
str |
The mode to use. Defaults to "eval". |
'eval' |
hydra_overrides_extra |
Optional[List[str]] |
Additional Hydra overrides. Defaults to None. |
None |
apply_postprocessing |
bool |
Whether to apply postprocessing. Defaults to False. |
False |
points_per_side |
int or None |
The number of points to be sampled along one side of the image. The total number of points is points_per_side**2. If None, 'point_grids' must provide explicit point sampling. |
32 |
points_per_batch |
int |
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. |
64 |
pred_iou_thresh |
float |
A filtering threshold in [0,1], using the model's predicted mask quality. |
0.8 |
stability_score_thresh |
float |
A filtering threshold in [0,1], using the stability of the mask under changes to the cutoff used to binarize the model's mask predictions. |
0.95 |
stability_score_offset |
float |
The amount to shift the cutoff when calculated the stability score. |
1.0 |
mask_threshold |
float |
Threshold for binarizing the mask logits |
0.0 |
box_nms_thresh |
float |
The box IoU cutoff used by non-maximal suppression to filter duplicate masks. |
0.7 |
crop_n_layers |
int |
If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where each layer has 2**i_layer number of image crops. |
0 |
crop_nms_thresh |
float |
The box IoU cutoff used by non-maximal suppression to filter duplicate masks between different crops. |
0.7 |
crop_overlap_ratio |
float |
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the image length. Later layers with more crops scale down this overlap. |
0.3413333333333333 |
crop_n_points_downscale_factor |
int |
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. |
1 |
point_grids |
list(np.ndarray) or None |
A list over explicit grids of points used for sampling, normalized to [0,1]. The nth grid in the list is used in the nth crop layer. Exclusive with points_per_side. |
None |
min_mask_region_area |
int |
If >0, postprocessing will be applied to remove disconnected regions and holes in masks with area smaller than min_mask_region_area. Requires opencv. |
0 |
output_mode |
str |
The form masks are returned in. Can be 'binary_mask', 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. For large resolutions, 'binary_mask' may consume large amounts of memory. |
'binary_mask' |
use_m2m |
bool |
Whether to add a one step refinement using previous mask predictions. |
False |
multimask_output |
bool |
Whether to output multimask at each point of the grid. |
True |
max_hole_area |
int |
If max_hole_area > 0, we fill small holes in up to the maximum area of max_hole_area in low_res_masks. |
0.0 |
max_sprinkle_area |
int |
If max_sprinkle_area > 0, we remove small sprinkles up to the maximum area of max_sprinkle_area in low_res_masks. |
0.0 |
**kwargs |
Any |
Additional keyword arguments to pass to SAM2AutomaticMaskGenerator.from_pretrained() or SAM2ImagePredictor.from_pretrained(). |
{} |
Source code in samgeo/samgeo2.py
def __init__(
self,
model_id: str = "sam2-hiera-large",
device: Optional[str] = None,
empty_cache: bool = True,
automatic: bool = True,
video: bool = False,
mode: str = "eval",
hydra_overrides_extra: Optional[List[str]] = None,
apply_postprocessing: bool = False,
points_per_side: Optional[int] = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.8,
stability_score_thresh: float = 0.95,
stability_score_offset: float = 1.0,
mask_threshold: float = 0.0,
box_nms_thresh: float = 0.7,
crop_n_layers: int = 0,
crop_nms_thresh: float = 0.7,
crop_overlap_ratio: float = 512 / 1500,
crop_n_points_downscale_factor: int = 1,
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
use_m2m: bool = False,
multimask_output: bool = True,
max_hole_area: float = 0.0,
max_sprinkle_area: float = 0.0,
**kwargs: Any,
) -> None:
"""
Initializes the SamGeo2 class.
Args:
model_id (str): The model ID to use. Can be one of the following: "sam2-hiera-tiny",
"sam2-hiera-small", "sam2-hiera-base-plus", "sam2-hiera-large".
Defaults to "sam2-hiera-large".
device (Optional[str]): The device to use (e.g., "cpu", "cuda", "mps"). Defaults to None.
empty_cache (bool): Whether to empty the cache. Defaults to True.
automatic (bool): Whether to use automatic mask generation. Defaults to True.
video (bool): Whether to use video prediction. Defaults to False.
mode (str): The mode to use. Defaults to "eval".
hydra_overrides_extra (Optional[List[str]]): Additional Hydra overrides. Defaults to None.
apply_postprocessing (bool): Whether to apply postprocessing. Defaults to False.
points_per_side (int or None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_per_batch (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
pred_iou_thresh (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
mask_threshold (float): Threshold for binarizing the mask logits
box_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks.
crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray) or None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
min_mask_region_area (int): If >0, postprocessing will be applied
to remove disconnected regions and holes in masks with area smaller
than min_mask_region_area. Requires opencv.
output_mode (str): The form masks are returned in. Can be 'binary_mask',
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
multimask_output (bool): Whether to output multimask at each point of the grid.
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
**kwargs (Any): Additional keyword arguments to pass to
SAM2AutomaticMaskGenerator.from_pretrained() or SAM2ImagePredictor.from_pretrained().
"""
if isinstance(model_id, str):
if not model_id.startswith("facebook/"):
model_id = f"facebook/{model_id}"
else:
raise ValueError("model_id must be a string")
allowed_models = [
"facebook/sam2-hiera-tiny",
"facebook/sam2-hiera-small",
"facebook/sam2-hiera-base-plus",
"facebook/sam2-hiera-large",
]
if model_id not in allowed_models:
raise ValueError(
f"model_id must be one of the following: {', '.join(allowed_models)}"
)
if device is None:
device = common.choose_device(empty_cache=empty_cache)
if hydra_overrides_extra is None:
hydra_overrides_extra = []
self.model_id = model_id
self.device = device
if video:
automatic = False
if automatic:
self.mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
model_id,
device=device,
mode=mode,
hydra_overrides_extra=hydra_overrides_extra,
apply_postprocessing=apply_postprocessing,
points_per_side=points_per_side,
points_per_batch=points_per_batch,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
stability_score_offset=stability_score_offset,
mask_threshold=mask_threshold,
box_nms_thresh=box_nms_thresh,
crop_n_layers=crop_n_layers,
crop_nms_thresh=crop_nms_thresh,
crop_overlap_ratio=crop_overlap_ratio,
crop_n_points_downscale_factor=crop_n_points_downscale_factor,
point_grids=point_grids,
min_mask_region_area=min_mask_region_area,
output_mode=output_mode,
use_m2m=use_m2m,
multimask_output=multimask_output,
**kwargs,
)
elif video:
self.predictor = SAM2VideoPredictor.from_pretrained(
model_id,
device=device,
mode=mode,
hydra_overrides_extra=hydra_overrides_extra,
apply_postprocessing=apply_postprocessing,
**kwargs,
)
else:
self.predictor = SAM2ImagePredictor.from_pretrained(
model_id,
device=device,
mode=mode,
hydra_overrides_extra=hydra_overrides_extra,
apply_postprocessing=apply_postprocessing,
mask_threshold=mask_threshold,
max_hole_area=max_hole_area,
max_sprinkle_area=max_sprinkle_area,
**kwargs,
)
add_new_mask(self, inference_state, frame_idx, obj_id, mask)
¶
Add a new mask to the inference state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inference_state |
Any |
The current inference state. |
required |
frame_idx |
int |
The frame index. |
required |
obj_id |
int |
The object ID. |
required |
mask |
np.ndarray |
The mask to add. |
required |
Returns:
Type | Description |
---|---|
Any |
The updated inference state. |
Source code in samgeo/samgeo2.py
@torch.inference_mode()
def add_new_mask(
self,
inference_state: Any,
frame_idx: int,
obj_id: int,
mask: np.ndarray,
) -> Any:
"""Add a new mask to the inference state.
Args:
inference_state (Any): The current inference state.
frame_idx (int): The frame index.
obj_id (int): The object ID.
mask (np.ndarray): The mask to add.
Returns:
Any: The updated inference state.
"""
return self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask)
add_new_points_or_box(self, inference_state, frame_idx, obj_id, points=None, labels=None, clear_old_points=True, normalize_coords=True, box=None)
¶
Add new points or a box to the inference state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inference_state |
Any |
The current inference state. |
required |
frame_idx |
int |
The frame index. |
required |
obj_id |
int |
The object ID. |
required |
points |
Optional[np.ndarray] |
The points to add. Defaults to None. |
None |
labels |
Optional[np.ndarray] |
The labels for the points. Defaults to None. |
None |
clear_old_points |
bool |
Whether to clear old points. Defaults to True. |
True |
normalize_coords |
bool |
Whether to normalize the coordinates. Defaults to True. |
True |
box |
Optional[np.ndarray] |
The bounding box to add. Defaults to None. |
None |
Returns:
Type | Description |
---|---|
Any |
The updated inference state. |
Source code in samgeo/samgeo2.py
@torch.inference_mode()
def add_new_points_or_box(
self,
inference_state: Any,
frame_idx: int,
obj_id: int,
points: Optional[np.ndarray] = None,
labels: Optional[np.ndarray] = None,
clear_old_points: bool = True,
normalize_coords: bool = True,
box: Optional[np.ndarray] = None,
) -> Any:
"""Add new points or a box to the inference state.
Args:
inference_state (Any): The current inference state.
frame_idx (int): The frame index.
obj_id (int): The object ID.
points (Optional[np.ndarray]): The points to add. Defaults to None.
labels (Optional[np.ndarray]): The labels for the points. Defaults to None.
clear_old_points (bool): Whether to clear old points. Defaults to True.
normalize_coords (bool): Whether to normalize the coordinates. Defaults to True.
box (Optional[np.ndarray]): The bounding box to add. Defaults to None.
Returns:
Any: The updated inference state.
"""
return self.predictor.add_new_points_or_box(
inference_state,
frame_idx,
obj_id,
points=points,
labels=labels,
clear_old_points=clear_old_points,
normalize_coords=normalize_coords,
box=box,
)
generate(self, source, output=None, foreground=True, erosion_kernel=None, mask_multiplier=255, unique=True, min_size=0, max_size=None, **kwargs)
¶
Generate masks for the input image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
Union[str, np.ndarray] |
The path to the input image or the input image as a numpy array. |
required |
output |
Optional[str] |
The path to the output image. Defaults to None. |
None |
foreground |
bool |
Whether to generate the foreground mask. Defaults to True. |
True |
erosion_kernel |
Optional[Tuple[int, int]] |
The erosion kernel for filtering object masks and extract borders. Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None. |
None |
mask_multiplier |
int |
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. The parameter is ignored if unique is True. |
255 |
unique |
bool |
Whether to assign a unique value to each object. Defaults to True. The unique value increases from 1 to the number of objects. The larger the number, the larger the object area. |
True |
min_size |
int |
The minimum size of the object. Defaults to 0. |
0 |
max_size |
int |
The maximum size of the object. Defaults to None. |
None |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
List[Dict[str, Any]] |
A list of dictionaries containing the generated masks. |
Source code in samgeo/samgeo2.py
def generate(
self,
source: Union[str, np.ndarray],
output: Optional[str] = None,
foreground: bool = True,
erosion_kernel: Optional[Tuple[int, int]] = None,
mask_multiplier: int = 255,
unique: bool = True,
min_size: int = 0,
max_size: int = None,
**kwargs: Any,
) -> List[Dict[str, Any]]:
"""
Generate masks for the input image.
Args:
source (Union[str, np.ndarray]): The path to the input image or the
input image as a numpy array.
output (Optional[str]): The path to the output image. Defaults to None.
foreground (bool): Whether to generate the foreground mask. Defaults
to True.
erosion_kernel (Optional[Tuple[int, int]]): The erosion kernel for
filtering object masks and extract borders.
Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
mask_multiplier (int): The mask multiplier for the output mask,
which is usually a binary mask [0, 1].
You can use this parameter to scale the mask to a larger range,
for example [0, 255]. Defaults to 255.
The parameter is ignored if unique is True.
unique (bool): Whether to assign a unique value to each object.
Defaults to True.
The unique value increases from 1 to the number of objects. The
larger the number, the larger the object area.
min_size (int): The minimum size of the object. Defaults to 0.
max_size (int): The maximum size of the object. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the generated masks.
"""
if isinstance(source, str):
if source.startswith("http"):
source = common.download_file(source)
if not os.path.exists(source):
raise ValueError(f"Input path {source} does not exist.")
image = cv2.imread(source)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif isinstance(source, np.ndarray):
image = source
source = None
else:
raise ValueError("Input source must be either a path or a numpy array.")
self.source = source # Store the input image path
self.image = image # Store the input image as a numpy array
mask_generator = self.mask_generator # The automatic mask generator
masks = mask_generator.generate(image) # Segment the input image
self.masks = masks # Store the masks as a list of dictionaries
self._min_size = min_size
self._max_size = max_size
if output is not None:
# Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
self.save_masks(
output,
foreground,
unique,
erosion_kernel,
mask_multiplier,
min_size,
max_size,
**kwargs,
)
init_state(self, video_path, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False)
¶
Initialize an inference state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
video_path |
str |
The path to the video file. |
required |
offload_video_to_cpu |
bool |
Whether to offload the video to CPU. Defaults to False. |
False |
offload_state_to_cpu |
bool |
Whether to offload the state to CPU. Defaults to False. |
False |
async_loading_frames |
bool |
Whether to load frames asynchronously. Defaults to False. |
False |
Returns:
Type | Description |
---|---|
Any |
The initialized inference state. |
Source code in samgeo/samgeo2.py
@torch.inference_mode()
def init_state(
self,
video_path: str,
offload_video_to_cpu: bool = False,
offload_state_to_cpu: bool = False,
async_loading_frames: bool = False,
) -> Any:
"""Initialize an inference state.
Args:
video_path (str): The path to the video file.
offload_video_to_cpu (bool): Whether to offload the video to CPU.
Defaults to False.
offload_state_to_cpu (bool): Whether to offload the state to CPU.
Defaults to False.
async_loading_frames (bool): Whether to load frames asynchronously.
Defaults to False.
Returns:
Any: The initialized inference state.
"""
return self.predictor.init_state(
video_path,
offload_video_to_cpu=offload_video_to_cpu,
offload_state_to_cpu=offload_state_to_cpu,
async_loading_frames=async_loading_frames,
)
predict(self, point_coords=None, point_labels=None, boxes=None, mask_input=None, multimask_output=True, return_logits=False, normalize_coords=True, point_crs=None, output=None, index=None, mask_multiplier=255, dtype='float32', return_results=False, **kwargs)
¶
Predict the mask for the input image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
point_coords |
np.ndarray |
The point coordinates. Defaults to None. |
None |
point_labels |
np.ndarray |
The point labels. Defaults to None. |
None |
boxes |
list | np.ndarray |
A length 4 array given a box prompt to the model, in XYXY format. |
None |
mask_input |
np.ndarray |
A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. multimask_output (bool, optional): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. |
None |
multimask_output |
bool |
Whether to output multimask at each point of the grid. Defaults to True. |
True |
return_logits |
bool |
If true, returns un-thresholded masks logits instead of a binary mask. |
False |
normalize_coords |
bool |
Whether to normalize the coordinates. Defaults to True. |
True |
point_crs |
str |
The coordinate reference system (CRS) of the point prompts. |
None |
output |
str |
The path to the output image. Defaults to None. |
None |
index |
index |
The index of the mask to save. Defaults to None, which will save the mask with the highest score. |
None |
mask_multiplier |
int |
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. |
255 |
dtype |
np.dtype |
The data type of the output image. Defaults to np.float32. |
'float32' |
return_results |
bool |
Whether to return the predicted masks, scores, and logits. Defaults to False. |
False |
Returns:
Type | Description |
---|---|
Tuple[np.ndarray, np.ndarray, np.ndarray] |
The mask, the multimask, and the logits. |
Source code in samgeo/samgeo2.py
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
boxes: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
normalize_coords: bool = True,
point_crs: Optional[str] = None,
output: Optional[str] = None,
index: Optional[int] = None,
mask_multiplier: int = 255,
dtype: str = "float32",
return_results: bool = False,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Predict the mask for the input image.
Args:
point_coords (np.ndarray, optional): The point coordinates. Defaults to None.
point_labels (np.ndarray, optional): The point labels. Defaults to None.
boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
multimask_output (bool, optional): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
multimask_output (bool, optional): Whether to output multimask at each
point of the grid. Defaults to True.
return_logits (bool, optional): If true, returns un-thresholded masks logits
instead of a binary mask.
normalize_coords (bool, optional): Whether to normalize the coordinates.
Defaults to True.
point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
output (str, optional): The path to the output image. Defaults to None.
index (index, optional): The index of the mask to save. Defaults to None,
which will save the mask with the highest score.
mask_multiplier (int, optional): The mask multiplier for the output mask,
which is usually a binary mask [0, 1].
dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
return_results (bool, optional): Whether to return the predicted masks,
scores, and logits. Defaults to False.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: The mask, the multimask,
and the logits.
"""
import geopandas as gpd
out_of_bounds = []
if isinstance(boxes, str):
gdf = gpd.read_file(boxes)
if gdf.crs is not None:
gdf = gdf.to_crs("epsg:4326")
boxes = gdf.geometry.bounds.values.tolist()
elif isinstance(boxes, dict):
import json
geojson = json.dumps(boxes)
gdf = gpd.read_file(geojson, driver="GeoJSON")
boxes = gdf.geometry.bounds.values.tolist()
if isinstance(point_coords, str):
point_coords = common.vector_to_geojson(point_coords)
if isinstance(point_coords, dict):
point_coords = common.geojson_to_coords(point_coords)
if hasattr(self, "point_coords"):
point_coords = self.point_coords
if hasattr(self, "point_labels"):
point_labels = self.point_labels
if (point_crs is not None) and (point_coords is not None):
point_coords, out_of_bounds = common.coords_to_xy(
self.source, point_coords, point_crs, return_out_of_bounds=True
)
if isinstance(point_coords, list):
point_coords = np.array(point_coords)
if point_coords is not None:
if point_labels is None:
point_labels = [1] * len(point_coords)
elif isinstance(point_labels, int):
point_labels = [point_labels] * len(point_coords)
if isinstance(point_labels, list):
if len(point_labels) != len(point_coords):
if len(point_labels) == 1:
point_labels = point_labels * len(point_coords)
elif len(out_of_bounds) > 0:
print(f"Removing {len(out_of_bounds)} out-of-bound points.")
point_labels_new = []
for i, p in enumerate(point_labels):
if i not in out_of_bounds:
point_labels_new.append(p)
point_labels = point_labels_new
else:
raise ValueError(
"The length of point_labels must be equal to the length of point_coords."
)
point_labels = np.array(point_labels)
predictor = self.predictor
input_boxes = None
if isinstance(boxes, list) and (point_crs is not None):
coords = common.bbox_to_xy(self.source, boxes, point_crs)
input_boxes = np.array(coords)
if isinstance(coords[0], int):
input_boxes = input_boxes[None, :]
else:
input_boxes = torch.tensor(input_boxes, device=self.device)
input_boxes = predictor.transform.apply_boxes_torch(
input_boxes, self.image.shape[:2]
)
elif isinstance(boxes, list) and (point_crs is None):
input_boxes = np.array(boxes)
if isinstance(boxes[0], int):
input_boxes = input_boxes[None, :]
self.boxes = input_boxes
if (
boxes is None
or (len(boxes) == 1)
or (len(boxes) == 4 and isinstance(boxes[0], float))
):
if isinstance(boxes, list) and isinstance(boxes[0], list):
boxes = boxes[0]
masks, scores, logits = predictor.predict(
point_coords,
point_labels,
input_boxes,
mask_input,
multimask_output,
return_logits,
)
else:
masks, scores, logits = predictor.predict_torch(
point_coords=point_coords,
point_labels=point_coords,
boxes=input_boxes,
multimask_output=True,
)
self.masks = masks
self.scores = scores
self.logits = logits
if output is not None:
if boxes is None or (not isinstance(boxes[0], list)):
self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
else:
self.tensor_to_numpy(
index, output, mask_multiplier, dtype, save_args=kwargs
)
if return_results:
return masks, scores, logits
return self.predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
box=boxes,
mask_input=mask_input,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)
predict_batch(self, point_coords_batch=None, point_labels_batch=None, box_batch=None, mask_input_batch=None, multimask_output=True, return_logits=False, normalize_coords=True)
¶
Predict masks for a batch of images.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
point_coords_batch |
Optional[List[np.ndarray]] |
A batch of point coordinates. Defaults to None. |
None |
point_labels_batch |
Optional[List[np.ndarray]] |
A batch of point labels. Defaults to None. |
None |
box_batch |
Optional[List[np.ndarray]] |
A batch of bounding boxes. Defaults to None. |
None |
mask_input_batch |
Optional[List[np.ndarray]] |
A batch of mask inputs. Defaults to None. |
None |
multimask_output |
bool |
Whether to output multimask at each point of the grid. Defaults to True. |
True |
return_logits |
bool |
Whether to return the logits. Defaults to False. |
False |
normalize_coords |
bool |
Whether to normalize the coordinates. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]] |
Lists of masks, multimasks, and logits. |
Source code in samgeo/samgeo2.py
def predict_batch(
self,
point_coords_batch: List[np.ndarray] = None,
point_labels_batch: List[np.ndarray] = None,
box_batch: List[np.ndarray] = None,
mask_input_batch: List[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
normalize_coords=True,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""Predict masks for a batch of images.
Args:
point_coords_batch (Optional[List[np.ndarray]]): A batch of point
coordinates. Defaults to None.
point_labels_batch (Optional[List[np.ndarray]]): A batch of point
labels. Defaults to None.
box_batch (Optional[List[np.ndarray]]): A batch of bounding boxes.
Defaults to None.
mask_input_batch (Optional[List[np.ndarray]]): A batch of mask inputs.
Defaults to None.
multimask_output (bool): Whether to output multimask at each point
of the grid. Defaults to True.
return_logits (bool): Whether to return the logits. Defaults to False.
normalize_coords (bool): Whether to normalize the coordinates.
Defaults to True.
Returns:
Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: Lists
of masks, multimasks, and logits.
"""
return self.predictor.predict_batch(
point_coords_batch=point_coords_batch,
point_labels_batch=point_labels_batch,
box_batch=box_batch,
mask_input_batch=mask_input_batch,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)
predict_video(self, prompts=None, point_crs=None, output_dir=None, img_ext='png')
¶
Predict masks for the video.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prompts |
Dict[int, Any] |
A dictionary containing the prompts with points and labels. |
None |
point_crs |
Optional[str] |
The coordinate reference system (CRS) of the point prompts. |
None |
output_dir |
Optional[str] |
The directory to save the output images. Defaults to None. |
None |
img_ext |
str |
The file extension for the output images. Defaults to "png". |
'png' |
Source code in samgeo/samgeo2.py
def predict_video(
self,
prompts: Dict[int, Any] = None,
point_crs: Optional[str] = None,
output_dir: Optional[str] = None,
img_ext: str = "png",
) -> None:
"""Predict masks for the video.
Args:
prompts (Dict[int, Any]): A dictionary containing the prompts with points and labels.
point_crs (Optional[str]): The coordinate reference system (CRS) of the point prompts.
output_dir (Optional[str]): The directory to save the output images. Defaults to None.
img_ext (str): The file extension for the output images. Defaults to "png".
"""
from PIL import Image
def save_image_from_dict(data, output_path="output_image.png"):
# Find the shape of the first array in the dictionary (assuming all arrays have the same shape)
array_shape = next(iter(data.values())).shape[1:]
# Initialize an empty array with the same shape as the arrays in the dictionary, filled with zeros
output_array = np.zeros(array_shape, dtype=np.uint8)
# Iterate over each key and array in the dictionary
for key, array in data.items():
# Assign the key value wherever the boolean array is True
output_array[array[0]] = key
# Convert the output array to a PIL image
image = Image.fromarray(output_array)
# Save the image
image.save(output_path)
if prompts is None:
if hasattr(self, "prompts"):
prompts = self.prompts
else:
raise ValueError("Please provide prompts.")
if point_crs is not None and self._tif_source is not None:
for prompt in prompts.values():
points = prompt.get("points", None)
if points is not None:
points = common.coords_to_xy(self._tif_source, points, point_crs)
prompt["points"] = points
box = prompt.get("box", None)
if box is not None:
box = common.bbox_to_xy(self._tif_source, box, point_crs)
prompt["box"] = box
prompts = self._convert_prompts(prompts)
predictor = self.predictor
inference_state = self.inference_state
for obj_id, prompt in prompts.items():
points = prompt.get("points", None)
labels = prompt.get("labels", None)
box = prompt.get("box", None)
frame_idx = prompt.get("frame_idx", None)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=obj_id,
points=points,
labels=labels,
box=box,
)
video_segments = {}
num_frames = self._num_images
num_digits = len(str(num_frames))
if output_dir is not None:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
if output_dir is not None:
output_path = os.path.join(
output_dir, f"{str(out_frame_idx).zfill(num_digits)}.{img_ext}"
)
save_image_from_dict(video_segments[out_frame_idx], output_path)
self.video_segments = video_segments
# if output_dir is not None:
# self.save_video_segments(output_dir, img_ext)
propagate_in_video(self, inference_state, start_frame_idx=None, max_frame_num_to_track=None, reverse=False)
¶
Propagate the inference state in video.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inference_state |
Any |
The current inference state. |
required |
start_frame_idx |
Optional[int] |
The starting frame index. Defaults to None. |
None |
max_frame_num_to_track |
Optional[int] |
The maximum number of frames to track. Defaults to None. |
None |
reverse |
bool |
Whether to propagate in reverse. Defaults to False. |
False |
Returns:
Type | Description |
---|---|
Any |
The propagated inference state. |
Source code in samgeo/samgeo2.py
@torch.inference_mode()
def propagate_in_video(
self,
inference_state: Any,
start_frame_idx: Optional[int] = None,
max_frame_num_to_track: Optional[int] = None,
reverse: bool = False,
) -> Any:
"""Propagate the inference state in video.
Args:
inference_state (Any): The current inference state.
start_frame_idx (Optional[int]): The starting frame index. Defaults to None.
max_frame_num_to_track (Optional[int]): The maximum number of frames
to track. Defaults to None.
reverse (bool): Whether to propagate in reverse. Defaults to False.
Returns:
Any: The propagated inference state.
"""
return self.predictor.propagate_in_video(
inference_state,
start_frame_idx=start_frame_idx,
max_frame_num_to_track=max_frame_num_to_track,
reverse=reverse,
)
propagate_in_video_preflight(self, inference_state)
¶
Propagate the inference state in video preflight.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inference_state |
Any |
The current inference state. |
required |
Returns:
Type | Description |
---|---|
Any |
The propagated inference state. |
Source code in samgeo/samgeo2.py
@torch.inference_mode()
def propagate_in_video_preflight(self, inference_state: Any) -> Any:
"""Propagate the inference state in video preflight.
Args:
inference_state (Any): The current inference state.
Returns:
Any: The propagated inference state.
"""
return self.predictor.propagate_in_video_preflight(inference_state)
raster_to_vector(self, raster, vector, simplify_tolerance=None, **kwargs)
¶
Convert a raster image file to a vector dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
raster |
str |
The path to the raster image. |
required |
output |
str |
The path to the vector file. |
required |
simplify_tolerance |
float |
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry. |
None |
Source code in samgeo/samgeo2.py
def raster_to_vector(self, raster, vector, simplify_tolerance=None, **kwargs):
"""Convert a raster image file to a vector dataset.
Args:
raster (str): The path to the raster image.
output (str): The path to the vector file.
simplify_tolerance (float, optional): The maximum allowed geometry displacement.
The higher this value, the smaller the number of vertices in the resulting geometry.
"""
common.raster_to_vector(
raster, vector, simplify_tolerance=simplify_tolerance, **kwargs
)
reset_state(self, inference_state)
¶
Remove all input points or masks in all frames throughout the video.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inference_state |
Any |
The current inference state. |
required |
Source code in samgeo/samgeo2.py
@torch.inference_mode()
def reset_state(self, inference_state: Any) -> None:
"""Remove all input points or masks in all frames throughout the video.
Args:
inference_state (Any): The current inference state.
"""
self.predictor.reset_state(inference_state)
save_masks(self, output=None, foreground=True, unique=True, erosion_kernel=None, mask_multiplier=255, min_size=0, max_size=None, **kwargs)
¶
Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
str |
The path to the output image. Defaults to None, saving the masks to SamGeo.objects. |
None |
foreground |
bool |
Whether to generate the foreground mask. Defaults to True. |
True |
unique |
bool |
Whether to assign a unique value to each object. Defaults to True. |
True |
erosion_kernel |
tuple |
The erosion kernel for filtering object masks and extract borders. Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None. |
None |
mask_multiplier |
int |
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. |
255 |
min_size |
int |
The minimum size of the object. Defaults to 0. |
0 |
max_size |
int |
The maximum size of the object. Defaults to None. |
None |
**kwargs |
Any |
Additional keyword arguments for common.array_to_image(). |
{} |
Source code in samgeo/samgeo2.py
def save_masks(
self,
output: Optional[str] = None,
foreground: bool = True,
unique: bool = True,
erosion_kernel: Optional[Tuple[int, int]] = None,
mask_multiplier: int = 255,
min_size: int = 0,
max_size: int = None,
**kwargs: Any,
) -> None:
"""Save the masks to the output path. The output is either a binary mask
or a mask of objects with unique values.
Args:
output (str, optional): The path to the output image. Defaults to
None, saving the masks to SamGeo.objects.
foreground (bool, optional): Whether to generate the foreground mask.
Defaults to True.
unique (bool, optional): Whether to assign a unique value to each
object. Defaults to True.
erosion_kernel (tuple, optional): The erosion kernel for filtering
object masks and extract borders.
Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to
None.
mask_multiplier (int, optional): The mask multiplier for the output
mask, which is usually a binary mask [0, 1]. You can use this
parameter to scale the mask to a larger range, for example
[0, 255]. Defaults to 255.
min_size (int, optional): The minimum size of the object. Defaults to 0.
max_size (int, optional): The maximum size of the object. Defaults to None.
**kwargs: Additional keyword arguments for common.array_to_image().
"""
if self.masks is None:
raise ValueError("No masks found. Please run generate() first.")
h, w, _ = self.image.shape
masks = self.masks
# Set output image data type based on the number of objects
if len(masks) < 255:
dtype = np.uint8
elif len(masks) < 65535:
dtype = np.uint16
else:
dtype = np.uint32
# Generate a mask of objects with unique values
if unique:
# Sort the masks by area in descending order
sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
# Create an output image with the same size as the input image
objects = np.zeros(
(
sorted_masks[0]["segmentation"].shape[0],
sorted_masks[0]["segmentation"].shape[1],
)
)
# Assign a unique value to each object
count = len(sorted_masks)
for index, ann in enumerate(sorted_masks):
m = ann["segmentation"]
if min_size > 0 and ann["area"] < min_size:
continue
if max_size is not None and ann["area"] > max_size:
continue
objects[m] = count - index
# Generate a binary mask
else:
if foreground: # Extract foreground objects only
resulting_mask = np.zeros((h, w), dtype=dtype)
else:
resulting_mask = np.ones((h, w), dtype=dtype)
resulting_borders = np.zeros((h, w), dtype=dtype)
for m in masks:
if min_size > 0 and m["area"] < min_size:
continue
if max_size is not None and m["area"] > max_size:
continue
mask = (m["segmentation"] > 0).astype(dtype)
resulting_mask += mask
# Apply erosion to the mask
if erosion_kernel is not None:
mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
mask_erode = (mask_erode > 0).astype(dtype)
edge_mask = mask - mask_erode
resulting_borders += edge_mask
resulting_mask = (resulting_mask > 0).astype(dtype)
resulting_borders = (resulting_borders > 0).astype(dtype)
objects = resulting_mask - resulting_borders
objects = objects * mask_multiplier
objects = objects.astype(dtype)
self.objects = objects
if output is not None: # Save the output image
common.array_to_image(self.objects, output, self.source, **kwargs)
save_prediction(self, output, index=None, mask_multiplier=255, dtype='float32', vector=None, simplify_tolerance=None, **kwargs)
¶
Save the predicted mask to the output path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
str |
The path to the output image. |
required |
index |
Optional[int] |
The index of the mask to save. Defaults to None, which will save the mask with the highest score. |
None |
mask_multiplier |
int |
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. |
255 |
dtype |
str |
The data type of the output image. Defaults to "float32". |
'float32' |
vector |
Optional[str] |
The path to the output vector file. Defaults to None. |
None |
simplify_tolerance |
Optional[float] |
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry. |
None |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in samgeo/samgeo2.py
def save_prediction(
self,
output: str,
index: Optional[int] = None,
mask_multiplier: int = 255,
dtype: str = "float32",
vector: Optional[str] = None,
simplify_tolerance: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Save the predicted mask to the output path.
Args:
output (str): The path to the output image.
index (Optional[int], optional): The index of the mask to save.
Defaults to None, which will save the mask with the highest score.
mask_multiplier (int, optional): The mask multiplier for the output
mask, which is usually a binary mask [0, 1].
dtype (str, optional): The data type of the output image. Defaults
to "float32".
vector (Optional[str], optional): The path to the output vector file.
Defaults to None.
simplify_tolerance (Optional[float], optional): The maximum allowed
geometry displacement. The higher this value, the smaller the
number of vertices in the resulting geometry.
**kwargs (Any): Additional keyword arguments.
"""
if self.scores is None:
raise ValueError("No predictions found. Please run predict() first.")
if index is None:
index = self.scores.argmax(axis=0)
array = self.masks[index] * mask_multiplier
self.prediction = array
common.array_to_image(array, output, self.source, dtype=dtype, **kwargs)
if vector is not None:
common.raster_to_vector(
output, vector, simplify_tolerance=simplify_tolerance
)
save_video_segments(self, output_dir, img_ext='png')
¶
Save the video segments to the output directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_dir |
str |
The path to the output directory. |
required |
img_ext |
str |
The file extension for the output images. Defaults to "png". |
'png' |
Source code in samgeo/samgeo2.py
def save_video_segments(self, output_dir: str, img_ext: str = "png") -> None:
"""Save the video segments to the output directory.
Args:
output_dir (str): The path to the output directory.
img_ext (str): The file extension for the output images. Defaults to "png".
"""
from PIL import Image
def save_image_from_dict(
data, output_path="output_image.png", crs_source=None, **kwargs
):
# Find the shape of the first array in the dictionary (assuming all arrays have the same shape)
array_shape = next(iter(data.values())).shape[1:]
# Initialize an empty array with the same shape as the arrays in the dictionary, filled with zeros
output_array = np.zeros(array_shape, dtype=np.uint8)
# Iterate over each key and array in the dictionary
for key, array in data.items():
# Assign the key value wherever the boolean array is True
output_array[array[0]] = key
if crs_source is None:
# Convert the output array to a PIL image
image = Image.fromarray(output_array)
# Save the image
image.save(output_path)
else:
output_path = output_path.replace(".png", ".tif")
common.array_to_image(output_array, output_path, crs_source, **kwargs)
num_frames = len(self.video_segments)
num_digits = len(str(num_frames))
if hasattr(self, "_tif_source") and self._tif_source.endswith(".tif"):
crs_source = self._tif_source
filenames = self._tif_names
else:
crs_source = None
filenames = None
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Initialize the tqdm progress bar
for frame_idx, video_segment in tqdm(
self.video_segments.items(), desc="Rendering frames", total=num_frames
):
if filenames is None:
output_path = os.path.join(
output_dir, f"{str(frame_idx).zfill(num_digits)}.{img_ext}"
)
else:
output_path = os.path.join(output_dir, filenames[frame_idx])
save_image_from_dict(video_segment, output_path, crs_source)
save_video_segments_blended(self, output_dir, img_ext='png', alpha=0.6, dpi=200, frame_stride=1, output_video=None, fps=30)
¶
Save blended video segments to the output directory and optionally create a video.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_dir |
str |
The directory to save the output images. |
required |
img_ext |
str |
The file extension for the output images. Defaults to "png". |
'png' |
alpha |
float |
The alpha value for the blended masks. Defaults to 0.6. |
0.6 |
dpi |
int |
The DPI (dots per inch) for the output images. Defaults to 200. |
200 |
frame_stride |
int |
The stride for selecting frames to save. Defaults to 1. |
1 |
output_video |
Optional[str] |
The path to the output video file. Defaults to None. |
None |
fps |
int |
The frames per second for the output video. Defaults to 30. |
30 |
Source code in samgeo/samgeo2.py
def save_video_segments_blended(
self,
output_dir: str,
img_ext: str = "png",
alpha: float = 0.6,
dpi: int = 200,
frame_stride: int = 1,
output_video: Optional[str] = None,
fps: int = 30,
) -> None:
"""Save blended video segments to the output directory and optionally create a video.
Args:
output_dir (str): The directory to save the output images.
img_ext (str): The file extension for the output images. Defaults to "png".
alpha (float): The alpha value for the blended masks. Defaults to 0.6.
dpi (int): The DPI (dots per inch) for the output images. Defaults to 200.
frame_stride (int): The stride for selecting frames to save. Defaults to 1.
output_video (Optional[str]): The path to the output video file. Defaults to None.
fps (int): The frames per second for the output video. Defaults to 30.
"""
from PIL import Image
def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], alpha])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
plt.close("all")
video_segments = self.video_segments
video_dir = self.video_path
frame_names = self._frame_names
num_frames = len(frame_names)
num_digits = len(str(num_frames))
# Initialize the tqdm progress bar
for out_frame_idx in tqdm(
range(0, len(frame_names), frame_stride), desc="Rendering frames"
):
image = Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))
# Get original image dimensions
w, h = image.size
# Set DPI and calculate figure size based on the original image dimensions
figsize = (
w / dpi,
h / dpi,
)
figsize = (
figsize[0] * 1.3,
figsize[1] * 1.3,
)
# Create a figure with the exact size and DPI
fig = plt.figure(figsize=figsize, dpi=dpi)
# Disable axis to prevent whitespace
plt.axis("off")
# Display the original image
plt.imshow(image)
# Overlay masks for each object ID
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
# Save the figure with no borders or extra padding
filename = f"{str(out_frame_idx).zfill(num_digits)}.{img_ext}"
filepath = os.path.join(output_dir, filename)
plt.savefig(filepath, dpi=dpi, pad_inches=0, bbox_inches="tight")
plt.close(fig)
if output_video is not None:
common.images_to_video(output_dir, output_video, fps=fps)
set_image(self, image)
¶
Set the input image as a numpy array.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
Union[str, np.ndarray, Image] |
The input image as a path, a numpy array, or an Image. |
required |
Source code in samgeo/samgeo2.py
@torch.no_grad()
def set_image(
self,
image: Union[str, np.ndarray, Image],
) -> None:
"""Set the input image as a numpy array.
Args:
image (Union[str, np.ndarray, Image]): The input image as a path,
a numpy array, or an Image.
"""
if isinstance(image, str):
if image.startswith("http"):
image = common.download_file(image)
if not os.path.exists(image):
raise ValueError(f"Input path {image} does not exist.")
self.source = image
image = cv2.imread(image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.image = image
elif isinstance(image, np.ndarray) or isinstance(image, Image):
pass
else:
raise ValueError("Input image must be either a path or a numpy array.")
self.predictor.set_image(image)
set_image_batch(self, image_list)
¶
Set a batch of images for prediction.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_list |
List[Union[np.ndarray, str, Image]] |
A list of images, |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If an input image path does not exist or if the input image type is not supported. |
Source code in samgeo/samgeo2.py
@torch.no_grad()
def set_image_batch(
self,
image_list: List[Union[np.ndarray, str, Image]],
) -> None:
"""Set a batch of images for prediction.
Args:
image_list (List[Union[np.ndarray, str, Image]]): A list of images,
which can be numpy arrays, file paths, or PIL images.
Raises:
ValueError: If an input image path does not exist or if the input
image type is not supported.
"""
images = []
for image in image_list:
if isinstance(image, str):
if image.startswith("http"):
image = common.download_file(image)
if not os.path.exists(image):
raise ValueError(f"Input path {image} does not exist.")
image = cv2.imread(image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif isinstance(image, Image):
image = np.array(image)
elif isinstance(image, np.ndarray):
pass
else:
raise ValueError("Input image must be either a path or a numpy array.")
images.append(image)
self.predictor.set_image_batch(images)
set_video(self, video_path, output_dir=None, frame_rate=None, prefix='')
¶
Set the video path and parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
video_path |
str |
The path to the video file. |
required |
start_frame |
int |
The starting frame index. Defaults to 0. |
required |
end_frame |
Optional[int] |
The ending frame index. Defaults to None. |
required |
step |
int |
The step size. Defaults to 1. |
required |
frame_rate |
Optional[int] |
The frame rate. Defaults to None. |
None |
Source code in samgeo/samgeo2.py
def set_video(
self,
video_path: str,
output_dir: str = None,
frame_rate: Optional[int] = None,
prefix: str = "",
) -> None:
"""Set the video path and parameters.
Args:
video_path (str): The path to the video file.
start_frame (int, optional): The starting frame index. Defaults to 0.
end_frame (Optional[int], optional): The ending frame index. Defaults to None.
step (int, optional): The step size. Defaults to 1.
frame_rate (Optional[int], optional): The frame rate. Defaults to None.
"""
if isinstance(video_path, str):
if video_path.startswith("http"):
video_path = common.download_file(video_path)
if os.path.isfile(video_path):
if output_dir is None:
output_dir = common.make_temp_dir()
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Output directory: {output_dir}")
common.video_to_images(
video_path, output_dir, frame_rate=frame_rate, prefix=prefix
)
elif os.path.isdir(video_path):
files = sorted(os.listdir(video_path))
if len(files) == 0:
raise ValueError(f"No files found in {video_path}.")
elif files[0].endswith(".tif"):
self._tif_source = os.path.join(video_path, files[0])
self._tif_dir = video_path
self._tif_names = files
video_path = common.geotiff_to_jpg_batch(video_path)
output_dir = video_path
if not os.path.exists(video_path):
raise ValueError(f"Input path {video_path} does not exist.")
else:
raise ValueError("Input video_path must be a string.")
self.video_path = output_dir
self._num_images = len(os.listdir(output_dir))
self._frame_names = sorted(os.listdir(output_dir))
self.inference_state = self.predictor.init_state(video_path=output_dir)
show_anns(self, figsize=(12, 10), axis='off', alpha=0.35, output=None, blend=True, **kwargs)
¶
Show the annotations (objects with random color) on the input image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
figsize |
tuple |
The figure size. Defaults to (12, 10). |
(12, 10) |
axis |
str |
Whether to show the axis. Defaults to "off". |
'off' |
alpha |
float |
The alpha value for the annotations. Defaults to 0.35. |
0.35 |
output |
str |
The path to the output image. Defaults to None. |
None |
blend |
bool |
Whether to show the input image. Defaults to True. |
True |
Source code in samgeo/samgeo2.py
def show_anns(
self,
figsize: Tuple[int, int] = (12, 10),
axis: str = "off",
alpha: float = 0.35,
output: Optional[str] = None,
blend: bool = True,
**kwargs: Any,
) -> None:
"""Show the annotations (objects with random color) on the input image.
Args:
figsize (tuple, optional): The figure size. Defaults to (12, 10).
axis (str, optional): Whether to show the axis. Defaults to "off".
alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
output (str, optional): The path to the output image. Defaults to None.
blend (bool, optional): Whether to show the input image. Defaults to True.
"""
import matplotlib.pyplot as plt
anns = self.masks
if self.image is None:
print("Please run generate() first.")
return
if anns is None or len(anns) == 0:
return
plt.figure(figsize=figsize)
plt.imshow(self.image)
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones(
(
sorted_anns[0]["segmentation"].shape[0],
sorted_anns[0]["segmentation"].shape[1],
4,
)
)
img[:, :, 3] = 0
for ann in sorted_anns:
if hasattr(self, "_min_size") and (ann["area"] < self._min_size):
continue
if (
hasattr(self, "_max_size")
and isinstance(self._max_size, int)
and ann["area"] > self._max_size
):
continue
m = ann["segmentation"]
color_mask = np.concatenate([np.random.random(3), [alpha]])
img[m] = color_mask
ax.imshow(img)
if "dpi" not in kwargs:
kwargs["dpi"] = 100
if "bbox_inches" not in kwargs:
kwargs["bbox_inches"] = "tight"
plt.axis(axis)
self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)
if output is not None:
if blend:
array = common.blend_images(
self.annotations, self.image, alpha=alpha, show=False
)
else:
array = self.annotations
common.array_to_image(array, output, self.source)
show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5)
¶
Show a canvas to collect foreground and background points.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fg_color |
Tuple[int, int, int] |
The color for the foreground points. Defaults to (0, 255, 0). |
(0, 255, 0) |
bg_color |
Tuple[int, int, int] |
The color for the background points. Defaults to (0, 0, 255). |
(0, 0, 255) |
radius |
int |
The radius of the points. Defaults to 5. |
5 |
Returns:
Type | Description |
---|---|
Tuple[list, list] |
A tuple of two lists of foreground and background points. |
Source code in samgeo/samgeo2.py
def show_canvas(
self,
fg_color: Tuple[int, int, int] = (0, 255, 0),
bg_color: Tuple[int, int, int] = (0, 0, 255),
radius: int = 5,
) -> Tuple[list, list]:
"""Show a canvas to collect foreground and background points.
Args:
fg_color (Tuple[int, int, int], optional): The color for the foreground points.
Defaults to (0, 255, 0).
bg_color (Tuple[int, int, int], optional): The color for the background points.
Defaults to (0, 0, 255).
radius (int, optional): The radius of the points. Defaults to 5.
Returns:
Tuple[list, list]: A tuple of two lists of foreground and background points.
"""
if self.image is None:
raise ValueError("Please run set_image() first.")
image = self.image
fg_points, bg_points = common.show_canvas(image, fg_color, bg_color, radius)
self.fg_points = fg_points
self.bg_points = bg_points
point_coords = fg_points + bg_points
point_labels = [1] * len(fg_points) + [0] * len(bg_points)
self.point_coords = point_coords
self.point_labels = point_labels
show_images(self, path=None)
¶
Show the images in the video.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
The path to the images. Defaults to None. |
None |
Source code in samgeo/samgeo2.py
def show_images(self, path: str = None) -> None:
"""Show the images in the video.
Args:
path (str, optional): The path to the images. Defaults to None.
"""
if path is None:
path = self.video_path
if path is not None:
common.show_image_gui(path)
show_map(self, basemap='SATELLITE', repeat_mode=True, out_dir=None, **kwargs)
¶
Show the interactive map.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
basemap |
str |
The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID. |
'SATELLITE' |
repeat_mode |
bool |
Whether to use the repeat mode for draw control. Defaults to True. |
True |
out_dir |
Optional[str] |
The path to the output directory. Defaults to None. |
None |
Returns:
Type | Description |
---|---|
Any |
The map object. |
Source code in samgeo/samgeo2.py
def show_map(
self,
basemap: str = "SATELLITE",
repeat_mode: bool = True,
out_dir: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Show the interactive map.
Args:
basemap (str, optional): The basemap. It can be one of the following:
SATELLITE, ROADMAP, TERRAIN, HYBRID.
repeat_mode (bool, optional): Whether to use the repeat mode for
draw control. Defaults to True.
out_dir (Optional[str], optional): The path to the output directory.
Defaults to None.
Returns:
Any: The map object.
"""
return common.sam_map_gui(
self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs
)
show_masks(self, figsize=(12, 10), cmap='binary_r', axis='off', foreground=True, **kwargs)
¶
Show the binary mask or the mask of objects with unique values.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
figsize |
tuple |
The figure size. Defaults to (12, 10). |
(12, 10) |
cmap |
str |
The colormap. Defaults to "binary_r". |
'binary_r' |
axis |
str |
Whether to show the axis. Defaults to "off". |
'off' |
foreground |
bool |
Whether to show the foreground mask only. Defaults to True. |
True |
**kwargs |
Any |
Other arguments for save_masks(). |
{} |
Source code in samgeo/samgeo2.py
def show_masks(
self,
figsize: Tuple[int, int] = (12, 10),
cmap: str = "binary_r",
axis: str = "off",
foreground: bool = True,
**kwargs: Any,
) -> None:
"""Show the binary mask or the mask of objects with unique values.
Args:
figsize (tuple, optional): The figure size. Defaults to (12, 10).
cmap (str, optional): The colormap. Defaults to "binary_r".
axis (str, optional): Whether to show the axis. Defaults to "off".
foreground (bool, optional): Whether to show the foreground mask only.
Defaults to True.
**kwargs: Other arguments for save_masks().
"""
import matplotlib.pyplot as plt
if self.objects is None:
self.save_masks(foreground=foreground, **kwargs)
plt.figure(figsize=figsize)
plt.imshow(self.objects, cmap=cmap)
plt.axis(axis)
plt.show()
show_prompts(self, prompts, frame_idx=0, mask=None, random_color=False, point_crs=None, figsize=(9, 6))
¶
Show the prompts on the image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prompts |
Dict[int, Any] |
A dictionary containing the prompts with points and labels. |
required |
frame_idx |
int |
The frame index. Defaults to 0. |
0 |
mask |
Any |
The mask. Defaults to None. |
None |
random_color |
bool |
Whether to use random colors for the masks. Defaults to False. |
False |
point_crs |
Optional[str] |
The coordinate reference system |
None |
figsize |
Tuple[int, int] |
The figure size. Defaults to (9, 6). |
(9, 6) |
Source code in samgeo/samgeo2.py
def show_prompts(
self,
prompts: Dict[int, Any],
frame_idx: int = 0,
mask: Any = None,
random_color: bool = False,
point_crs: Optional[str] = None,
figsize: Tuple[int, int] = (9, 6),
) -> None:
"""Show the prompts on the image.
Args:
prompts (Dict[int, Any]): A dictionary containing the prompts with
points and labels.
frame_idx (int, optional): The frame index. Defaults to 0.
mask (Any, optional): The mask. Defaults to None.
random_color (bool, optional): Whether to use random colors for the
masks. Defaults to False.
point_crs (Optional[str], optional): The coordinate reference system
figsize (Tuple[int, int], optional): The figure size. Defaults to (9, 6).
"""
from PIL import Image
def show_mask(mask, ax, obj_id=None, random_color=random_color):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle(
(x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2
)
)
if point_crs is not None and self._tif_source is not None:
for prompt in prompts.values():
points = prompt.get("points", None)
if points is not None:
points = common.coords_to_xy(self._tif_source, points, point_crs)
prompt["points"] = points
box = prompt.get("box", None)
if box is not None:
box = common.bbox_to_xy(self._tif_source, box, point_crs)
prompt["box"] = box
prompts = self._convert_prompts(prompts)
self.prompts = prompts
video_dir = self.video_path
frame_names = self._frame_names
fig = plt.figure(figsize=figsize)
fig.canvas.toolbar_visible = True
fig.canvas.header_visible = False
fig.canvas.footer_visible = True
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
for obj_id, prompt in prompts.items():
points = prompt.get("points", None)
labels = prompt.get("labels", None)
box = prompt.get("box", None)
anno_frame_idx = prompt.get("frame_idx", None)
if anno_frame_idx == frame_idx:
if points is not None:
show_points(points, labels, plt.gca())
if box is not None:
show_box(box, plt.gca())
if mask is not None:
show_mask(mask, plt.gca(), obj_id=obj_id)
plt.show()
tensor_to_numpy(self, index=None, output=None, mask_multiplier=255, dtype='uint8', save_args=None)
¶
Convert the predicted masks from tensors to numpy arrays.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
index |
Optional[int] |
The index of the mask to save. Defaults to None, which will save the mask with the highest score. |
None |
output |
Optional[str] |
The path to the output image. Defaults to None. |
None |
mask_multiplier |
int |
The mask multiplier for the output mask, which is usually a binary mask [0, 1]. |
255 |
dtype |
str |
The data type of the output image. Defaults to "uint8". |
'uint8' |
save_args |
Optional[Dict[str, Any]] |
Optional arguments for saving the output image. Defaults to None. |
None |
Returns:
Type | Description |
---|---|
Optional[np.ndarray] |
The predicted mask as a numpy array, or None if output is specified. |
Source code in samgeo/samgeo2.py
def tensor_to_numpy(
self,
index: Optional[int] = None,
output: Optional[str] = None,
mask_multiplier: int = 255,
dtype: str = "uint8",
save_args: Optional[Dict[str, Any]] = None,
) -> Optional[np.ndarray]:
"""Convert the predicted masks from tensors to numpy arrays.
Args:
index (Optional[int], optional): The index of the mask to save.
Defaults to None, which will save the mask with the highest score.
output (Optional[str], optional): The path to the output image.
Defaults to None.
mask_multiplier (int, optional): The mask multiplier for the output
mask, which is usually a binary mask [0, 1].
dtype (str, optional): The data type of the output image. Defaults
to "uint8".
save_args (Optional[Dict[str, Any]], optional): Optional arguments
for saving the output image. Defaults to None.
Returns:
Optional[np.ndarray]: The predicted mask as a numpy array, or None
if output is specified.
"""
if save_args is None:
save_args = {}
boxes = self.boxes
masks = self.masks
image_pil = self.image
image_np = np.array(image_pil)
if index is None:
index = 1
masks = masks[:, index, :, :]
masks = masks.squeeze(1)
if boxes is None or (len(boxes) == 0): # No "object" instances found
print("No objects found in the image.")
return
else:
# Create an empty image to store the mask overlays
mask_overlay = np.zeros_like(
image_np[..., 0], dtype=dtype
) # Adjusted for single channel
for i, (_, mask) in enumerate(zip(boxes, masks)):
# Convert tensor to numpy array if necessary and ensure it contains integers
if isinstance(mask, torch.Tensor):
mask = (
mask.cpu().numpy().astype(dtype)
) # If mask is on GPU, use .cpu() before .numpy()
mask_overlay += ((mask > 0) * (i + 1)).astype(
dtype
) # Assign a unique value for each mask
# Normalize mask_overlay to be in [0, 255]
mask_overlay = (
mask_overlay > 0
) * mask_multiplier # Binary mask in [0, 255]
if output is not None:
common.array_to_image(
mask_overlay, output, self.source, dtype=dtype, **save_args
)
else:
return mask_overlay