fast_sam module¶
Segmenting remote sensing images with the Fast Segment Anything Model (FastSAM. https://github.com/opengeos/FastSAM
SamGeo (FastSAM)
¶
Segmenting remote sensing images with the Fast Segment Anything Model (FastSAM).
Source code in samgeo/fast_sam.py
class SamGeo(FastSAM):
"""Segmenting remote sensing images with the Fast Segment Anything Model (FastSAM)."""
def __init__(self, model="FastSAM-x.pt", **kwargs):
"""Initialize the FastSAM algorithm."""
if "checkpoint_dir" in kwargs:
checkpoint_dir = kwargs["checkpoint_dir"]
kwargs.pop("checkpoint_dir")
else:
checkpoint_dir = os.environ.get(
"TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints")
)
models = {
"FastSAM-x.pt": "https://github.com/opengeos/datasets/releases/download/models/FastSAM-x.pt",
"FastSAM-s.pt": "https://github.com/opengeos/datasets/releases/download/models/FastSAM-s.pt",
}
if model not in models:
raise ValueError(
f"Model must be one of {list(models.keys())}, but got {model} instead."
)
model_path = os.path.join(checkpoint_dir, model)
if not os.path.exists(model_path):
print(f"Downloading {model} to {model_path}...")
download_file(models[model], model_path)
super().__init__(model, **kwargs)
def set_image(self, image, device=None, **kwargs):
"""Set the input image.
Args:
image (str): The path to the image file or a HTTP URL.
device (str, optional): The device to use. Defaults to "cuda" if available, otherwise "cpu".
kwargs: Additional keyword arguments to pass to the FastSAM model.
"""
if isinstance(image, str):
if image.startswith("http"):
image = download_file(image)
if not os.path.exists(image):
raise ValueError(f"Input path {image} does not exist.")
self.source = image
else:
self.source = None
# Use cuda if available
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.cuda.empty_cache()
everything_results = self(image, device=device, **kwargs)
self.prompt_process = FastSAMPrompt(image, everything_results, device=device)
def everything_prompt(self, output=None, **kwargs):
"""Segment the image with the everything prompt. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L451
Args:
output (str, optional): The path to save the output image. Defaults to None.
"""
prompt_process = self.prompt_process
ann = prompt_process.everything_prompt()
self.annotations = ann
if output is not None:
self.save_masks(output, **kwargs)
else:
return ann
def point_prompt(self, points, pointlabel, output=None, **kwargs):
"""Segment the image with the point prompt. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L414
Args:
points (list): A list of points.
pointlabel (list): A list of labels for each point.
output (str, optional): The path to save the output image. Defaults to None.
"""
prompt_process = self.prompt_process
ann = prompt_process.point_prompt(points, pointlabel)
self.annotations = ann
if output is not None:
self.save_masks(output, **kwargs)
else:
return ann
def box_prompt(self, bbox=None, bboxes=None, output=None, **kwargs):
"""Segment the image with the box prompt. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L377
Args:
bbox (list, optional): The bounding box. Defaults to None.
bboxes (list, optional): A list of bounding boxes. Defaults to None.
output (str, optional): The path to save the output image. Defaults to None.
"""
prompt_process = self.prompt_process
ann = prompt_process.box_prompt(bbox, bboxes)
self.annotations = ann
if output is not None:
self.save_masks(output, **kwargs)
else:
return ann
def text_prompt(self, text, output=None, **kwargs):
"""Segment the image with the text prompt. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L439
Args:
text (str): The text to segment.
output (str, optional): The path to save the output image. Defaults to None.
"""
prompt_process = self.prompt_process
ann = prompt_process.text_prompt(text)
self.annotations = ann
if output is not None:
self.save_masks(output, **kwargs)
else:
return ann
def save_masks(
self,
output=None,
better_quality=True,
dtype=None,
mask_multiplier=255,
**kwargs,
) -> np.ndarray:
"""Save the mask of the image. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L222
Returns:
np.ndarray: The mask of the image.
"""
annotations = self.annotations
if isinstance(annotations[0], dict):
annotations = [annotation["segmentation"] for annotation in annotations]
image = self.prompt_process.img
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
height = image.shape[0]
width = image.shape[1]
if better_quality:
if isinstance(annotations[0], torch.Tensor):
annotations = np.array(annotations.cpu())
for i, mask in enumerate(annotations):
mask = cv2.morphologyEx(
mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
)
annotations[i] = cv2.morphologyEx(
mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
)
if self.device == "cpu":
annotations = np.array(annotations)
else:
if isinstance(annotations[0], np.ndarray):
annotations = torch.from_numpy(annotations)
if isinstance(annotations, torch.Tensor):
annotations = annotations.cpu().numpy()
if dtype is None:
# Set output image data type based on the number of objects
if len(annotations) < 255:
dtype = np.uint8
elif len(annotations) < 65535:
dtype = np.uint16
else:
dtype = np.uint32
masks = np.sum(annotations, axis=0)
masks = cv2.resize(masks, (width, height), interpolation=cv2.INTER_NEAREST)
masks[masks > 0] = 1
masks = masks.astype(dtype) * mask_multiplier
self.objects = masks
if output is not None: # Save the output image
array_to_image(self.objects, output, self.source, **kwargs)
else:
return masks
def fast_show_mask(
self,
random_color=False,
):
"""Show the mask of the image. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L222
Args:
random_color (bool, optional): Whether to use random colors for each object. Defaults to False.
Returns:
np.ndarray: The mask of the image.
"""
image = self.prompt_process.img
target_height = image.shape[0]
target_width = image.shape[1]
annotations = self.annotations
annotation = np.array(annotations.cpu())
mask_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
# Sort annotations based on area.
areas = np.sum(annotation, axis=(1, 2))
sorted_indices = np.argsort(areas)
annotation = annotation[sorted_indices]
index = (annotation != 0).argmax(axis=0)
if random_color:
color = np.random.random((mask_sum, 1, 1, 3))
else:
color = np.ones((mask_sum, 1, 1, 3)) * np.array(
[30 / 255, 144 / 255, 255 / 255]
)
transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((height, weight, 4))
h_indices, w_indices = np.meshgrid(
np.arange(height), np.arange(weight), indexing="ij"
)
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# Use vectorized indexing to update the values of 'show'.
show[h_indices, w_indices, :] = mask_image[indices]
show = cv2.resize(
show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
)
return show
def raster_to_vector(
self, image, output, simplify_tolerance=None, dst_crs="EPSG:4326", **kwargs
):
"""Save the result to a vector file.
Args:
image (str): The path to the image file.
output (str): The path to the vector file.
simplify_tolerance (float, optional): The maximum allowed geometry displacement.
The higher this value, the smaller the number of vertices in the resulting geometry.
"""
raster_to_vector(
image,
output,
simplify_tolerance=simplify_tolerance,
dst_crs=dst_crs,
**kwargs,
)
def show_anns(
self,
output=None,
**kwargs,
):
"""Show the annotations (objects with random color) on the input image.
Args:
figsize (tuple, optional): The figure size. Defaults to (12, 10).
axis (str, optional): Whether to show the axis. Defaults to "off".
alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
output (str, optional): The path to the output image. Defaults to None.
blend (bool, optional): Whether to show the input image. Defaults to True.
"""
annotations = self.annotations
prompt_process = self.prompt_process
if output is None:
output = temp_file_path(".png")
prompt_process.plot(annotations, output, **kwargs)
show_image(output)
__init__(self, model='FastSAM-x.pt', **kwargs)
special
¶
Initialize the FastSAM algorithm.
Source code in samgeo/fast_sam.py
def __init__(self, model="FastSAM-x.pt", **kwargs):
"""Initialize the FastSAM algorithm."""
if "checkpoint_dir" in kwargs:
checkpoint_dir = kwargs["checkpoint_dir"]
kwargs.pop("checkpoint_dir")
else:
checkpoint_dir = os.environ.get(
"TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints")
)
models = {
"FastSAM-x.pt": "https://github.com/opengeos/datasets/releases/download/models/FastSAM-x.pt",
"FastSAM-s.pt": "https://github.com/opengeos/datasets/releases/download/models/FastSAM-s.pt",
}
if model not in models:
raise ValueError(
f"Model must be one of {list(models.keys())}, but got {model} instead."
)
model_path = os.path.join(checkpoint_dir, model)
if not os.path.exists(model_path):
print(f"Downloading {model} to {model_path}...")
download_file(models[model], model_path)
super().__init__(model, **kwargs)
box_prompt(self, bbox=None, bboxes=None, output=None, **kwargs)
¶
Segment the image with the box prompt. Adapted from https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L377
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bbox |
list |
The bounding box. Defaults to None. |
None |
bboxes |
list |
A list of bounding boxes. Defaults to None. |
None |
output |
str |
The path to save the output image. Defaults to None. |
None |
Source code in samgeo/fast_sam.py
def box_prompt(self, bbox=None, bboxes=None, output=None, **kwargs):
"""Segment the image with the box prompt. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L377
Args:
bbox (list, optional): The bounding box. Defaults to None.
bboxes (list, optional): A list of bounding boxes. Defaults to None.
output (str, optional): The path to save the output image. Defaults to None.
"""
prompt_process = self.prompt_process
ann = prompt_process.box_prompt(bbox, bboxes)
self.annotations = ann
if output is not None:
self.save_masks(output, **kwargs)
else:
return ann
everything_prompt(self, output=None, **kwargs)
¶
Segment the image with the everything prompt. Adapted from https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L451
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
str |
The path to save the output image. Defaults to None. |
None |
Source code in samgeo/fast_sam.py
def everything_prompt(self, output=None, **kwargs):
"""Segment the image with the everything prompt. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L451
Args:
output (str, optional): The path to save the output image. Defaults to None.
"""
prompt_process = self.prompt_process
ann = prompt_process.everything_prompt()
self.annotations = ann
if output is not None:
self.save_masks(output, **kwargs)
else:
return ann
fast_show_mask(self, random_color=False)
¶
Show the mask of the image. Adapted from https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L222
Parameters:
Name | Type | Description | Default |
---|---|---|---|
random_color |
bool |
Whether to use random colors for each object. Defaults to False. |
False |
Returns:
Type | Description |
---|---|
np.ndarray |
The mask of the image. |
Source code in samgeo/fast_sam.py
def fast_show_mask(
self,
random_color=False,
):
"""Show the mask of the image. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L222
Args:
random_color (bool, optional): Whether to use random colors for each object. Defaults to False.
Returns:
np.ndarray: The mask of the image.
"""
image = self.prompt_process.img
target_height = image.shape[0]
target_width = image.shape[1]
annotations = self.annotations
annotation = np.array(annotations.cpu())
mask_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
# Sort annotations based on area.
areas = np.sum(annotation, axis=(1, 2))
sorted_indices = np.argsort(areas)
annotation = annotation[sorted_indices]
index = (annotation != 0).argmax(axis=0)
if random_color:
color = np.random.random((mask_sum, 1, 1, 3))
else:
color = np.ones((mask_sum, 1, 1, 3)) * np.array(
[30 / 255, 144 / 255, 255 / 255]
)
transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((height, weight, 4))
h_indices, w_indices = np.meshgrid(
np.arange(height), np.arange(weight), indexing="ij"
)
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# Use vectorized indexing to update the values of 'show'.
show[h_indices, w_indices, :] = mask_image[indices]
show = cv2.resize(
show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
)
return show
point_prompt(self, points, pointlabel, output=None, **kwargs)
¶
Segment the image with the point prompt. Adapted from https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L414
Parameters:
Name | Type | Description | Default |
---|---|---|---|
points |
list |
A list of points. |
required |
pointlabel |
list |
A list of labels for each point. |
required |
output |
str |
The path to save the output image. Defaults to None. |
None |
Source code in samgeo/fast_sam.py
def point_prompt(self, points, pointlabel, output=None, **kwargs):
"""Segment the image with the point prompt. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L414
Args:
points (list): A list of points.
pointlabel (list): A list of labels for each point.
output (str, optional): The path to save the output image. Defaults to None.
"""
prompt_process = self.prompt_process
ann = prompt_process.point_prompt(points, pointlabel)
self.annotations = ann
if output is not None:
self.save_masks(output, **kwargs)
else:
return ann
raster_to_vector(self, image, output, simplify_tolerance=None, dst_crs='EPSG:4326', **kwargs)
¶
Save the result to a vector file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
str |
The path to the image file. |
required |
output |
str |
The path to the vector file. |
required |
simplify_tolerance |
float |
The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry. |
None |
Source code in samgeo/fast_sam.py
def raster_to_vector(
self, image, output, simplify_tolerance=None, dst_crs="EPSG:4326", **kwargs
):
"""Save the result to a vector file.
Args:
image (str): The path to the image file.
output (str): The path to the vector file.
simplify_tolerance (float, optional): The maximum allowed geometry displacement.
The higher this value, the smaller the number of vertices in the resulting geometry.
"""
raster_to_vector(
image,
output,
simplify_tolerance=simplify_tolerance,
dst_crs=dst_crs,
**kwargs,
)
save_masks(self, output=None, better_quality=True, dtype=None, mask_multiplier=255, **kwargs)
¶
Save the mask of the image. Adapted from https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L222
Returns:
Type | Description |
---|---|
np.ndarray |
The mask of the image. |
Source code in samgeo/fast_sam.py
def save_masks(
self,
output=None,
better_quality=True,
dtype=None,
mask_multiplier=255,
**kwargs,
) -> np.ndarray:
"""Save the mask of the image. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L222
Returns:
np.ndarray: The mask of the image.
"""
annotations = self.annotations
if isinstance(annotations[0], dict):
annotations = [annotation["segmentation"] for annotation in annotations]
image = self.prompt_process.img
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
height = image.shape[0]
width = image.shape[1]
if better_quality:
if isinstance(annotations[0], torch.Tensor):
annotations = np.array(annotations.cpu())
for i, mask in enumerate(annotations):
mask = cv2.morphologyEx(
mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
)
annotations[i] = cv2.morphologyEx(
mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
)
if self.device == "cpu":
annotations = np.array(annotations)
else:
if isinstance(annotations[0], np.ndarray):
annotations = torch.from_numpy(annotations)
if isinstance(annotations, torch.Tensor):
annotations = annotations.cpu().numpy()
if dtype is None:
# Set output image data type based on the number of objects
if len(annotations) < 255:
dtype = np.uint8
elif len(annotations) < 65535:
dtype = np.uint16
else:
dtype = np.uint32
masks = np.sum(annotations, axis=0)
masks = cv2.resize(masks, (width, height), interpolation=cv2.INTER_NEAREST)
masks[masks > 0] = 1
masks = masks.astype(dtype) * mask_multiplier
self.objects = masks
if output is not None: # Save the output image
array_to_image(self.objects, output, self.source, **kwargs)
else:
return masks
set_image(self, image, device=None, **kwargs)
¶
Set the input image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
str |
The path to the image file or a HTTP URL. |
required |
device |
str |
The device to use. Defaults to "cuda" if available, otherwise "cpu". |
None |
kwargs |
Additional keyword arguments to pass to the FastSAM model. |
{} |
Source code in samgeo/fast_sam.py
def set_image(self, image, device=None, **kwargs):
"""Set the input image.
Args:
image (str): The path to the image file or a HTTP URL.
device (str, optional): The device to use. Defaults to "cuda" if available, otherwise "cpu".
kwargs: Additional keyword arguments to pass to the FastSAM model.
"""
if isinstance(image, str):
if image.startswith("http"):
image = download_file(image)
if not os.path.exists(image):
raise ValueError(f"Input path {image} does not exist.")
self.source = image
else:
self.source = None
# Use cuda if available
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.cuda.empty_cache()
everything_results = self(image, device=device, **kwargs)
self.prompt_process = FastSAMPrompt(image, everything_results, device=device)
show_anns(self, output=None, **kwargs)
¶
Show the annotations (objects with random color) on the input image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
figsize |
tuple |
The figure size. Defaults to (12, 10). |
required |
axis |
str |
Whether to show the axis. Defaults to "off". |
required |
alpha |
float |
The alpha value for the annotations. Defaults to 0.35. |
required |
output |
str |
The path to the output image. Defaults to None. |
None |
blend |
bool |
Whether to show the input image. Defaults to True. |
required |
Source code in samgeo/fast_sam.py
def show_anns(
self,
output=None,
**kwargs,
):
"""Show the annotations (objects with random color) on the input image.
Args:
figsize (tuple, optional): The figure size. Defaults to (12, 10).
axis (str, optional): Whether to show the axis. Defaults to "off".
alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
output (str, optional): The path to the output image. Defaults to None.
blend (bool, optional): Whether to show the input image. Defaults to True.
"""
annotations = self.annotations
prompt_process = self.prompt_process
if output is None:
output = temp_file_path(".png")
prompt_process.plot(annotations, output, **kwargs)
show_image(output)
text_prompt(self, text, output=None, **kwargs)
¶
Segment the image with the text prompt. Adapted from https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L439
Parameters:
Name | Type | Description | Default |
---|---|---|---|
text |
str |
The text to segment. |
required |
output |
str |
The path to save the output image. Defaults to None. |
None |
Source code in samgeo/fast_sam.py
def text_prompt(self, text, output=None, **kwargs):
"""Segment the image with the text prompt. Adapted from
https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/fastsam/prompt.py#L439
Args:
text (str): The text to segment.
output (str, optional): The path to save the output image. Defaults to None.
"""
prompt_process = self.prompt_process
ann = prompt_process.text_prompt(text)
self.annotations = ann
if output is not None:
self.save_masks(output, **kwargs)
else:
return ann