Segment Anything Model for Geospatial Data¶
This notebook shows how to use segment satellite imagery using the Segment Anything Model (SAM) with a few lines of code.
Make sure you use GPU runtime for this notebook. For Google Colab, go to Runtime
-> Change runtime type
and select GPU
as the hardware accelerator.
Install dependencies¶
Uncomment and run the following cell to install the required dependencies.
In [ ]:
Copied!
# %pip install segment-geospatial
# %pip install segment-geospatial
Import libraries¶
In [ ]:
Copied!
import os
import leafmap
from samgeo import SamGeoPredictor, tms_to_geotiff, get_basemaps
from segment_anything import sam_model_registry
import os
import leafmap
from samgeo import SamGeoPredictor, tms_to_geotiff, get_basemaps
from segment_anything import sam_model_registry
Create an interactive map¶
In [ ]:
Copied!
zoom = 16
m = leafmap.Map(center=[45, -123], zoom=zoom)
m.add_basemap("SATELLITE")
m
zoom = 16
m = leafmap.Map(center=[45, -123], zoom=zoom)
m.add_basemap("SATELLITE")
m
Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map
In [ ]:
Copied!
if m.user_roi_bounds() is not None:
bbox = m.user_roi_bounds()
else:
bbox = [-123.0127, 44.9957, -122.9874, 45.0045]
if m.user_roi_bounds() is not None:
bbox = m.user_roi_bounds()
else:
bbox = [-123.0127, 44.9957, -122.9874, 45.0045]
Download map tiles¶
Download maps tiles and mosaic them into a single GeoTIFF file
In [ ]:
Copied!
image = "satellite.tif"
# image = '/path/to/your/own/image.tif'
image = "satellite.tif"
# image = '/path/to/your/own/image.tif'
Besides the satellite
basemap, you can use any of the following basemaps returned by the get_basemaps()
function:
In [ ]:
Copied!
# get_basemaps().keys()
# get_basemaps().keys()
Specify the basemap as the source.
In [ ]:
Copied!
tms_to_geotiff(
output=image, bbox=bbox, zoom=zoom + 1, source="Satellite", overwrite=True
)
tms_to_geotiff(
output=image, bbox=bbox, zoom=zoom + 1, source="Satellite", overwrite=True
)
In [ ]:
Copied!
m.add_raster(image, layer_name="Image")
m
m.add_raster(image, layer_name="Image")
m
Use the draw tools to draw a rectangle from which to subset segmentations on the map
In [ ]:
Copied!
if m.user_roi_bounds() is not None:
clip_box = m.user_roi_bounds()
else:
clip_box = [-123.0064, 44.9988, -123.0005, 45.0025]
if m.user_roi_bounds() is not None:
clip_box = m.user_roi_bounds()
else:
clip_box = [-123.0064, 44.9988, -123.0005, 45.0025]
In [ ]:
Copied!
clip_box
clip_box
Initialize SamGeoPredictor class¶
In [ ]:
Copied!
out_dir = os.path.join(os.path.expanduser("~"), "Downloads")
checkpoint = os.path.join(out_dir, "sam_vit_h_4b8939.pth")
out_dir = os.path.join(os.path.expanduser("~"), "Downloads")
checkpoint = os.path.join(out_dir, "sam_vit_h_4b8939.pth")
In [ ]:
Copied!
import cv2
img_arr = cv2.imread(image)
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
predictor = SamGeoPredictor(sam)
predictor.set_image(img_arr)
masks, _, _ = predictor.predict(src_fp=image, geo_box=clip_box)
import cv2
img_arr = cv2.imread(image)
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
predictor = SamGeoPredictor(sam)
predictor.set_image(img_arr)
masks, _, _ = predictor.predict(src_fp=image, geo_box=clip_box)
In [ ]:
Copied!
masks_img = "preds.tif"
predictor.masks_to_geotiff(image, masks_img, masks.astype("uint8"))
masks_img = "preds.tif"
predictor.masks_to_geotiff(image, masks_img, masks.astype("uint8"))
In [ ]:
Copied!
vector = "feats.geojson"
gdf = predictor.geotiff_to_geojson(masks_img, vector, bidx=1)
gdf.plot()
vector = "feats.geojson"
gdf = predictor.geotiff_to_geojson(masks_img, vector, bidx=1)
gdf.plot()
Visualize the results¶
In [ ]:
Copied!
style = {
"color": "#3388ff",
"weight": 2,
"fillColor": "#7c4185",
"fillOpacity": 0.5,
}
m.add_vector(vector, layer_name="Vector", style=style)
m
style = {
"color": "#3388ff",
"weight": 2,
"fillColor": "#7c4185",
"fillOpacity": 0.5,
}
m.add_vector(vector, layer_name="Vector", style=style)
m