Skip to content

common module

The source code is adapted from https://github.com/aliaksandr960/segment-anything-eo. Credit to the author Aliaksandr Hancharenka.

array_to_image(array, output, source=None, dtype=None, compress='deflate', **kwargs)

Save a NumPy array as a GeoTIFF using the projection information from an existing GeoTIFF file.

Parameters:

Name Type Description Default
array np.ndarray

The NumPy array to be saved as a GeoTIFF.

required
output str

The path to the output image.

required
source str

The path to an existing GeoTIFF file with map projection information. Defaults to None.

None
dtype np.dtype

The data type of the output array. Defaults to None.

None
compress str

The compression method. Can be one of the following: "deflate", "lzw", "packbits", "jpeg". Defaults to "deflate".

'deflate'
Source code in samgeo/common.py
def array_to_image(
    array, output, source=None, dtype=None, compress="deflate", **kwargs
):
    """Save a NumPy array as a GeoTIFF using the projection information from an existing GeoTIFF file.

    Args:
        array (np.ndarray): The NumPy array to be saved as a GeoTIFF.
        output (str): The path to the output image.
        source (str, optional): The path to an existing GeoTIFF file with map projection information. Defaults to None.
        dtype (np.dtype, optional): The data type of the output array. Defaults to None.
        compress (str, optional): The compression method. Can be one of the following: "deflate", "lzw", "packbits", "jpeg". Defaults to "deflate".
    """

    from PIL import Image

    if isinstance(array, str) and os.path.exists(array):
        array = cv2.imread(array)
        array = cv2.cvtColor(array, cv2.COLOR_BGR2RGB)

    if output.endswith(".tif") and source is not None:
        with rasterio.open(source) as src:
            crs = src.crs
            transform = src.transform
            if compress is None:
                compress = src.compression

        # Determine the minimum and maximum values in the array

        min_value = np.min(array)
        max_value = np.max(array)

        if dtype is None:
            # Determine the best dtype for the array
            if min_value >= 0 and max_value <= 1:
                dtype = np.float32
            elif min_value >= 0 and max_value <= 255:
                dtype = np.uint8
            elif min_value >= -128 and max_value <= 127:
                dtype = np.int8
            elif min_value >= 0 and max_value <= 65535:
                dtype = np.uint16
            elif min_value >= -32768 and max_value <= 32767:
                dtype = np.int16
            else:
                dtype = np.float64

        # Convert the array to the best dtype
        array = array.astype(dtype)

        # Define the GeoTIFF metadata
        if array.ndim == 2:
            metadata = {
                "driver": "GTiff",
                "height": array.shape[0],
                "width": array.shape[1],
                "count": 1,
                "dtype": array.dtype,
                "crs": crs,
                "transform": transform,
            }
        elif array.ndim == 3:
            metadata = {
                "driver": "GTiff",
                "height": array.shape[0],
                "width": array.shape[1],
                "count": array.shape[2],
                "dtype": array.dtype,
                "crs": crs,
                "transform": transform,
            }

        if compress is not None:
            metadata["compress"] = compress
        else:
            raise ValueError("Array must be 2D or 3D.")

        # Create a new GeoTIFF file and write the array to it
        with rasterio.open(output, "w", **metadata) as dst:
            if array.ndim == 2:
                dst.write(array, 1)
            elif array.ndim == 3:
                for i in range(array.shape[2]):
                    dst.write(array[:, :, i], i + 1)

    else:
        img = Image.fromarray(array)
        img.save(output, **kwargs)

bbox_to_xy(src_fp, coords, coord_crs='epsg:4326', **kwargs)

Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates. Note that map bbox coords is [minx, miny, maxx, maxy] from bottomleft to topright While rasterio bbox coords is [minx, max, maxx, min] from topleft to bottomright

Parameters:

Name Type Description Default
src_fp str

The source raster file path.

required
coords list

A list of coordinates in the format of [[minx, miny, maxx, maxy], [minx, miny, maxx, maxy], ...]

required
coord_crs str

The coordinate CRS of the input coordinates. Defaults to "epsg:4326".

'epsg:4326'

Returns:

Type Description
list

A list of pixel coordinates in the format of [[minx, maxy, maxx, miny], ...] from top left to bottom right.

Source code in samgeo/common.py
def bbox_to_xy(
    src_fp: str, coords: list, coord_crs: str = "epsg:4326", **kwargs
) -> list:
    """Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.
        Note that map bbox coords is [minx, miny, maxx, maxy] from bottomleft to topright
        While rasterio bbox coords is [minx, max, maxx, min] from topleft to bottomright

    Args:
        src_fp (str): The source raster file path.
        coords (list): A list of coordinates in the format of [[minx, miny, maxx, maxy], [minx, miny, maxx, maxy], ...]
        coord_crs (str, optional): The coordinate CRS of the input coordinates. Defaults to "epsg:4326".

    Returns:
        list: A list of pixel coordinates in the format of [[minx, maxy, maxx, miny], ...] from top left to bottom right.
    """

    if isinstance(coords, str):
        gdf = gpd.read_file(coords)
        coords = gdf.geometry.bounds.values.tolist()
        if gdf.crs is not None:
            coord_crs = f"epsg:{gdf.crs.to_epsg()}"
    elif isinstance(coords, np.ndarray):
        coords = coords.tolist()
    if isinstance(coords, dict):
        import json

        geojson = json.dumps(coords)
        gdf = gpd.read_file(geojson, driver="GeoJSON")
        coords = gdf.geometry.bounds.values.tolist()

    elif not isinstance(coords, list):
        raise ValueError("coords must be a list of coordinates.")

    if not isinstance(coords[0], list):
        coords = [coords]

    new_coords = []

    with rasterio.open(src_fp) as src:
        width = src.width
        height = src.height

        for coord in coords:
            minx, miny, maxx, maxy = coord

            if coord_crs != src.crs:
                minx, miny = transform_coords(minx, miny, coord_crs, src.crs, **kwargs)
                maxx, maxy = transform_coords(maxx, maxy, coord_crs, src.crs, **kwargs)

                rows1, cols1 = rasterio.transform.rowcol(
                    src.transform, minx, miny, **kwargs
                )
                rows2, cols2 = rasterio.transform.rowcol(
                    src.transform, maxx, maxy, **kwargs
                )

                new_coords.append([cols1, rows1, cols2, rows2])

            else:
                new_coords.append([minx, miny, maxx, maxy])

    result = []

    for coord in new_coords:
        minx, miny, maxx, maxy = coord

        if (
            minx >= 0
            and miny >= 0
            and maxx >= 0
            and maxy >= 0
            and minx < width
            and miny < height
            and maxx < width
            and maxy < height
        ):
            # Note that map bbox coords is [minx, miny, maxx, maxy] from bottomleft to topright
            # While rasterio bbox coords is [minx, max, maxx, min] from topleft to bottomright
            result.append([minx, maxy, maxx, miny])

    if len(result) == 0:
        print("No valid pixel coordinates found.")
        return None
    elif len(result) == 1:
        return result[0]
    elif len(result) < len(coords):
        print("Some coordinates are out of the image boundary.")

    return result

blend_images(img1, img2, alpha=0.5, output=False, show=True, figsize=(12, 10), axis='off', **kwargs)

Blends two images together using the addWeighted function from the OpenCV library.

Parameters:

Name Type Description Default
img1 numpy.ndarray

The first input image on top represented as a NumPy array.

required
img2 numpy.ndarray

The second input image at the bottom represented as a NumPy array.

required
alpha float

The weighting factor for the first image in the blend. By default, this is set to 0.5.

0.5
output str

The path to the output image. Defaults to False.

False
show bool

Whether to display the blended image. Defaults to True.

True
figsize tuple

The size of the figure. Defaults to (12, 10).

(12, 10)
axis str

The axis of the figure. Defaults to "off".

'off'
**kwargs

Additional keyword arguments to pass to the cv2.addWeighted() function.

{}

Returns:

Type Description
numpy.ndarray

The blended image as a NumPy array.

Source code in samgeo/common.py
def blend_images(
    img1,
    img2,
    alpha=0.5,
    output=False,
    show=True,
    figsize=(12, 10),
    axis="off",
    **kwargs,
):
    """
    Blends two images together using the addWeighted function from the OpenCV library.

    Args:
        img1 (numpy.ndarray): The first input image on top represented as a NumPy array.
        img2 (numpy.ndarray): The second input image at the bottom represented as a NumPy array.
        alpha (float): The weighting factor for the first image in the blend. By default, this is set to 0.5.
        output (str, optional): The path to the output image. Defaults to False.
        show (bool, optional): Whether to display the blended image. Defaults to True.
        figsize (tuple, optional): The size of the figure. Defaults to (12, 10).
        axis (str, optional): The axis of the figure. Defaults to "off".
        **kwargs: Additional keyword arguments to pass to the cv2.addWeighted() function.

    Returns:
        numpy.ndarray: The blended image as a NumPy array.
    """
    # Resize the images to have the same dimensions
    if isinstance(img1, str):
        if img1.startswith("http"):
            img1 = download_file(img1)

        if not os.path.exists(img1):
            raise ValueError(f"Input path {img1} does not exist.")

        img1 = cv2.imread(img1)

    if isinstance(img2, str):
        if img2.startswith("http"):
            img2 = download_file(img2)

        if not os.path.exists(img2):
            raise ValueError(f"Input path {img2} does not exist.")

        img2 = cv2.imread(img2)

    if img1.dtype == np.float32:
        img1 = (img1 * 255).astype(np.uint8)

    if img2.dtype == np.float32:
        img2 = (img2 * 255).astype(np.uint8)

    if img1.dtype != img2.dtype:
        img2 = img2.astype(img1.dtype)

    img1 = cv2.resize(img1, (img2.shape[1], img2.shape[0]))

    # Blend the images using the addWeighted function
    beta = 1 - alpha
    blend_img = cv2.addWeighted(img1, alpha, img2, beta, 0, **kwargs)

    if output:
        array_to_image(blend_img, output, img2)

    if show:
        plt.figure(figsize=figsize)
        plt.imshow(blend_img)
        plt.axis(axis)
        plt.show()
    else:
        return blend_img

boxes_to_vector(coords, src_crs, dst_crs='EPSG:4326', output=None, **kwargs)

Convert a list of bounding box coordinates to vector data.

Parameters:

Name Type Description Default
coords list

A list of bounding box coordinates in the format [[left, top, right, bottom], [left, top, right, bottom], ...].

required
src_crs int or str

The EPSG code or proj4 string representing the source coordinate reference system (CRS) of the input coordinates.

required
dst_crs int or str

The EPSG code or proj4 string representing the destination CRS to reproject the data (default is "EPSG:4326").

'EPSG:4326'
output str or None

The full file path (including the directory and filename without the extension) where the vector data should be saved. If None (default), the function returns the GeoDataFrame without saving it to a file.

None
**kwargs

Additional keyword arguments to pass to geopandas.GeoDataFrame.to_file() when saving the vector data.

{}

Returns:

Type Description
geopandas.GeoDataFrame or None

The GeoDataFrame with the converted vector data if output is None, otherwise None if the data is saved to a file.

Source code in samgeo/common.py
def boxes_to_vector(coords, src_crs, dst_crs="EPSG:4326", output=None, **kwargs):
    """
    Convert a list of bounding box coordinates to vector data.

    Args:
        coords (list): A list of bounding box coordinates in the format [[left, top, right, bottom], [left, top, right, bottom], ...].
        src_crs (int or str): The EPSG code or proj4 string representing the source coordinate reference system (CRS) of the input coordinates.
        dst_crs (int or str, optional): The EPSG code or proj4 string representing the destination CRS to reproject the data (default is "EPSG:4326").
        output (str or None, optional): The full file path (including the directory and filename without the extension) where the vector data should be saved.
                                       If None (default), the function returns the GeoDataFrame without saving it to a file.
        **kwargs: Additional keyword arguments to pass to geopandas.GeoDataFrame.to_file() when saving the vector data.

    Returns:
        geopandas.GeoDataFrame or None: The GeoDataFrame with the converted vector data if output is None, otherwise None if the data is saved to a file.
    """

    from shapely.geometry import box

    # Create a list of Shapely Polygon objects based on the provided coordinates
    polygons = [box(*coord) for coord in coords]

    # Create a GeoDataFrame with the Shapely Polygon objects
    gdf = gpd.GeoDataFrame({"geometry": polygons}, crs=src_crs)

    # Reproject the GeoDataFrame to the specified EPSG code
    gdf_reprojected = gdf.to_crs(dst_crs)

    if output is not None:
        gdf_reprojected.to_file(output, **kwargs)
    else:
        return gdf_reprojected

check_file_path(file_path, make_dirs=True)

Gets the absolute file path.

Parameters:

Name Type Description Default
file_path str

The path to the file.

required
make_dirs bool

Whether to create the directory if it does not exist. Defaults to True.

True

Exceptions:

Type Description
FileNotFoundError

If the directory could not be found.

TypeError

If the input directory path is not a string.

Returns:

Type Description
str

The absolute path to the file.

Source code in samgeo/common.py
def check_file_path(file_path, make_dirs=True):
    """Gets the absolute file path.

    Args:
        file_path (str): The path to the file.
        make_dirs (bool, optional): Whether to create the directory if it does not exist. Defaults to True.

    Raises:
        FileNotFoundError: If the directory could not be found.
        TypeError: If the input directory path is not a string.

    Returns:
        str: The absolute path to the file.
    """
    if isinstance(file_path, str):
        if file_path.startswith("~"):
            file_path = os.path.expanduser(file_path)
        else:
            file_path = os.path.abspath(file_path)

        file_dir = os.path.dirname(file_path)
        if not os.path.exists(file_dir) and make_dirs:
            os.makedirs(file_dir)

        return file_path

    else:
        raise TypeError("The provided file path must be a string.")

choose_device(empty_cache=True, quiet=True)

Choose a device (CPU or GPU) for deep learning.

Parameters:

Name Type Description Default
empty_cache bool

Whether to empty the CUDA cache if a GPU is used. Defaults to True.

True
quiet bool

Whether to suppress device information printout. Defaults to True.

True

Returns:

Type Description
str

The device name.

Source code in samgeo/common.py
def choose_device(empty_cache: bool = True, quiet: bool = True) -> str:
    """Choose a device (CPU or GPU) for deep learning.

    Args:
        empty_cache (bool): Whether to empty the CUDA cache if a GPU is used. Defaults to True.
        quiet (bool): Whether to suppress device information printout. Defaults to True.

    Returns:
        str: The device name.
    """
    import torch

    # if using Apple MPS, fall back to CPU for unsupported ops
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

    # select the device for computation
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    if not quiet:
        print(f"Using device: {device}")

    if device.type == "cuda":
        if empty_cache:
            torch.cuda.empty_cache()
        # use bfloat16 for the entire notebook
        torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
        # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
        if torch.cuda.get_device_properties(0).major >= 8:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
    elif device.type == "mps":
        if not quiet:
            print(
                "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
                "give numerically different outputs and sometimes degraded performance on MPS. "
                "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
            )
    return device

coords_to_geojson(coords, output=None)

Convert a list of coordinates (lon, lat) to a GeoJSON string or file.

Parameters:

Name Type Description Default
coords list

A list of coordinates (lon, lat).

required
output str

The output file path. Defaults to None.

None

Returns:

Type Description
dict

A GeoJSON dictionary.

Source code in samgeo/common.py
def coords_to_geojson(coords, output=None):
    """Convert a list of coordinates (lon, lat) to a GeoJSON string or file.

    Args:
        coords (list): A list of coordinates (lon, lat).
        output (str, optional): The output file path. Defaults to None.

    Returns:
        dict: A GeoJSON dictionary.
    """

    import json

    if len(coords) == 0:
        return
    # Create a GeoJSON FeatureCollection object
    feature_collection = {"type": "FeatureCollection", "features": []}

    # Iterate through the coordinates list and create a GeoJSON Feature object for each coordinate
    for coord in coords:
        feature = {
            "type": "Feature",
            "geometry": {"type": "Point", "coordinates": coord},
            "properties": {},
        }
        feature_collection["features"].append(feature)

    # Convert the FeatureCollection object to a JSON string
    geojson_str = json.dumps(feature_collection)

    if output is not None:
        with open(output, "w") as f:
            f.write(geojson_str)
    else:
        return geojson_str

coords_to_xy(src_fp, coords, coord_crs='epsg:4326', return_out_of_bounds=False, **kwargs)

Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.

Parameters:

Name Type Description Default
src_fp str

The source raster file path.

required
coords list

A list of coordinates in the format of [[x1, y1], [x2, y2], ...]

required
coord_crs str

The coordinate CRS of the input coordinates. Defaults to "epsg:4326".

'epsg:4326'
return_out_of_bounds

Whether to return out of bounds coordinates. Defaults to False.

False
**kwargs

Additional keyword arguments to pass to rasterio.transform.rowcol.

{}

Returns:

Type Description
list

A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]

Source code in samgeo/common.py
def coords_to_xy(
    src_fp: str,
    coords: list,
    coord_crs: str = "epsg:4326",
    return_out_of_bounds=False,
    **kwargs,
) -> list:
    """Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.

    Args:
        src_fp: The source raster file path.
        coords: A list of coordinates in the format of [[x1, y1], [x2, y2], ...]
        coord_crs: The coordinate CRS of the input coordinates. Defaults to "epsg:4326".
        return_out_of_bounds: Whether to return out of bounds coordinates. Defaults to False.
        **kwargs: Additional keyword arguments to pass to rasterio.transform.rowcol.

    Returns:
        A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]
    """
    out_of_bounds = []

    if isinstance(coords, np.ndarray):
        coords = coords.tolist()

    xs, ys = zip(*coords)
    with rasterio.open(src_fp) as src:
        width = src.width
        height = src.height
        if coord_crs != src.crs:
            xs, ys = transform_coords(xs, ys, coord_crs, src.crs, **kwargs)
        rows, cols = rasterio.transform.rowcol(src.transform, xs, ys, **kwargs)
    result = [[col, row] for col, row in zip(cols, rows)]

    output = []

    for i, (x, y) in enumerate(result):
        if x >= 0 and y >= 0 and x < width and y < height:
            output.append([x, y])
        else:
            out_of_bounds.append(i)

    # output = [
    #     [x, y] for x, y in result if x >= 0 and y >= 0 and x < width and y < height
    # ]
    if len(output) == 0:
        print("No valid pixel coordinates found.")
    elif len(output) < len(coords):
        print("Some coordinates are out of the image boundary.")

    if return_out_of_bounds:
        return output, out_of_bounds
    else:
        return output

download_checkpoint(model_type='vit_h', checkpoint_dir=None, hq=False)

Download the SAM model checkpoint.

Parameters:

Name Type Description Default
model_type str

The model type. Can be one of ['vit_h', 'vit_l', 'vit_b']. Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.

'vit_h'
checkpoint_dir str

The checkpoint_dir directory. Defaults to None, "~/.cache/torch/hub/checkpoints".

None
hq bool

Whether to use HQ-SAM model (https://github.com/SysCV/sam-hq). Defaults to False.

False
Source code in samgeo/common.py
def download_checkpoint(model_type="vit_h", checkpoint_dir=None, hq=False):
    """Download the SAM model checkpoint.

    Args:
        model_type (str, optional): The model type. Can be one of ['vit_h', 'vit_l', 'vit_b'].
            Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
        checkpoint_dir (str, optional): The checkpoint_dir directory. Defaults to None, "~/.cache/torch/hub/checkpoints".
        hq (bool, optional): Whether to use HQ-SAM model (https://github.com/SysCV/sam-hq). Defaults to False.
    """

    if not hq:
        model_types = {
            "vit_h": {
                "name": "sam_vit_h_4b8939.pth",
                "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
            },
            "vit_l": {
                "name": "sam_vit_l_0b3195.pth",
                "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
            },
            "vit_b": {
                "name": "sam_vit_b_01ec64.pth",
                "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
            },
        }
    else:
        model_types = {
            "vit_h": {
                "name": "sam_hq_vit_h.pth",
                "url": [
                    "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_h.zip",
                    "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_h.z01",
                ],
            },
            "vit_l": {
                "name": "sam_hq_vit_l.pth",
                "url": "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_l.pth",
            },
            "vit_b": {
                "name": "sam_hq_vit_b.pth",
                "url": "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_b.pth",
            },
            "vit_tiny": {
                "name": "sam_hq_vit_tiny.pth",
                "url": "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_tiny.pth",
            },
        }

    if model_type not in model_types:
        raise ValueError(
            f"Invalid model_type: {model_type}. It must be one of {', '.join(model_types)}"
        )

    if checkpoint_dir is None:
        checkpoint_dir = os.environ.get(
            "TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints")
        )

    checkpoint = os.path.join(checkpoint_dir, model_types[model_type]["name"])
    if not os.path.exists(checkpoint):
        print(f"Model checkpoint for {model_type} not found.")
        url = model_types[model_type]["url"]
        if isinstance(url, str):
            download_file(url, checkpoint)
        elif isinstance(url, list):
            download_files(url, checkpoint_dir, multi_part=True)
    return checkpoint

download_checkpoint_legacy(url=None, output=None, overwrite=False, **kwargs)

Download a checkpoint from URL. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.

Parameters:

Name Type Description Default
url str

The checkpoint URL. Defaults to None.

None
output str

The output file path. Defaults to None.

None
overwrite bool

Overwrite the file if it already exists. Defaults to False.

False

Returns:

Type Description
str

The output file path.

Source code in samgeo/common.py
def download_checkpoint_legacy(url=None, output=None, overwrite=False, **kwargs):
    """Download a checkpoint from URL. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.

    Args:
        url (str, optional): The checkpoint URL. Defaults to None.
        output (str, optional): The output file path. Defaults to None.
        overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.

    Returns:
        str: The output file path.
    """
    checkpoints = {
        "sam_vit_h_4b8939.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
        "sam_vit_l_0b3195.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
        "sam_vit_b_01ec64.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
    }

    if isinstance(url, str) and url in checkpoints:
        url = checkpoints[url]

    if url is None:
        url = checkpoints["sam_vit_h_4b8939.pth"]

    if output is None:
        output = os.path.basename(url)

    return download_file(url, output, overwrite=overwrite, **kwargs)

download_file(url=None, output=None, quiet=False, proxy=None, speed=None, use_cookies=True, verify=True, id=None, fuzzy=False, resume=False, unzip=True, overwrite=False, subfolder=False)

Download a file from URL, including Google Drive shared URL.

Parameters:

Name Type Description Default
url str

Google Drive URL is also supported. Defaults to None.

None
output str

Output filename. Default is basename of URL.

None
quiet bool

Suppress terminal output. Default is False.

False
proxy str

Proxy. Defaults to None.

None
speed float

Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None.

None
use_cookies bool

Flag to use cookies. Defaults to True.

True
verify bool | str

Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string, in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True.

True
id str

Google Drive's file ID. Defaults to None.

None
fuzzy bool

Fuzzy extraction of Google Drive's file Id. Defaults to False.

False
resume bool

Resume the download from existing tmp file if possible. Defaults to False.

False
unzip bool

Unzip the file. Defaults to True.

True
overwrite bool

Overwrite the file if it already exists. Defaults to False.

False
subfolder bool

Create a subfolder with the same name as the file. Defaults to False.

False

Returns:

Type Description
str

The output file path.

Source code in samgeo/common.py
def download_file(
    url=None,
    output=None,
    quiet=False,
    proxy=None,
    speed=None,
    use_cookies=True,
    verify=True,
    id=None,
    fuzzy=False,
    resume=False,
    unzip=True,
    overwrite=False,
    subfolder=False,
):
    """Download a file from URL, including Google Drive shared URL.

    Args:
        url (str, optional): Google Drive URL is also supported. Defaults to None.
        output (str, optional): Output filename. Default is basename of URL.
        quiet (bool, optional): Suppress terminal output. Default is False.
        proxy (str, optional): Proxy. Defaults to None.
        speed (float, optional): Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None.
        use_cookies (bool, optional): Flag to use cookies. Defaults to True.
        verify (bool | str, optional): Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string,
            in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True.
        id (str, optional): Google Drive's file ID. Defaults to None.
        fuzzy (bool, optional): Fuzzy extraction of Google Drive's file Id. Defaults to False.
        resume (bool, optional): Resume the download from existing tmp file if possible. Defaults to False.
        unzip (bool, optional): Unzip the file. Defaults to True.
        overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.
        subfolder (bool, optional): Create a subfolder with the same name as the file. Defaults to False.

    Returns:
        str: The output file path.
    """
    import zipfile

    try:
        import gdown
    except ImportError:
        print(
            "The gdown package is required for this function. Use `pip install gdown` to install it."
        )
        return

    if output is None:
        if isinstance(url, str) and url.startswith("http"):
            output = os.path.basename(url)

    out_dir = os.path.abspath(os.path.dirname(output))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    if isinstance(url, str):
        if os.path.exists(os.path.abspath(output)) and (not overwrite):
            print(
                f"{output} already exists. Skip downloading. Set overwrite=True to overwrite."
            )
            return os.path.abspath(output)
        else:
            url = github_raw_url(url)

    if "https://drive.google.com/file/d/" in url:
        fuzzy = True

    output = gdown.download(
        url, output, quiet, proxy, speed, use_cookies, verify, id, fuzzy, resume
    )

    if unzip and output.endswith(".zip"):
        with zipfile.ZipFile(output, "r") as zip_ref:
            if not quiet:
                print("Extracting files...")
            if subfolder:
                basename = os.path.splitext(os.path.basename(output))[0]

                output = os.path.join(out_dir, basename)
                if not os.path.exists(output):
                    os.makedirs(output)
                zip_ref.extractall(output)
            else:
                zip_ref.extractall(os.path.dirname(output))

    return os.path.abspath(output)

download_files(urls, out_dir=None, filenames=None, quiet=False, proxy=None, speed=None, use_cookies=True, verify=True, id=None, fuzzy=False, resume=False, unzip=True, overwrite=False, subfolder=False, multi_part=False)

Download files from URLs, including Google Drive shared URL.

Parameters:

Name Type Description Default
urls list

The list of urls to download. Google Drive URL is also supported.

required
out_dir str

The output directory. Defaults to None.

None
filenames list

Output filename. Default is basename of URL.

None
quiet bool

Suppress terminal output. Default is False.

False
proxy str

Proxy. Defaults to None.

None
speed float

Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None.

None
use_cookies bool

Flag to use cookies. Defaults to True.

True
verify bool | str

Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string, in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True.

True
id str

Google Drive's file ID. Defaults to None.

None
fuzzy bool

Fuzzy extraction of Google Drive's file Id. Defaults to False.

False
resume bool

Resume the download from existing tmp file if possible. Defaults to False.

False
unzip bool

Unzip the file. Defaults to True.

True
overwrite bool

Overwrite the file if it already exists. Defaults to False.

False
subfolder bool

Create a subfolder with the same name as the file. Defaults to False.

False
multi_part bool

If the file is a multi-part file. Defaults to False.

False

Examples:

files = ["sam_hq_vit_tiny.zip", "sam_hq_vit_tiny.z01", "sam_hq_vit_tiny.z02", "sam_hq_vit_tiny.z03"] base_url = "https://github.com/opengeos/datasets/releases/download/models/" urls = [base_url + f for f in files] leafmap.download_files(urls, out_dir="models", multi_part=True)

Source code in samgeo/common.py
def download_files(
    urls,
    out_dir=None,
    filenames=None,
    quiet=False,
    proxy=None,
    speed=None,
    use_cookies=True,
    verify=True,
    id=None,
    fuzzy=False,
    resume=False,
    unzip=True,
    overwrite=False,
    subfolder=False,
    multi_part=False,
):
    """Download files from URLs, including Google Drive shared URL.

    Args:
        urls (list): The list of urls to download. Google Drive URL is also supported.
        out_dir (str, optional): The output directory. Defaults to None.
        filenames (list, optional): Output filename. Default is basename of URL.
        quiet (bool, optional): Suppress terminal output. Default is False.
        proxy (str, optional): Proxy. Defaults to None.
        speed (float, optional): Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None.
        use_cookies (bool, optional): Flag to use cookies. Defaults to True.
        verify (bool | str, optional): Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string, in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True.
        id (str, optional): Google Drive's file ID. Defaults to None.
        fuzzy (bool, optional): Fuzzy extraction of Google Drive's file Id. Defaults to False.
        resume (bool, optional): Resume the download from existing tmp file if possible. Defaults to False.
        unzip (bool, optional): Unzip the file. Defaults to True.
        overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.
        subfolder (bool, optional): Create a subfolder with the same name as the file. Defaults to False.
        multi_part (bool, optional): If the file is a multi-part file. Defaults to False.

    Examples:

        files = ["sam_hq_vit_tiny.zip", "sam_hq_vit_tiny.z01", "sam_hq_vit_tiny.z02", "sam_hq_vit_tiny.z03"]
        base_url = "https://github.com/opengeos/datasets/releases/download/models/"
        urls = [base_url + f for f in files]
        leafmap.download_files(urls, out_dir="models", multi_part=True)
    """

    if out_dir is None:
        out_dir = os.getcwd()

    if filenames is None:
        filenames = [None] * len(urls)

    filepaths = []
    for url, output in zip(urls, filenames):
        if output is None:
            filename = os.path.join(out_dir, os.path.basename(url))
        else:
            filename = os.path.join(out_dir, output)

        filepaths.append(filename)
        if multi_part:
            unzip = False

        download_file(
            url,
            filename,
            quiet,
            proxy,
            speed,
            use_cookies,
            verify,
            id,
            fuzzy,
            resume,
            unzip,
            overwrite,
            subfolder,
        )

    if multi_part:
        archive = os.path.splitext(filename)[0] + ".zip"
        out_dir = os.path.dirname(filename)
        extract_archive(archive, out_dir)

        for file in filepaths:
            os.remove(file)

extract_archive(archive, outdir=None, **kwargs)

Extracts a multipart archive.

This function uses the patoolib library to extract a multipart archive. If the patoolib library is not installed, it attempts to install it. If the archive does not end with ".zip", it appends ".zip" to the archive name. If the extraction fails (for example, if the files already exist), it skips the extraction.

Parameters:

Name Type Description Default
archive str

The path to the archive file.

required
outdir str

The directory where the archive should be extracted.

None
**kwargs

Arbitrary keyword arguments for the patoolib.extract_archive function.

{}

Returns:

Type Description

None

Exceptions:

Type Description
Exception

An exception is raised if the extraction fails for reasons other than the files already existing.

Examples:

files = ["sam_hq_vit_tiny.zip", "sam_hq_vit_tiny.z01", "sam_hq_vit_tiny.z02", "sam_hq_vit_tiny.z03"] base_url = "https://github.com/opengeos/datasets/releases/download/models/" urls = [base_url + f for f in files] leafmap.download_files(urls, out_dir="models", multi_part=True)

Source code in samgeo/common.py
def extract_archive(archive, outdir=None, **kwargs):
    """
    Extracts a multipart archive.

    This function uses the patoolib library to extract a multipart archive.
    If the patoolib library is not installed, it attempts to install it.
    If the archive does not end with ".zip", it appends ".zip" to the archive name.
    If the extraction fails (for example, if the files already exist), it skips the extraction.

    Args:
        archive (str): The path to the archive file.
        outdir (str): The directory where the archive should be extracted.
        **kwargs: Arbitrary keyword arguments for the patoolib.extract_archive function.

    Returns:
        None

    Raises:
        Exception: An exception is raised if the extraction fails for reasons other than the files already existing.

    Example:

        files = ["sam_hq_vit_tiny.zip", "sam_hq_vit_tiny.z01", "sam_hq_vit_tiny.z02", "sam_hq_vit_tiny.z03"]
        base_url = "https://github.com/opengeos/datasets/releases/download/models/"
        urls = [base_url + f for f in files]
        leafmap.download_files(urls, out_dir="models", multi_part=True)

    """
    try:
        import patoolib
    except ImportError:
        install_package("patool")
        import patoolib

    if not archive.endswith(".zip"):
        archive = archive + ".zip"

    if outdir is None:
        outdir = os.path.dirname(archive)

    try:
        patoolib.extract_archive(archive, outdir=outdir, **kwargs)
    except Exception as e:
        print("The unzipped files might already exist. Skipping extraction.")
        return

geojson_to_coords(geojson, src_crs='epsg:4326', dst_crs='epsg:4326')

Converts a geojson file or a dictionary of feature collection to a list of centroid coordinates.

Parameters:

Name Type Description Default
geojson str | dict

The geojson file path or a dictionary of feature collection.

required
src_crs str

The source CRS. Defaults to "epsg:4326".

'epsg:4326'
dst_crs str

The destination CRS. Defaults to "epsg:4326".

'epsg:4326'

Returns:

Type Description
list

A list of centroid coordinates in the format of [[x1, y1], [x2, y2], ...]

Source code in samgeo/common.py
def geojson_to_coords(
    geojson: str, src_crs: str = "epsg:4326", dst_crs: str = "epsg:4326"
) -> list:
    """Converts a geojson file or a dictionary of feature collection to a list of centroid coordinates.

    Args:
        geojson (str | dict): The geojson file path or a dictionary of feature collection.
        src_crs (str, optional): The source CRS. Defaults to "epsg:4326".
        dst_crs (str, optional): The destination CRS. Defaults to "epsg:4326".

    Returns:
        list: A list of centroid coordinates in the format of [[x1, y1], [x2, y2], ...]
    """

    import json
    import warnings

    warnings.filterwarnings("ignore")

    if isinstance(geojson, dict):
        geojson = json.dumps(geojson)
    gdf = gpd.read_file(geojson, driver="GeoJSON")
    centroids = gdf.geometry.centroid
    centroid_list = [[point.x, point.y] for point in centroids]
    if src_crs != dst_crs:
        centroid_list = transform_coords(
            [x[0] for x in centroid_list],
            [x[1] for x in centroid_list],
            src_crs,
            dst_crs,
        )
        centroid_list = [[x, y] for x, y in zip(centroid_list[0], centroid_list[1])]
    return centroid_list

geojson_to_xy(src_fp, geojson, coord_crs='epsg:4326', **kwargs)

Converts a geojson file or a dictionary of feature collection to a list of pixel coordinates.

Parameters:

Name Type Description Default
src_fp str

The source raster file path.

required
geojson str

The geojson file path or a dictionary of feature collection.

required
coord_crs str

The coordinate CRS of the input coordinates. Defaults to "epsg:4326".

'epsg:4326'
**kwargs

Additional keyword arguments to pass to rasterio.transform.rowcol.

{}

Returns:

Type Description
list

A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]

Source code in samgeo/common.py
def geojson_to_xy(
    src_fp: str, geojson: str, coord_crs: str = "epsg:4326", **kwargs
) -> list:
    """Converts a geojson file or a dictionary of feature collection to a list of pixel coordinates.

    Args:
        src_fp: The source raster file path.
        geojson: The geojson file path or a dictionary of feature collection.
        coord_crs: The coordinate CRS of the input coordinates. Defaults to "epsg:4326".
        **kwargs: Additional keyword arguments to pass to rasterio.transform.rowcol.

    Returns:
        A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]
    """
    with rasterio.open(src_fp) as src:
        src_crs = src.crs
    coords = geojson_to_coords(geojson, coord_crs, src_crs)
    return coords_to_xy(src_fp, coords, src_crs, **kwargs)

geotiff_to_jpg(geotiff_path, output_path)

Convert a GeoTIFF file to a JPG file.

Parameters:

Name Type Description Default
geotiff_path str

The path to the input GeoTIFF file.

required
output_path str

The path to the output JPG file.

required
Source code in samgeo/common.py
def geotiff_to_jpg(geotiff_path: str, output_path: str) -> None:
    """Convert a GeoTIFF file to a JPG file.

    Args:
        geotiff_path (str): The path to the input GeoTIFF file.
        output_path (str): The path to the output JPG file.
    """

    from PIL import Image

    # Open the GeoTIFF file
    with rasterio.open(geotiff_path) as src:
        # Read the first band (for grayscale) or all bands
        array = src.read()

        # If the array has more than 3 bands, reduce it to the first 3 (RGB)
        if array.shape[0] >= 3:
            array = array[:3, :, :]  # Select the first 3 bands (R, G, B)
        elif array.shape[0] == 1:
            # For single-band images, repeat the band to create a grayscale RGB
            array = np.repeat(array, 3, axis=0)

        # Transpose the array from (bands, height, width) to (height, width, bands)
        array = np.transpose(array, (1, 2, 0))

        # Normalize the array to 8-bit (0-255) range for JPG
        array = array.astype(np.float32)
        array -= array.min()
        array /= array.max()
        array *= 255
        array = array.astype(np.uint8)

        # Convert to a PIL Image and save as JPG
        image = Image.fromarray(array)
        image.save(output_path)

geotiff_to_jpg_batch(input_folder, output_folder=None)

Convert all GeoTIFF files in a folder to JPG files.

Parameters:

Name Type Description Default
input_folder str

The path to the folder containing GeoTIFF files.

required
output_folder str

The path to the folder to save the output JPG files.

None

Returns:

Type Description
str

The path to the output folder containing the JPG files.

Source code in samgeo/common.py
def geotiff_to_jpg_batch(input_folder: str, output_folder: str = None) -> str:
    """Convert all GeoTIFF files in a folder to JPG files.

    Args:
        input_folder (str): The path to the folder containing GeoTIFF files.
        output_folder (str): The path to the folder to save the output JPG files.

    Returns:
        str: The path to the output folder containing the JPG files.
    """

    if output_folder is None:
        output_folder = make_temp_dir()

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    geotiff_files = [
        f for f in os.listdir(input_folder) if f.endswith(".tif") or f.endswith(".tiff")
    ]

    # Initialize tqdm progress bar
    for filename in tqdm(geotiff_files, desc="Converting GeoTIFF to JPG"):
        geotiff_path = os.path.join(input_folder, filename)
        jpg_filename = os.path.splitext(filename)[0] + ".jpg"
        output_path = os.path.join(output_folder, jpg_filename)
        geotiff_to_jpg(geotiff_path, output_path)

    return output_folder

get_basemaps(free_only=True)

Returns a dictionary of xyz basemaps.

Parameters:

Name Type Description Default
free_only bool

Whether to return only free xyz tile services that do not require an access token. Defaults to True.

True

Returns:

Type Description
dict

A dictionary of xyz basemaps.

Source code in samgeo/common.py
def get_basemaps(free_only=True):
    """Returns a dictionary of xyz basemaps.

    Args:
        free_only (bool, optional): Whether to return only free xyz tile services that do not require an access token. Defaults to True.

    Returns:
        dict: A dictionary of xyz basemaps.
    """

    basemaps = {}
    xyz_dict = get_xyz_dict(free_only=free_only)
    for item in xyz_dict:
        name = xyz_dict[item].name
        url = xyz_dict[item].build_url()
        basemaps[name] = url

    return basemaps

get_vector_crs(filename, **kwargs)

Gets the CRS of a vector file.

Parameters:

Name Type Description Default
filename str

The vector file path.

required

Returns:

Type Description
str

The CRS of the vector file.

Source code in samgeo/common.py
def get_vector_crs(filename, **kwargs):
    """Gets the CRS of a vector file.

    Args:
        filename (str): The vector file path.

    Returns:
        str: The CRS of the vector file.
    """
    gdf = gpd.read_file(filename, **kwargs)
    epsg = gdf.crs.to_epsg()
    if epsg is None:
        return gdf.crs
    else:
        return f"EPSG:{epsg}"

get_xyz_dict(free_only=True)

Returns a dictionary of xyz services.

Parameters:

Name Type Description Default
free_only bool

Whether to return only free xyz tile services that do not require an access token. Defaults to True.

True

Returns:

Type Description
dict

A dictionary of xyz services.

Source code in samgeo/common.py
def get_xyz_dict(free_only=True):
    """Returns a dictionary of xyz services.

    Args:
        free_only (bool, optional): Whether to return only free xyz tile services that do not require an access token. Defaults to True.

    Returns:
        dict: A dictionary of xyz services.
    """
    import collections
    import xyzservices.providers as xyz

    def _unpack_sub_parameters(var, param):
        temp = var
        for sub_param in param.split("."):
            temp = getattr(temp, sub_param)
        return temp

    xyz_dict = {}
    for item in xyz.values():
        try:
            name = item["name"]
            tile = _unpack_sub_parameters(xyz, name)
            if _unpack_sub_parameters(xyz, name).requires_token():
                if free_only:
                    pass
                else:
                    xyz_dict[name] = tile
            else:
                xyz_dict[name] = tile

        except Exception:
            for sub_item in item:
                name = item[sub_item]["name"]
                tile = _unpack_sub_parameters(xyz, name)
                if _unpack_sub_parameters(xyz, name).requires_token():
                    if free_only:
                        pass
                    else:
                        xyz_dict[name] = tile
                else:
                    xyz_dict[name] = tile

    xyz_dict = collections.OrderedDict(sorted(xyz_dict.items()))
    return xyz_dict

github_raw_url(url)

Get the raw URL for a GitHub file.

Parameters:

Name Type Description Default
url str

The GitHub URL.

required

Returns:

Type Description
str

The raw URL.

Source code in samgeo/common.py
def github_raw_url(url):
    """Get the raw URL for a GitHub file.

    Args:
        url (str): The GitHub URL.
    Returns:
        str: The raw URL.
    """
    if isinstance(url, str) and url.startswith("https://github.com/") and "blob" in url:
        url = url.replace("github.com", "raw.githubusercontent.com").replace(
            "blob/", ""
        )
    return url

image_to_cog(source, dst_path=None, profile='deflate', **kwargs)

Converts an image to a COG file.

Parameters:

Name Type Description Default
source str

A dataset path, URL or rasterio.io.DatasetReader object.

required
dst_path str

An output dataset path or or PathLike object. Defaults to None.

None
profile str

COG profile. More at https://cogeotiff.github.io/rio-cogeo/profile. Defaults to "deflate".

'deflate'

Exceptions:

Type Description
ImportError

If rio-cogeo is not installed.

FileNotFoundError

If the source file could not be found.

Source code in samgeo/common.py
def image_to_cog(source, dst_path=None, profile="deflate", **kwargs):
    """Converts an image to a COG file.

    Args:
        source (str): A dataset path, URL or rasterio.io.DatasetReader object.
        dst_path (str, optional): An output dataset path or or PathLike object. Defaults to None.
        profile (str, optional): COG profile. More at https://cogeotiff.github.io/rio-cogeo/profile. Defaults to "deflate".

    Raises:
        ImportError: If rio-cogeo is not installed.
        FileNotFoundError: If the source file could not be found.
    """
    try:
        from rio_cogeo.cogeo import cog_translate
        from rio_cogeo.profiles import cog_profiles

    except ImportError:
        raise ImportError(
            "The rio-cogeo package is not installed. Please install it with `pip install rio-cogeo` or `conda install rio-cogeo -c conda-forge`."
        )

    if not source.startswith("http"):
        source = check_file_path(source)

        if not os.path.exists(source):
            raise FileNotFoundError("The provided input file could not be found.")

    if dst_path is None:
        if not source.startswith("http"):
            dst_path = os.path.splitext(source)[0] + "_cog.tif"
        else:
            dst_path = temp_file_path(extension=".tif")

    dst_path = check_file_path(dst_path)

    dst_profile = cog_profiles.get(profile)
    cog_translate(source, dst_path, dst_profile, **kwargs)

images_to_video(images, output_video, fps=30, video_size=None)

Converts a series of images into a video. The input can be either a directory containing the images or a list of image file paths.

Parameters:

Name Type Description Default
images Union[str, List[str]]

A directory containing images or a list of image file paths.

required
output_video str

The filename of the output video (e.g., 'output.mp4').

required
fps int

Frames per second for the output video. Default is 30.

30
video_size Optional[tuple]

The size (width, height) of the video. If not provided, the size of the first image is used.

None

Exceptions:

Type Description
ValueError

If the provided path is not a directory, if the images list is empty, or if the first image cannot be read.

Example usage: images_to_video('path_to_image_directory', 'output_video.mp4', fps=30, video_size=(1280, 720)) images_to_video(['image1.jpg', 'image2.jpg', 'image3.jpg'], 'output_video.mp4', fps=30)

Source code in samgeo/common.py
def images_to_video(
    images: Union[str, List[str]],
    output_video: str,
    fps: int = 30,
    video_size: Optional[tuple] = None,
) -> None:
    """
    Converts a series of images into a video. The input can be either a directory
    containing the images or a list of image file paths.

    Args:
        images (Union[str, List[str]]): A directory containing images or a list
            of image file paths.
        output_video (str): The filename of the output video (e.g., 'output.mp4').
        fps (int, optional): Frames per second for the output video. Default is 30.
        video_size (Optional[tuple], optional): The size (width, height) of the
            video. If not provided, the size of the first image is used.

    Raises:
        ValueError: If the provided path is not a directory, if the images list
            is empty, or if the first image cannot be read.

    Example usage:
        images_to_video('path_to_image_directory', 'output_video.mp4', fps=30, video_size=(1280, 720))
        images_to_video(['image1.jpg', 'image2.jpg', 'image3.jpg'], 'output_video.mp4', fps=30)
    """
    if isinstance(images, str):
        if not os.path.isdir(images):
            raise ValueError(f"The provided path {images} is not a valid directory.")

        # Get all image files in the directory (sorted by filename)

        files = sorted(os.listdir(images))
        if len(files) == 0:
            raise ValueError(f"No image files found in the directory {images}")
        elif files[0].endswith(".tif"):
            images = geotiff_to_jpg_batch(images)

        images = [
            os.path.join(images, img)
            for img in sorted(os.listdir(images))
            if img.endswith((".jpg", ".png"))
        ]

    if not isinstance(images, list) or not images:
        raise ValueError(
            "The images parameter should either be a non-empty list of image paths or a valid directory."
        )

    # Read the first image to get the dimensions if video_size is not provided
    first_image_path = images[0]
    frame = cv2.imread(first_image_path)

    if frame is None:
        raise ValueError(f"Error reading the first image {first_image_path}")

    if video_size is None:
        height, width, _ = frame.shape
        video_size = (width, height)

    fourcc = cv2.VideoWriter_fourcc(*"avc1")  # Define the codec for mp4
    video_writer = cv2.VideoWriter(output_video, fourcc, fps, video_size)

    for image_path in images:
        frame = cv2.imread(image_path)
        if frame is None:
            print(f"Warning: Could not read image {image_path}. Skipping.")
            continue

        if video_size != (frame.shape[1], frame.shape[0]):
            frame = cv2.resize(frame, video_size)

        video_writer.write(frame)

    video_writer.release()
    print(f"Video saved as {output_video}")

install_package(package)

Install a Python package.

Parameters:

Name Type Description Default
package str | list

The package name or a GitHub URL or a list of package names or GitHub URLs.

required
Source code in samgeo/common.py
def install_package(package):
    """Install a Python package.

    Args:
        package (str | list): The package name or a GitHub URL or a list of package names or GitHub URLs.
    """
    import subprocess

    if isinstance(package, str):
        packages = [package]

    for package in packages:
        if package.startswith("https://github.com"):
            package = f"git+{package}"

        # Execute pip install command and show output in real-time
        command = f"pip install {package}"
        process = subprocess.Popen(command.split(), stdout=subprocess.PIPE)

        # Print output in real-time
        while True:
            output = process.stdout.readline()
            if output == b"" and process.poll() is not None:
                break
            if output:
                print(output.decode("utf-8").strip())

        # Wait for process to complete
        process.wait()

is_colab()

Tests if the code is being executed within Google Colab.

Source code in samgeo/common.py
def is_colab():
    """Tests if the code is being executed within Google Colab."""
    import sys

    if "google.colab" in sys.modules:
        return True
    else:
        return False

make_temp_dir(**kwargs)

Create a temporary directory and return the path.

Returns:

Type Description
str

The path to the temporary directory.

Source code in samgeo/common.py
def make_temp_dir(**kwargs) -> str:
    """Create a temporary directory and return the path.

    Returns:
        str: The path to the temporary directory.
    """
    import tempfile

    temp_dir = tempfile.mkdtemp(**kwargs)
    return temp_dir

merge_rasters(input_dir, output, input_pattern='*.tif', output_format='GTiff', output_nodata=None, output_options=['COMPRESS=DEFLATE'])

Merge a directory of rasters into a single raster.

Parameters:

Name Type Description Default
input_dir str

The path to the input directory.

required
output str

The path to the output raster.

required
input_pattern str

The pattern to match the input files. Defaults to "*.tif".

'*.tif'
output_format str

The output format. Defaults to "GTiff".

'GTiff'
output_nodata float

The output nodata value. Defaults to None.

None
output_options list

A list of output options. Defaults to ["COMPRESS=DEFLATE"].

['COMPRESS=DEFLATE']

Exceptions:

Type Description
ImportError

Raised if GDAL is not installed.

Source code in samgeo/common.py
def merge_rasters(
    input_dir,
    output,
    input_pattern="*.tif",
    output_format="GTiff",
    output_nodata=None,
    output_options=["COMPRESS=DEFLATE"],
):
    """Merge a directory of rasters into a single raster.

    Args:
        input_dir (str): The path to the input directory.
        output (str): The path to the output raster.
        input_pattern (str, optional): The pattern to match the input files. Defaults to "*.tif".
        output_format (str, optional): The output format. Defaults to "GTiff".
        output_nodata (float, optional): The output nodata value. Defaults to None.
        output_options (list, optional): A list of output options. Defaults to ["COMPRESS=DEFLATE"].

    Raises:
        ImportError: Raised if GDAL is not installed.
    """

    import glob

    try:
        from osgeo import gdal
    except ImportError:
        raise ImportError(
            "GDAL is required to use this function. Install it with `conda install gdal -c conda-forge`"
        )

    # Get a list of all the input files
    input_files = glob.glob(os.path.join(input_dir, input_pattern))

    # Prepare the gdal.Warp options
    warp_options = gdal.WarpOptions(
        format=output_format, dstNodata=output_nodata, creationOptions=output_options
    )

    # Merge the input files into a single output file
    gdal.Warp(
        destNameOrDestDS=output,
        srcDSOrSrcDSTab=input_files,
        options=warp_options,
    )

overlay_images(image1, image2, alpha=0.5, backend='TkAgg', height_ratios=[10, 1], show_args1={}, show_args2={})

Overlays two images using a slider to control the opacity of the top image.

Parameters:

Name Type Description Default
image1 str | np.ndarray

The first input image at the bottom represented as a NumPy array or the path to the image.

required
image2 _type_

The second input image on top represented as a NumPy array or the path to the image.

required
alpha float

The alpha value of the top image. Defaults to 0.5.

0.5
backend str

The backend of the matplotlib plot. Defaults to "TkAgg".

'TkAgg'
height_ratios list

The height ratios of the two subplots. Defaults to [10, 1].

[10, 1]
show_args1 dict

The keyword arguments to pass to the imshow() function for the first image. Defaults to {}.

{}
show_args2 dict

The keyword arguments to pass to the imshow() function for the second image. Defaults to {}.

{}
Source code in samgeo/common.py
def overlay_images(
    image1,
    image2,
    alpha=0.5,
    backend="TkAgg",
    height_ratios=[10, 1],
    show_args1={},
    show_args2={},
):
    """Overlays two images using a slider to control the opacity of the top image.

    Args:
        image1 (str | np.ndarray): The first input image at the bottom represented as a NumPy array or the path to the image.
        image2 (_type_): The second input image on top represented as a NumPy array or the path to the image.
        alpha (float, optional): The alpha value of the top image. Defaults to 0.5.
        backend (str, optional): The backend of the matplotlib plot. Defaults to "TkAgg".
        height_ratios (list, optional): The height ratios of the two subplots. Defaults to [10, 1].
        show_args1 (dict, optional): The keyword arguments to pass to the imshow() function for the first image. Defaults to {}.
        show_args2 (dict, optional): The keyword arguments to pass to the imshow() function for the second image. Defaults to {}.

    """
    import sys
    import matplotlib
    import matplotlib.widgets as mpwidgets

    if "google.colab" in sys.modules:
        backend = "inline"
        print(
            "The TkAgg backend is not supported in Google Colab. The overlay_images function will not work on Colab."
        )
        return

    matplotlib.use(backend)

    if isinstance(image1, str):
        if image1.startswith("http"):
            image1 = download_file(image1)

        if not os.path.exists(image1):
            raise ValueError(f"Input path {image1} does not exist.")

    if isinstance(image2, str):
        if image2.startswith("http"):
            image2 = download_file(image2)

        if not os.path.exists(image2):
            raise ValueError(f"Input path {image2} does not exist.")

    # Load the two images
    x = plt.imread(image1)
    y = plt.imread(image2)

    # Create the plot
    fig, (ax0, ax1) = plt.subplots(2, 1, gridspec_kw={"height_ratios": height_ratios})
    img0 = ax0.imshow(x, **show_args1)
    img1 = ax0.imshow(y, alpha=alpha, **show_args2)

    # Define the update function
    def update(value):
        img1.set_alpha(value)
        fig.canvas.draw_idle()

    # Create the slider
    slider0 = mpwidgets.Slider(ax=ax1, label="alpha", valmin=0, valmax=1, valinit=alpha)
    slider0.on_changed(update)

    # Display the plot
    plt.show()

random_string(string_length=6)

Generates a random string of fixed length.

Parameters:

Name Type Description Default
string_length int

Fixed length. Defaults to 3.

6

Returns:

Type Description
str

A random string

Source code in samgeo/common.py
def random_string(string_length=6):
    """Generates a random string of fixed length.

    Args:
        string_length (int, optional): Fixed length. Defaults to 3.

    Returns:
        str: A random string
    """
    import random
    import string

    # random.seed(1001)
    letters = string.ascii_lowercase
    return "".join(random.choice(letters) for i in range(string_length))

raster_to_geojson(tiff_path, output, simplify_tolerance=None, **kwargs)

Convert a tiff file to a GeoJSON file.

Parameters:

Name Type Description Default
tiff_path str

The path to the tiff file.

required
output str

The path to the GeoJSON file.

required
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in samgeo/common.py
def raster_to_geojson(tiff_path, output, simplify_tolerance=None, **kwargs):
    """Convert a tiff file to a GeoJSON file.

    Args:
        tiff_path (str): The path to the tiff file.
        output (str): The path to the GeoJSON file.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.
    """

    if not output.endswith(".geojson"):
        output += ".geojson"

    raster_to_vector(tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs)

raster_to_gpkg(tiff_path, output, simplify_tolerance=None, **kwargs)

Convert a tiff file to a gpkg file.

Parameters:

Name Type Description Default
tiff_path str

The path to the tiff file.

required
output str

The path to the gpkg file.

required
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in samgeo/common.py
def raster_to_gpkg(tiff_path, output, simplify_tolerance=None, **kwargs):
    """Convert a tiff file to a gpkg file.

    Args:
        tiff_path (str): The path to the tiff file.
        output (str): The path to the gpkg file.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.
    """

    if not output.endswith(".gpkg"):
        output += ".gpkg"

    raster_to_vector(tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs)

raster_to_shp(tiff_path, output, simplify_tolerance=None, **kwargs)

Convert a tiff file to a shapefile.

Parameters:

Name Type Description Default
tiff_path str

The path to the tiff file.

required
output str

The path to the shapefile.

required
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in samgeo/common.py
def raster_to_shp(tiff_path, output, simplify_tolerance=None, **kwargs):
    """Convert a tiff file to a shapefile.

    Args:
        tiff_path (str): The path to the tiff file.
        output (str): The path to the shapefile.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.
    """

    if not output.endswith(".shp"):
        output += ".shp"

    raster_to_vector(tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs)

raster_to_vector(source, output, simplify_tolerance=None, dst_crs=None, **kwargs)

Vectorize a raster dataset.

Parameters:

Name Type Description Default
source str

The path to the tiff file.

required
output str

The path to the vector file.

required
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in samgeo/common.py
def raster_to_vector(source, output, simplify_tolerance=None, dst_crs=None, **kwargs):
    """Vectorize a raster dataset.

    Args:
        source (str): The path to the tiff file.
        output (str): The path to the vector file.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.
    """
    from rasterio import features

    with rasterio.open(source) as src:
        band = src.read()

        mask = band != 0
        shapes = features.shapes(band, mask=mask, transform=src.transform)

    fc = [
        {"geometry": shapely.geometry.shape(shape), "properties": {"value": value}}
        for shape, value in shapes
    ]
    if simplify_tolerance is not None:
        for i in fc:
            i["geometry"] = i["geometry"].simplify(tolerance=simplify_tolerance)

    gdf = gpd.GeoDataFrame.from_features(fc)
    if src.crs is not None:
        gdf.set_crs(crs=src.crs, inplace=True)

    if dst_crs is not None:
        gdf = gdf.to_crs(dst_crs)

    gdf.to_file(output, **kwargs)

regularize(source, output=None, crs='EPSG:4326', **kwargs)

Regularize a polygon GeoDataFrame.

Parameters:

Name Type Description Default
source str | gpd.GeoDataFrame

The input file path or a GeoDataFrame.

required
output str

The output file path. Defaults to None.

None

Returns:

Type Description
gpd.GeoDataFrame

The output GeoDataFrame.

Source code in samgeo/common.py
def regularize(source, output=None, crs="EPSG:4326", **kwargs):
    """Regularize a polygon GeoDataFrame.

    Args:
        source (str | gpd.GeoDataFrame): The input file path or a GeoDataFrame.
        output (str, optional): The output file path. Defaults to None.


    Returns:
        gpd.GeoDataFrame: The output GeoDataFrame.
    """
    if isinstance(source, str):
        gdf = gpd.read_file(source)
    elif isinstance(source, gpd.GeoDataFrame):
        gdf = source
    else:
        raise ValueError("The input source must be a GeoDataFrame or a file path.")

    polygons = gdf.geometry.apply(lambda geom: geom.minimum_rotated_rectangle)
    result = gpd.GeoDataFrame(geometry=polygons, data=gdf.drop("geometry", axis=1))

    if crs is not None:
        result.to_crs(crs, inplace=True)
    if output is not None:
        result.to_file(output, **kwargs)
    else:
        return result

reproject(image, output, dst_crs='EPSG:4326', resampling='nearest', to_cog=True, **kwargs)

Reprojects an image.

Parameters:

Name Type Description Default
image str

The input image filepath.

required
output str

The output image filepath.

required
dst_crs str

The destination CRS. Defaults to "EPSG:4326".

'EPSG:4326'
resampling Resampling

The resampling method. Defaults to "nearest".

'nearest'
to_cog bool

Whether to convert the output image to a Cloud Optimized GeoTIFF. Defaults to True.

True
**kwargs

Additional keyword arguments to pass to rasterio.open.

{}
Source code in samgeo/common.py
def reproject(
    image, output, dst_crs="EPSG:4326", resampling="nearest", to_cog=True, **kwargs
):
    """Reprojects an image.

    Args:
        image (str): The input image filepath.
        output (str): The output image filepath.
        dst_crs (str, optional): The destination CRS. Defaults to "EPSG:4326".
        resampling (Resampling, optional): The resampling method. Defaults to "nearest".
        to_cog (bool, optional): Whether to convert the output image to a Cloud Optimized GeoTIFF. Defaults to True.
        **kwargs: Additional keyword arguments to pass to rasterio.open.

    """
    import rasterio as rio
    from rasterio.warp import calculate_default_transform, reproject, Resampling

    if isinstance(resampling, str):
        resampling = getattr(Resampling, resampling)

    image = os.path.abspath(image)
    output = os.path.abspath(output)

    if not os.path.exists(os.path.dirname(output)):
        os.makedirs(os.path.dirname(output))

    with rio.open(image, **kwargs) as src:
        transform, width, height = calculate_default_transform(
            src.crs, dst_crs, src.width, src.height, *src.bounds
        )
        kwargs = src.meta.copy()
        kwargs.update(
            {
                "crs": dst_crs,
                "transform": transform,
                "width": width,
                "height": height,
            }
        )

        with rio.open(output, "w", **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rio.band(src, i),
                    destination=rio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=dst_crs,
                    resampling=resampling,
                    **kwargs,
                )

    if to_cog:
        image_to_cog(output, output)

rowcol_to_xy(src_fp, rows=None, cols=None, boxes=None, zs=None, offset='center', output=None, dst_crs='EPSG:4326', **kwargs)

Converts a list of (row, col) coordinates to (x, y) coordinates.

Parameters:

Name Type Description Default
src_fp str

The source raster file path.

required
rows list

A list of row coordinates. Defaults to None.

None
cols list

A list of col coordinates. Defaults to None.

None
boxes list

A list of (row, col) coordinates in the format of [[left, top, right, bottom], [left, top, right, bottom], ...]

None
zs

zs (list or float, optional): Height associated with coordinates. Primarily used for RPC based coordinate transformations.

None
offset str

Determines if the returned coordinates are for the center of the pixel or for a corner.

'center'
output str

The output vector file path. Defaults to None.

None
dst_crs str

The destination CRS. Defaults to "EPSG:4326".

'EPSG:4326'
**kwargs

Additional keyword arguments to pass to rasterio.transform.xy.

{}

Returns:

Type Description

A list of (x, y) coordinates.

Source code in samgeo/common.py
def rowcol_to_xy(
    src_fp,
    rows=None,
    cols=None,
    boxes=None,
    zs=None,
    offset="center",
    output=None,
    dst_crs="EPSG:4326",
    **kwargs,
):
    """Converts a list of (row, col) coordinates to (x, y) coordinates.

    Args:
        src_fp (str): The source raster file path.
        rows (list, optional): A list of row coordinates. Defaults to None.
        cols (list, optional): A list of col coordinates. Defaults to None.
        boxes (list, optional): A list of (row, col) coordinates in the format of [[left, top, right, bottom], [left, top, right, bottom], ...]
        zs: zs (list or float, optional): Height associated with coordinates. Primarily used for RPC based coordinate transformations.
        offset (str, optional): Determines if the returned coordinates are for the center of the pixel or for a corner.
        output (str, optional): The output vector file path. Defaults to None.
        dst_crs (str, optional): The destination CRS. Defaults to "EPSG:4326".
        **kwargs: Additional keyword arguments to pass to rasterio.transform.xy.

    Returns:
        A list of (x, y) coordinates.
    """

    if boxes is not None:
        rows = []
        cols = []

        for box in boxes:
            rows.append(box[1])
            rows.append(box[3])
            cols.append(box[0])
            cols.append(box[2])

    if rows is None or cols is None:
        raise ValueError("rows and cols must be provided.")

    with rasterio.open(src_fp) as src:
        xs, ys = rasterio.transform.xy(src.transform, rows, cols, zs, offset, **kwargs)
        src_crs = src.crs

    if boxes is None:
        return [[x, y] for x, y in zip(xs, ys)]
    else:
        result = [[xs[i], ys[i + 1], xs[i + 1], ys[i]] for i in range(0, len(xs), 2)]

        if output is not None:
            boxes_to_vector(result, src_crs, dst_crs, output)
        else:
            return result

sam_map_gui(sam, basemap='SATELLITE', repeat_mode=True, out_dir=None, **kwargs)

Display the SAM Map GUI.

Parameters:

Name Type Description Default
sam SamGeo required
basemap str

The basemap to use. Defaults to "SATELLITE".

'SATELLITE'
repeat_mode bool

Whether to use the repeat mode for the draw control. Defaults to True.

True
out_dir str

The output directory. Defaults to None.

None
Source code in samgeo/common.py
def sam_map_gui(sam, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwargs):
    """Display the SAM Map GUI.

    Args:
        sam (SamGeo):
        basemap (str, optional): The basemap to use. Defaults to "SATELLITE".
        repeat_mode (bool, optional): Whether to use the repeat mode for the draw control. Defaults to True.
        out_dir (str, optional): The output directory. Defaults to None.

    """
    try:
        import shutil
        import tempfile
        import leafmap
        import ipyleaflet
        import ipyevents
        import ipywidgets as widgets
        from ipyfilechooser import FileChooser
    except ImportError:
        raise ImportError(
            "The sam_map function requires the leafmap package. Please install it first."
        )

    if out_dir is None:
        out_dir = tempfile.gettempdir()

    m = leafmap.Map(repeat_mode=repeat_mode, **kwargs)
    m.default_style = {"cursor": "crosshair"}
    m.add_basemap(basemap, show=False)

    # Skip the image layer if localtileserver is not available
    try:
        m.add_raster(sam.source, layer_name="Image")
    except:
        pass

    m.fg_markers = []
    m.bg_markers = []

    fg_layer = ipyleaflet.LayerGroup(layers=m.fg_markers, name="Foreground")
    bg_layer = ipyleaflet.LayerGroup(layers=m.bg_markers, name="Background")
    m.add(fg_layer)
    m.add(bg_layer)
    m.fg_layer = fg_layer
    m.bg_layer = bg_layer

    widget_width = "280px"
    button_width = "90px"
    padding = "0px 0px 0px 4px"  # upper, right, bottom, left
    style = {"description_width": "initial"}

    toolbar_button = widgets.ToggleButton(
        value=True,
        tooltip="Toolbar",
        icon="gear",
        layout=widgets.Layout(width="28px", height="28px", padding=padding),
    )

    close_button = widgets.ToggleButton(
        value=False,
        tooltip="Close the tool",
        icon="times",
        button_style="primary",
        layout=widgets.Layout(height="28px", width="28px", padding=padding),
    )

    plus_button = widgets.ToggleButton(
        value=False,
        tooltip="Load foreground points",
        icon="plus-circle",
        button_style="primary",
        layout=widgets.Layout(height="28px", width="28px", padding=padding),
    )

    minus_button = widgets.ToggleButton(
        value=False,
        tooltip="Load background points",
        icon="minus-circle",
        button_style="primary",
        layout=widgets.Layout(height="28px", width="28px", padding=padding),
    )

    radio_buttons = widgets.RadioButtons(
        options=["Foreground", "Background"],
        description="Class Type:",
        disabled=False,
        style=style,
        layout=widgets.Layout(width=widget_width, padding=padding),
    )

    fg_count = widgets.IntText(
        value=0,
        description="Foreground #:",
        disabled=True,
        style=style,
        layout=widgets.Layout(width="135px", padding=padding),
    )
    bg_count = widgets.IntText(
        value=0,
        description="Background #:",
        disabled=True,
        style=style,
        layout=widgets.Layout(width="135px", padding=padding),
    )

    segment_button = widgets.ToggleButton(
        description="Segment",
        value=False,
        button_style="primary",
        layout=widgets.Layout(padding=padding),
    )

    save_button = widgets.ToggleButton(
        description="Save", value=False, button_style="primary"
    )

    reset_button = widgets.ToggleButton(
        description="Reset", value=False, button_style="primary"
    )
    segment_button.layout.width = button_width
    save_button.layout.width = button_width
    reset_button.layout.width = button_width

    opacity_slider = widgets.FloatSlider(
        description="Mask opacity:",
        min=0,
        max=1,
        value=0.5,
        readout=True,
        continuous_update=True,
        layout=widgets.Layout(width=widget_width, padding=padding),
        style=style,
    )

    rectangular = widgets.Checkbox(
        value=False,
        description="Regularize",
        layout=widgets.Layout(width="130px", padding=padding),
        style=style,
    )

    colorpicker = widgets.ColorPicker(
        concise=False,
        description="Color",
        value="#ffff00",
        layout=widgets.Layout(width="140px", padding=padding),
        style=style,
    )

    buttons = widgets.VBox(
        [
            radio_buttons,
            widgets.HBox([fg_count, bg_count]),
            opacity_slider,
            widgets.HBox([rectangular, colorpicker]),
            widgets.HBox(
                [segment_button, save_button, reset_button],
                layout=widgets.Layout(padding="0px 4px 0px 4px"),
            ),
        ]
    )

    def opacity_changed(change):
        if change["new"]:
            mask_layer = m.find_layer("Masks")
            if mask_layer is not None:
                mask_layer.interact(opacity=opacity_slider.value)

    opacity_slider.observe(opacity_changed, "value")

    output = widgets.Output(
        layout=widgets.Layout(
            width=widget_width, padding=padding, max_width=widget_width
        )
    )

    toolbar_header = widgets.HBox()
    toolbar_header.children = [close_button, plus_button, minus_button, toolbar_button]
    toolbar_footer = widgets.VBox()
    toolbar_footer.children = [
        buttons,
        output,
    ]
    toolbar_widget = widgets.VBox()
    toolbar_widget.children = [toolbar_header, toolbar_footer]

    toolbar_event = ipyevents.Event(
        source=toolbar_widget, watched_events=["mouseenter", "mouseleave"]
    )

    def marker_callback(chooser):
        with output:
            if chooser.selected is not None:
                try:
                    gdf = gpd.read_file(chooser.selected)
                    centroids = gdf.centroid
                    coords = [[point.x, point.y] for point in centroids]
                    for coord in coords:
                        if plus_button.value:
                            if is_colab():  # Colab does not support AwesomeIcon
                                marker = ipyleaflet.CircleMarker(
                                    location=(coord[1], coord[0]),
                                    radius=2,
                                    color="green",
                                    fill_color="green",
                                )
                            else:
                                marker = ipyleaflet.Marker(
                                    location=[coord[1], coord[0]],
                                    icon=ipyleaflet.AwesomeIcon(
                                        name="plus-circle",
                                        marker_color="green",
                                        icon_color="darkred",
                                    ),
                                )
                            m.fg_layer.add(marker)
                            m.fg_markers.append(marker)
                            fg_count.value = len(m.fg_markers)
                        elif minus_button.value:
                            if is_colab():
                                marker = ipyleaflet.CircleMarker(
                                    location=(coord[1], coord[0]),
                                    radius=2,
                                    color="red",
                                    fill_color="red",
                                )
                            else:
                                marker = ipyleaflet.Marker(
                                    location=[coord[1], coord[0]],
                                    icon=ipyleaflet.AwesomeIcon(
                                        name="minus-circle",
                                        marker_color="red",
                                        icon_color="darkred",
                                    ),
                                )
                            m.bg_layer.add(marker)
                            m.bg_markers.append(marker)
                            bg_count.value = len(m.bg_markers)

                except Exception as e:
                    print(e)

            if m.marker_control in m.controls:
                m.remove_control(m.marker_control)
                delattr(m, "marker_control")

            plus_button.value = False
            minus_button.value = False

    def marker_button_click(change):
        if change["new"]:
            sandbox_path = os.environ.get("SANDBOX_PATH")
            filechooser = FileChooser(
                path=os.getcwd(),
                sandbox_path=sandbox_path,
                layout=widgets.Layout(width="454px"),
            )
            filechooser.use_dir_icons = True
            filechooser.filter_pattern = ["*.shp", "*.geojson", "*.gpkg"]
            filechooser.register_callback(marker_callback)
            marker_control = ipyleaflet.WidgetControl(
                widget=filechooser, position="topright"
            )
            m.add_control(marker_control)
            m.marker_control = marker_control
        else:
            if hasattr(m, "marker_control") and m.marker_control in m.controls:
                m.remove_control(m.marker_control)
                m.marker_control.close()

    plus_button.observe(marker_button_click, "value")
    minus_button.observe(marker_button_click, "value")

    def handle_toolbar_event(event):
        if event["type"] == "mouseenter":
            toolbar_widget.children = [toolbar_header, toolbar_footer]
        elif event["type"] == "mouseleave":
            if not toolbar_button.value:
                toolbar_widget.children = [toolbar_button]
                toolbar_button.value = False
                close_button.value = False

    toolbar_event.on_dom_event(handle_toolbar_event)

    def toolbar_btn_click(change):
        if change["new"]:
            close_button.value = False
            toolbar_widget.children = [toolbar_header, toolbar_footer]
        else:
            if not close_button.value:
                toolbar_widget.children = [toolbar_button]

    toolbar_button.observe(toolbar_btn_click, "value")

    def close_btn_click(change):
        if change["new"]:
            toolbar_button.value = False
            if m.toolbar_control in m.controls:
                m.remove_control(m.toolbar_control)
            toolbar_widget.close()

    close_button.observe(close_btn_click, "value")

    def handle_map_interaction(**kwargs):
        try:
            if kwargs.get("type") == "click":
                latlon = kwargs.get("coordinates")
                if radio_buttons.value == "Foreground":
                    if is_colab():
                        marker = ipyleaflet.CircleMarker(
                            location=tuple(latlon),
                            radius=2,
                            color="green",
                            fill_color="green",
                        )
                    else:
                        marker = ipyleaflet.Marker(
                            location=latlon,
                            icon=ipyleaflet.AwesomeIcon(
                                name="plus-circle",
                                marker_color="green",
                                icon_color="darkred",
                            ),
                        )
                    fg_layer.add(marker)
                    m.fg_markers.append(marker)
                    fg_count.value = len(m.fg_markers)
                elif radio_buttons.value == "Background":
                    if is_colab():
                        marker = ipyleaflet.CircleMarker(
                            location=tuple(latlon),
                            radius=2,
                            color="red",
                            fill_color="red",
                        )
                    else:
                        marker = ipyleaflet.Marker(
                            location=latlon,
                            icon=ipyleaflet.AwesomeIcon(
                                name="minus-circle",
                                marker_color="red",
                                icon_color="darkred",
                            ),
                        )
                    bg_layer.add(marker)
                    m.bg_markers.append(marker)
                    bg_count.value = len(m.bg_markers)

        except (TypeError, KeyError) as e:
            print(f"Error handling map interaction: {e}")

    m.on_interaction(handle_map_interaction)

    def segment_button_click(change):
        if change["new"]:
            segment_button.value = False
            with output:
                output.clear_output()
                if len(m.fg_markers) == 0:
                    print("Please add some foreground markers.")
                    segment_button.value = False
                    return

                else:
                    try:
                        fg_points = [
                            [marker.location[1], marker.location[0]]
                            for marker in m.fg_markers
                        ]
                        bg_points = [
                            [marker.location[1], marker.location[0]]
                            for marker in m.bg_markers
                        ]
                        point_coords = fg_points + bg_points
                        point_labels = [1] * len(fg_points) + [0] * len(bg_points)

                        filename = f"masks_{random_string()}.tif"
                        filename = os.path.join(out_dir, filename)
                        sam.predict(
                            point_coords=point_coords,
                            point_labels=point_labels,
                            point_crs="EPSG:4326",
                            output=filename,
                        )
                        if m.find_layer("Masks") is not None:
                            m.remove_layer(m.find_layer("Masks"))
                        if m.find_layer("Regularized") is not None:
                            m.remove_layer(m.find_layer("Regularized"))

                        if hasattr(sam, "prediction_fp") and os.path.exists(
                            sam.prediction_fp
                        ):
                            try:
                                os.remove(sam.prediction_fp)
                            except:
                                pass

                        # Skip the image layer if localtileserver is not available
                        try:
                            m.add_raster(
                                filename,
                                nodata=0,
                                cmap="Blues",
                                opacity=opacity_slider.value,
                                layer_name="Masks",
                                zoom_to_layer=False,
                            )

                            if rectangular.value:
                                vector = filename.replace(".tif", ".gpkg")
                                vector_rec = filename.replace(".tif", "_rect.gpkg")
                                raster_to_vector(filename, vector)
                                regularize(vector, vector_rec)
                                vector_style = {"color": colorpicker.value}
                                m.add_vector(
                                    vector_rec,
                                    layer_name="Regularized",
                                    style=vector_style,
                                    info_mode=None,
                                    zoom_to_layer=False,
                                )

                        except:
                            pass
                        output.clear_output()
                        segment_button.value = False
                        sam.prediction_fp = filename
                    except Exception as e:
                        segment_button.value = False
                        print(e)

    segment_button.observe(segment_button_click, "value")

    def filechooser_callback(chooser):
        with output:
            if chooser.selected is not None:
                try:
                    filename = chooser.selected
                    shutil.copy(sam.prediction_fp, filename)
                    vector = filename.replace(".tif", ".gpkg")
                    raster_to_vector(filename, vector)
                    if rectangular.value:
                        vector_rec = filename.replace(".tif", "_rect.gpkg")
                        regularize(vector, vector_rec)

                    fg_points = [
                        [marker.location[1], marker.location[0]]
                        for marker in m.fg_markers
                    ]
                    bg_points = [
                        [marker.location[1], marker.location[0]]
                        for marker in m.bg_markers
                    ]

                    coords_to_geojson(
                        fg_points, filename.replace(".tif", "_fg_markers.geojson")
                    )
                    coords_to_geojson(
                        bg_points, filename.replace(".tif", "_bg_markers.geojson")
                    )

                except Exception as e:
                    print(e)

                if hasattr(m, "save_control") and m.save_control in m.controls:
                    m.remove_control(m.save_control)
                    delattr(m, "save_control")
                save_button.value = False

    def save_button_click(change):
        if change["new"]:
            with output:
                sandbox_path = os.environ.get("SANDBOX_PATH")
                filechooser = FileChooser(
                    path=os.getcwd(),
                    filename="masks.tif",
                    sandbox_path=sandbox_path,
                    layout=widgets.Layout(width="454px"),
                )
                filechooser.use_dir_icons = True
                filechooser.filter_pattern = ["*.tif"]
                filechooser.register_callback(filechooser_callback)
                save_control = ipyleaflet.WidgetControl(
                    widget=filechooser, position="topright"
                )
                m.add_control(save_control)
                m.save_control = save_control
        else:
            if hasattr(m, "save_control") and m.save_control in m.controls:
                m.remove_control(m.save_control)
                delattr(m, "save_control")

    save_button.observe(save_button_click, "value")

    def reset_button_click(change):
        if change["new"]:
            segment_button.value = False
            reset_button.value = False
            opacity_slider.value = 0.5
            rectangular.value = False
            colorpicker.value = "#ffff00"
            output.clear_output()
            try:
                m.remove_layer(m.find_layer("Masks"))
                if m.find_layer("Regularized") is not None:
                    m.remove_layer(m.find_layer("Regularized"))
                m.clear_drawings()
                if hasattr(m, "fg_markers"):
                    m.user_rois = None
                    m.fg_markers = []
                    m.bg_markers = []
                    m.fg_layer.clear_layers()
                    m.bg_layer.clear_layers()
                    fg_count.value = 0
                    bg_count.value = 0
                try:
                    os.remove(sam.prediction_fp)
                except:
                    pass
            except:
                pass

    reset_button.observe(reset_button_click, "value")

    toolbar_control = ipyleaflet.WidgetControl(
        widget=toolbar_widget, position="topright"
    )
    m.add_control(toolbar_control)
    m.toolbar_control = toolbar_control

    return m

show_canvas(image, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5)

Show a canvas to collect foreground and background points.

Parameters:

Name Type Description Default
image str | np.ndarray

The input image.

required
fg_color tuple

The color for the foreground points. Defaults to (0, 255, 0).

(0, 255, 0)
bg_color tuple

The color for the background points. Defaults to (0, 0, 255).

(0, 0, 255)
radius int

The radius of the points. Defaults to 5.

5

Returns:

Type Description
tuple

A tuple of two lists of foreground and background points.

Source code in samgeo/common.py
def show_canvas(image, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):
    """Show a canvas to collect foreground and background points.

    Args:
        image (str | np.ndarray): The input image.
        fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).
        bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).
        radius (int, optional): The radius of the points. Defaults to 5.

    Returns:
        tuple: A tuple of two lists of foreground and background points.
    """
    if isinstance(image, str):
        if image.startswith("http"):
            image = download_file(image)

        image = cv2.imread(image)
    elif isinstance(image, np.ndarray):
        pass
    else:
        raise ValueError("Input image must be a URL or a NumPy array.")

    # Create an empty list to store the mouse click coordinates
    left_clicks = []
    right_clicks = []

    # Create a mouse callback function
    def get_mouse_coordinates(event, x, y):
        if event == cv2.EVENT_LBUTTONDOWN:
            # Append the coordinates to the mouse_clicks list
            left_clicks.append((x, y))

            # Draw a green circle at the mouse click coordinates
            cv2.circle(image, (x, y), radius, fg_color, -1)

            # Show the updated image with the circle
            cv2.imshow("Image", image)

        elif event == cv2.EVENT_RBUTTONDOWN:
            # Append the coordinates to the mouse_clicks list
            right_clicks.append((x, y))

            # Draw a red circle at the mouse click coordinates
            cv2.circle(image, (x, y), radius, bg_color, -1)

            # Show the updated image with the circle
            cv2.imshow("Image", image)

    # Create a window to display the image
    cv2.namedWindow("Image")

    # Set the mouse callback function for the window
    cv2.setMouseCallback("Image", get_mouse_coordinates)

    # Display the image in the window
    cv2.imshow("Image", image)

    # Wait for a key press to exit
    cv2.waitKey(0)

    # Destroy the window
    cv2.destroyAllWindows()

    return left_clicks, right_clicks

show_image_gui(path)

Show an interactive GUI to explore images.

Parameters:

Name Type Description Default
path str

The path to the image file or directory containing images.

required
Source code in samgeo/common.py
def show_image_gui(path: str) -> None:
    """Show an interactive GUI to explore images.
    Args:
        path (str): The path to the image file or directory containing images.
    """

    from PIL import Image
    from ipywidgets import interact, IntSlider
    import matplotlib

    def setup_interactive_matplotlib():
        """Sets up ipympl backend for interactive plotting in Jupyter."""
        # Use the ipympl backend for interactive plotting
        try:
            import ipympl

            matplotlib.use("module://ipympl.backend_nbagg")
        except ImportError:
            print("ipympl is not installed. Falling back to default backend.")

    def load_images_from_folder(folder):
        """Load all images from the specified folder."""
        images = []
        filenames = []
        for filename in sorted(os.listdir(folder)):
            if filename.endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")):
                img = Image.open(os.path.join(folder, filename))
                img_array = np.array(img)
                images.append(img_array)
                filenames.append(filename)
        return images, filenames

    def load_single_image(image_path):
        """Load a single image from the specified image file path."""
        img = Image.open(image_path)
        img_array = np.array(img)
        return [img_array], [
            os.path.basename(image_path)
        ]  # Return as lists for consistency

    # Check if the input path is a file or a directory
    if os.path.isfile(path):
        images, filenames = load_single_image(path)
    elif os.path.isdir(path):
        images, filenames = load_images_from_folder(path)
    else:
        print("Invalid path. Please provide a valid image file or directory.")
        return

    total_images = len(images)

    if total_images == 0:
        print("No images found.")
        return

    # Set up interactive plotting
    setup_interactive_matplotlib()

    fig, ax = plt.subplots()
    fig.canvas.toolbar_visible = True
    fig.canvas.header_visible = False
    fig.canvas.footer_visible = True

    # Display the first image initially
    im_display = ax.imshow(images[0])
    ax.set_title(f"Image: {filenames[0]}")
    plt.tight_layout()

    # Function to update the image when the slider changes (for multiple images)
    def update_image(image_index):
        im_display.set_data(images[image_index])
        ax.set_title(f"Image: {filenames[image_index]}")
        fig.canvas.draw()

    # Function to show pixel information on click
    def onclick(event):
        if event.xdata is not None and event.ydata is not None:
            col = int(event.xdata)
            row = int(event.ydata)
            pixel_value = images[current_image_index][
                row, col
            ]  # Use current image index
            ax.set_title(
                f"Image: {filenames[current_image_index]} - X: {col}, Y: {row}, Pixel Value: {pixel_value}"
            )
            fig.canvas.draw()

    # Track the current image index (whether from slider or for single image)
    current_image_index = 0

    # Slider widget to choose between images (only if there is more than one image)
    if total_images > 1:
        slider = IntSlider(min=0, max=total_images - 1, step=1, description="Image")

        def on_slider_change(change):
            nonlocal current_image_index
            current_image_index = change["new"]  # Update current image index
            update_image(current_image_index)

        slider.observe(on_slider_change, names="value")
        fig.canvas.mpl_connect("button_press_event", onclick)
        interact(update_image, image_index=slider)
    else:
        # If there's only one image, no need for a slider, just show pixel info on click
        fig.canvas.mpl_connect("button_press_event", onclick)

    # Show the plot
    plt.show()

split_raster(filename, out_dir, tile_size=256, overlap=0)

Split a raster into tiles.

Parameters:

Name Type Description Default
filename str

The path or http URL to the raster file.

required
out_dir str

The path to the output directory.

required
tile_size int | tuple

The size of the tiles. Can be an integer or a tuple of (width, height). Defaults to 256.

256
overlap int

The number of pixels to overlap between tiles. Defaults to 0.

0

Exceptions:

Type Description
ImportError

Raised if GDAL is not installed.

Source code in samgeo/common.py
def split_raster(filename, out_dir, tile_size=256, overlap=0):
    """Split a raster into tiles.

    Args:
        filename (str): The path or http URL to the raster file.
        out_dir (str): The path to the output directory.
        tile_size (int | tuple, optional): The size of the tiles. Can be an integer or a tuple of (width, height). Defaults to 256.
        overlap (int, optional): The number of pixels to overlap between tiles. Defaults to 0.

    Raises:
        ImportError: Raised if GDAL is not installed.
    """

    try:
        from osgeo import gdal
    except ImportError:
        raise ImportError(
            "GDAL is required to use this function. Install it with `conda install gdal -c conda-forge`"
        )

    if isinstance(filename, str):
        if filename.startswith("http"):
            output = filename.split("/")[-1]
            download_file(filename, output)
            filename = output

    # Open the input GeoTIFF file
    ds = gdal.Open(filename)

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    if isinstance(tile_size, int):
        tile_width = tile_size
        tile_height = tile_size
    elif isinstance(tile_size, tuple):
        tile_width = tile_size[0]
        tile_height = tile_size[1]

    # Get the size of the input raster
    width = ds.RasterXSize
    height = ds.RasterYSize

    # Calculate the number of tiles needed in both directions, taking into account the overlap
    num_tiles_x = (width - overlap) // (tile_width - overlap) + int(
        (width - overlap) % (tile_width - overlap) > 0
    )
    num_tiles_y = (height - overlap) // (tile_height - overlap) + int(
        (height - overlap) % (tile_height - overlap) > 0
    )

    # Get the georeferencing information of the input raster
    geotransform = ds.GetGeoTransform()

    # Loop over all the tiles
    for i in range(num_tiles_x):
        for j in range(num_tiles_y):
            # Calculate the pixel coordinates of the tile, taking into account the overlap and clamping to the edge of the raster
            x_min = i * (tile_width - overlap)
            y_min = j * (tile_height - overlap)
            x_max = min(x_min + tile_width, width)
            y_max = min(y_min + tile_height, height)

            # Adjust the size of the last tile in each row and column to include any remaining pixels
            if i == num_tiles_x - 1:
                x_min = max(x_max - tile_width, 0)
            if j == num_tiles_y - 1:
                y_min = max(y_max - tile_height, 0)

            # Calculate the size of the tile, taking into account the overlap
            tile_width = x_max - x_min
            tile_height = y_max - y_min

            # Set the output file name
            output_file = f"{out_dir}/tile_{i}_{j}.tif"

            # Create a new dataset for the tile
            driver = gdal.GetDriverByName("GTiff")
            tile_ds = driver.Create(
                output_file,
                tile_width,
                tile_height,
                ds.RasterCount,
                ds.GetRasterBand(1).DataType,
            )

            # Calculate the georeferencing information for the output tile
            tile_geotransform = (
                geotransform[0] + x_min * geotransform[1],
                geotransform[1],
                0,
                geotransform[3] + y_min * geotransform[5],
                0,
                geotransform[5],
            )

            # Set the geotransform and projection of the tile
            tile_ds.SetGeoTransform(tile_geotransform)
            tile_ds.SetProjection(ds.GetProjection())

            # Read the data from the input raster band(s) and write it to the tile band(s)
            for k in range(ds.RasterCount):
                band = ds.GetRasterBand(k + 1)
                tile_band = tile_ds.GetRasterBand(k + 1)
                tile_data = band.ReadAsArray(x_min, y_min, tile_width, tile_height)
                tile_band.WriteArray(tile_data)

            # Close the tile dataset
            tile_ds = None

    # Close the input dataset
    ds = None

temp_file_path(extension)

Returns a temporary file path.

Parameters:

Name Type Description Default
extension str

The file extension.

required

Returns:

Type Description
str

The temporary file path.

Source code in samgeo/common.py
def temp_file_path(extension):
    """Returns a temporary file path.

    Args:
        extension (str): The file extension.

    Returns:
        str: The temporary file path.
    """

    import tempfile
    import uuid

    if not extension.startswith("."):
        extension = "." + extension
    file_id = str(uuid.uuid4())
    file_path = os.path.join(tempfile.gettempdir(), f"{file_id}{extension}")

    return file_path

text_sam_gui(sam, basemap='SATELLITE', out_dir=None, box_threshold=0.25, text_threshold=0.25, cmap='viridis', opacity=0.5, **kwargs)

Display the SAM Map GUI.

Parameters:

Name Type Description Default
sam SamGeo required
basemap str

The basemap to use. Defaults to "SATELLITE".

'SATELLITE'
out_dir str

The output directory. Defaults to None.

None
Source code in samgeo/common.py
def text_sam_gui(
    sam,
    basemap="SATELLITE",
    out_dir=None,
    box_threshold=0.25,
    text_threshold=0.25,
    cmap="viridis",
    opacity=0.5,
    **kwargs,
):
    """Display the SAM Map GUI.

    Args:
        sam (SamGeo):
        basemap (str, optional): The basemap to use. Defaults to "SATELLITE".
        out_dir (str, optional): The output directory. Defaults to None.

    """
    try:
        import shutil
        import tempfile
        import leafmap
        import ipyleaflet
        import ipyevents
        import ipywidgets as widgets
        import leafmap.colormaps as cm
        from ipyfilechooser import FileChooser
    except ImportError:
        raise ImportError(
            "The sam_map function requires the leafmap package. Please install it first."
        )

    if out_dir is None:
        out_dir = tempfile.gettempdir()

    m = leafmap.Map(**kwargs)
    m.default_style = {"cursor": "crosshair"}
    m.add_basemap(basemap, show=False)

    # Skip the image layer if localtileserver is not available
    try:
        m.add_raster(sam.source, layer_name="Image")
    except:
        pass

    widget_width = "280px"
    button_width = "90px"
    padding = "0px 4px 0px 4px"  # upper, right, bottom, left
    style = {"description_width": "initial"}

    toolbar_button = widgets.ToggleButton(
        value=True,
        tooltip="Toolbar",
        icon="gear",
        layout=widgets.Layout(width="28px", height="28px", padding="0px 0px 0px 4px"),
    )

    close_button = widgets.ToggleButton(
        value=False,
        tooltip="Close the tool",
        icon="times",
        button_style="primary",
        layout=widgets.Layout(height="28px", width="28px", padding="0px 0px 0px 4px"),
    )

    text_prompt = widgets.Text(
        description="Text prompt:",
        style=style,
        layout=widgets.Layout(width=widget_width, padding=padding),
    )

    box_slider = widgets.FloatSlider(
        description="Box threshold:",
        min=0,
        max=1,
        value=box_threshold,
        step=0.01,
        readout=True,
        continuous_update=True,
        layout=widgets.Layout(width=widget_width, padding=padding),
        style=style,
    )

    text_slider = widgets.FloatSlider(
        description="Text threshold:",
        min=0,
        max=1,
        step=0.01,
        value=text_threshold,
        readout=True,
        continuous_update=True,
        layout=widgets.Layout(width=widget_width, padding=padding),
        style=style,
    )

    cmap_dropdown = widgets.Dropdown(
        description="Palette:",
        options=cm.list_colormaps(),
        value=cmap,
        style=style,
        layout=widgets.Layout(width=widget_width, padding=padding),
    )

    opacity_slider = widgets.FloatSlider(
        description="Opacity:",
        min=0,
        max=1,
        value=opacity,
        readout=True,
        continuous_update=True,
        layout=widgets.Layout(width=widget_width, padding=padding),
        style=style,
    )

    def opacity_changed(change):
        if change["new"]:
            if hasattr(m, "layer_name"):
                mask_layer = m.find_layer(m.layer_name)
                if mask_layer is not None:
                    mask_layer.interact(opacity=opacity_slider.value)

    opacity_slider.observe(opacity_changed, "value")

    rectangular = widgets.Checkbox(
        value=False,
        description="Regularize",
        layout=widgets.Layout(width="130px", padding=padding),
        style=style,
    )

    colorpicker = widgets.ColorPicker(
        concise=False,
        description="Color",
        value="#ffff00",
        layout=widgets.Layout(width="140px", padding=padding),
        style=style,
    )

    segment_button = widgets.ToggleButton(
        description="Segment",
        value=False,
        button_style="primary",
        layout=widgets.Layout(padding=padding),
    )

    save_button = widgets.ToggleButton(
        description="Save", value=False, button_style="primary"
    )

    reset_button = widgets.ToggleButton(
        description="Reset", value=False, button_style="primary"
    )
    segment_button.layout.width = button_width
    save_button.layout.width = button_width
    reset_button.layout.width = button_width

    output = widgets.Output(
        layout=widgets.Layout(
            width=widget_width, padding=padding, max_width=widget_width
        )
    )

    toolbar_header = widgets.HBox()
    toolbar_header.children = [close_button, toolbar_button]
    toolbar_footer = widgets.VBox()
    toolbar_footer.children = [
        text_prompt,
        box_slider,
        text_slider,
        cmap_dropdown,
        opacity_slider,
        widgets.HBox([rectangular, colorpicker]),
        widgets.HBox(
            [segment_button, save_button, reset_button],
            layout=widgets.Layout(padding="0px 4px 0px 4px"),
        ),
        output,
    ]
    toolbar_widget = widgets.VBox()
    toolbar_widget.children = [toolbar_header, toolbar_footer]

    toolbar_event = ipyevents.Event(
        source=toolbar_widget, watched_events=["mouseenter", "mouseleave"]
    )

    def handle_toolbar_event(event):
        if event["type"] == "mouseenter":
            toolbar_widget.children = [toolbar_header, toolbar_footer]
        elif event["type"] == "mouseleave":
            if not toolbar_button.value:
                toolbar_widget.children = [toolbar_button]
                toolbar_button.value = False
                close_button.value = False

    toolbar_event.on_dom_event(handle_toolbar_event)

    def toolbar_btn_click(change):
        if change["new"]:
            close_button.value = False
            toolbar_widget.children = [toolbar_header, toolbar_footer]
        else:
            if not close_button.value:
                toolbar_widget.children = [toolbar_button]

    toolbar_button.observe(toolbar_btn_click, "value")

    def close_btn_click(change):
        if change["new"]:
            toolbar_button.value = False
            if m.toolbar_control in m.controls:
                m.remove_control(m.toolbar_control)
            toolbar_widget.close()

    close_button.observe(close_btn_click, "value")

    def segment_button_click(change):
        if change["new"]:
            segment_button.value = False
            with output:
                output.clear_output()
                if len(text_prompt.value) == 0:
                    print("Please enter a text prompt first.")
                elif sam.source is None:
                    print("Please run sam.set_image() first.")
                else:
                    print("Segmenting...")
                    layer_name = text_prompt.value.replace(" ", "_")
                    filename = os.path.join(
                        out_dir, f"{layer_name}_{random_string()}.tif"
                    )
                    try:
                        sam.predict(
                            sam.source,
                            text_prompt.value,
                            box_slider.value,
                            text_slider.value,
                            output=filename,
                        )
                        sam.output = filename
                        if m.find_layer(layer_name) is not None:
                            m.remove_layer(m.find_layer(layer_name))
                        if m.find_layer(f"{layer_name}_rect") is not None:
                            m.remove_layer(m.find_layer(f"{layer_name} Regularized"))
                    except Exception as e:
                        output.clear_output()
                        print(e)
                    if os.path.exists(filename):
                        try:
                            m.add_raster(
                                filename,
                                layer_name=layer_name,
                                palette=cmap_dropdown.value,
                                opacity=opacity_slider.value,
                                nodata=0,
                                zoom_to_layer=False,
                            )
                            m.layer_name = layer_name

                            if rectangular.value:
                                vector = filename.replace(".tif", ".gpkg")
                                vector_rec = filename.replace(".tif", "_rect.gpkg")
                                raster_to_vector(filename, vector)
                                regularize(vector, vector_rec)
                                vector_style = {"color": colorpicker.value}
                                m.add_vector(
                                    vector_rec,
                                    layer_name=f"{layer_name} Regularized",
                                    style=vector_style,
                                    info_mode=None,
                                    zoom_to_layer=False,
                                )

                            output.clear_output()
                        except Exception as e:
                            print(e)

    segment_button.observe(segment_button_click, "value")

    def filechooser_callback(chooser):
        with output:
            if chooser.selected is not None:
                try:
                    filename = chooser.selected
                    shutil.copy(sam.output, filename)
                    vector = filename.replace(".tif", ".gpkg")
                    raster_to_vector(filename, vector)
                    if rectangular.value:
                        vector_rec = filename.replace(".tif", "_rect.gpkg")
                        regularize(vector, vector_rec)
                except Exception as e:
                    print(e)

                if hasattr(m, "save_control") and m.save_control in m.controls:
                    m.remove_control(m.save_control)
                    delattr(m, "save_control")
                save_button.value = False

    def save_button_click(change):
        if change["new"]:
            with output:
                output.clear_output()
                if not hasattr(m, "layer_name"):
                    print("Please click the Segment button first.")
                else:
                    sandbox_path = os.environ.get("SANDBOX_PATH")
                    filechooser = FileChooser(
                        path=os.getcwd(),
                        filename=f"{m.layer_name}.tif",
                        sandbox_path=sandbox_path,
                        layout=widgets.Layout(width="454px"),
                    )
                    filechooser.use_dir_icons = True
                    filechooser.filter_pattern = ["*.tif"]
                    filechooser.register_callback(filechooser_callback)
                    save_control = ipyleaflet.WidgetControl(
                        widget=filechooser, position="topright"
                    )
                    m.add_control(save_control)
                    m.save_control = save_control

        else:
            if hasattr(m, "save_control") and m.save_control in m.controls:
                m.remove_control(m.save_control)
                delattr(m, "save_control")

    save_button.observe(save_button_click, "value")

    def reset_button_click(change):
        if change["new"]:
            segment_button.value = False
            save_button.value = False
            reset_button.value = False
            opacity_slider.value = 0.5
            box_slider.value = 0.25
            text_slider.value = 0.25
            cmap_dropdown.value = "viridis"
            text_prompt.value = ""
            output.clear_output()
            try:
                if hasattr(m, "layer_name") and m.find_layer(m.layer_name) is not None:
                    m.remove_layer(m.find_layer(m.layer_name))
                m.clear_drawings()
            except:
                pass

    reset_button.observe(reset_button_click, "value")

    toolbar_control = ipyleaflet.WidgetControl(
        widget=toolbar_widget, position="topright"
    )
    m.add_control(toolbar_control)
    m.toolbar_control = toolbar_control

    return m

tms_to_geotiff(output, bbox, zoom=None, resolution=None, source='OpenStreetMap', crs='EPSG:3857', to_cog=False, return_image=False, overwrite=False, quiet=False, **kwargs)

Download TMS tiles and convert them to a GeoTIFF. The source is adapted from https://github.com/gumblex/tms2geotiff. Credits to the GitHub user @gumblex.

Parameters:

Name Type Description Default
output str

The output GeoTIFF file.

required
bbox list

The bounding box [minx, miny, maxx, maxy], e.g., [-122.5216, 37.733, -122.3661, 37.8095]

required
zoom int

The map zoom level. Defaults to None.

None
resolution float

The resolution in meters. Defaults to None.

None
source str

The tile source. It can be one of the following: "OPENSTREETMAP", "ROADMAP", "SATELLITE", "TERRAIN", "HYBRID", or an HTTP URL. Defaults to "OpenStreetMap".

'OpenStreetMap'
crs str

The output CRS. Defaults to "EPSG:3857".

'EPSG:3857'
to_cog bool

Convert to Cloud Optimized GeoTIFF. Defaults to False.

False
return_image bool

Return the image as PIL.Image. Defaults to False.

False
overwrite bool

Overwrite the output file if it already exists. Defaults to False.

False
quiet bool

Suppress output. Defaults to False.

False
**kwargs

Additional arguments to pass to gdal.GetDriverByName("GTiff").Create().

{}
Source code in samgeo/common.py
def tms_to_geotiff(
    output,
    bbox,
    zoom=None,
    resolution=None,
    source="OpenStreetMap",
    crs="EPSG:3857",
    to_cog=False,
    return_image=False,
    overwrite=False,
    quiet=False,
    **kwargs,
):
    """Download TMS tiles and convert them to a GeoTIFF. The source is adapted from https://github.com/gumblex/tms2geotiff.
        Credits to the GitHub user @gumblex.

    Args:
        output (str): The output GeoTIFF file.
        bbox (list): The bounding box [minx, miny, maxx, maxy], e.g., [-122.5216, 37.733, -122.3661, 37.8095]
        zoom (int, optional): The map zoom level. Defaults to None.
        resolution (float, optional): The resolution in meters. Defaults to None.
        source (str, optional): The tile source. It can be one of the following: "OPENSTREETMAP", "ROADMAP",
            "SATELLITE", "TERRAIN", "HYBRID", or an HTTP URL. Defaults to "OpenStreetMap".
        crs (str, optional): The output CRS. Defaults to "EPSG:3857".
        to_cog (bool, optional): Convert to Cloud Optimized GeoTIFF. Defaults to False.
        return_image (bool, optional): Return the image as PIL.Image. Defaults to False.
        overwrite (bool, optional): Overwrite the output file if it already exists. Defaults to False.
        quiet (bool, optional): Suppress output. Defaults to False.
        **kwargs: Additional arguments to pass to gdal.GetDriverByName("GTiff").Create().

    """

    import os
    import io
    import math
    import itertools
    import concurrent.futures

    import numpy
    from PIL import Image

    try:
        from osgeo import gdal, osr
    except ImportError:
        raise ImportError("GDAL is not installed. Install it with pip install GDAL")

    try:
        import httpx

        SESSION = httpx.Client()
    except ImportError:
        import requests

        SESSION = requests.Session()

    if not overwrite and os.path.exists(output):
        print(
            f"The output file {output} already exists. Use `overwrite=True` to overwrite it."
        )
        return

    xyz_tiles = {
        "OPENSTREETMAP": "https://tile.openstreetmap.org/{z}/{x}/{y}.png",
        "ROADMAP": "https://mt1.google.com/vt/lyrs=m&x={x}&y={y}&z={z}",
        "SATELLITE": "https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}",
        "TERRAIN": "https://mt1.google.com/vt/lyrs=p&x={x}&y={y}&z={z}",
        "HYBRID": "https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}",
    }

    basemaps = get_basemaps()

    if isinstance(source, str):
        if source.upper() in xyz_tiles:
            source = xyz_tiles[source.upper()]
        elif source in basemaps:
            source = basemaps[source]
        elif source.startswith("http"):
            pass
    else:
        raise ValueError(
            'source must be one of "OpenStreetMap", "ROADMAP", "SATELLITE", "TERRAIN", "HYBRID", or a URL'
        )

    def resolution_to_zoom_level(resolution):
        """
        Convert map resolution in meters to zoom level for Web Mercator (EPSG:3857) tiles.
        """
        # Web Mercator tile size in meters at zoom level 0
        initial_resolution = 156543.03392804097

        # Calculate the zoom level
        zoom_level = math.log2(initial_resolution / resolution)

        return int(zoom_level)

    if isinstance(bbox, list) and len(bbox) == 4:
        west, south, east, north = bbox
    else:
        raise ValueError(
            "bbox must be a list of 4 coordinates in the format of [xmin, ymin, xmax, ymax]"
        )

    if zoom is None and resolution is None:
        raise ValueError("Either zoom or resolution must be provided")
    elif zoom is not None and resolution is not None:
        raise ValueError("Only one of zoom or resolution can be provided")

    if resolution is not None:
        zoom = resolution_to_zoom_level(resolution)

    EARTH_EQUATORIAL_RADIUS = 6378137.0

    Image.MAX_IMAGE_PIXELS = None

    gdal.UseExceptions()
    web_mercator = osr.SpatialReference()
    web_mercator.ImportFromEPSG(3857)

    WKT_3857 = web_mercator.ExportToWkt()

    def from4326_to3857(lat, lon):
        xtile = math.radians(lon) * EARTH_EQUATORIAL_RADIUS
        ytile = (
            math.log(math.tan(math.radians(45 + lat / 2.0))) * EARTH_EQUATORIAL_RADIUS
        )
        return (xtile, ytile)

    def deg2num(lat, lon, zoom):
        lat_r = math.radians(lat)
        n = 2**zoom
        xtile = (lon + 180) / 360 * n
        ytile = (1 - math.log(math.tan(lat_r) + 1 / math.cos(lat_r)) / math.pi) / 2 * n
        return (xtile, ytile)

    def is_empty(im):
        extrema = im.getextrema()
        if len(extrema) >= 3:
            if len(extrema) > 3 and extrema[-1] == (0, 0):
                return True
            for ext in extrema[:3]:
                if ext != (0, 0):
                    return False
            return True
        else:
            return extrema[0] == (0, 0)

    def paste_tile(bigim, base_size, tile, corner_xy, bbox):
        if tile is None:
            return bigim
        im = Image.open(io.BytesIO(tile))
        mode = "RGB" if im.mode == "RGB" else "RGBA"
        size = im.size
        if bigim is None:
            base_size[0] = size[0]
            base_size[1] = size[1]
            newim = Image.new(
                mode, (size[0] * (bbox[2] - bbox[0]), size[1] * (bbox[3] - bbox[1]))
            )
        else:
            newim = bigim

        dx = abs(corner_xy[0] - bbox[0])
        dy = abs(corner_xy[1] - bbox[1])
        xy0 = (size[0] * dx, size[1] * dy)
        if mode == "RGB":
            newim.paste(im, xy0)
        else:
            if im.mode != mode:
                im = im.convert(mode)
            if not is_empty(im):
                newim.paste(im, xy0)
        im.close()
        return newim

    def finish_picture(bigim, base_size, bbox, x0, y0, x1, y1):
        xfrac = x0 - bbox[0]
        yfrac = y0 - bbox[1]
        x2 = round(base_size[0] * xfrac)
        y2 = round(base_size[1] * yfrac)
        imgw = round(base_size[0] * (x1 - x0))
        imgh = round(base_size[1] * (y1 - y0))
        retim = bigim.crop((x2, y2, x2 + imgw, y2 + imgh))
        if retim.mode == "RGBA" and retim.getextrema()[3] == (255, 255):
            retim = retim.convert("RGB")
        bigim.close()
        return retim

    def get_tile(url):
        retry = 3
        while 1:
            try:
                r = SESSION.get(url, timeout=60)
                break
            except Exception:
                retry -= 1
                if not retry:
                    raise
        if r.status_code == 404:
            return None
        elif not r.content:
            return None
        r.raise_for_status()
        return r.content

    def draw_tile(
        source, lat0, lon0, lat1, lon1, zoom, filename, quiet=False, **kwargs
    ):
        x0, y0 = deg2num(lat0, lon0, zoom)
        x1, y1 = deg2num(lat1, lon1, zoom)
        x0, x1 = sorted([x0, x1])
        y0, y1 = sorted([y0, y1])
        corners = tuple(
            itertools.product(
                range(math.floor(x0), math.ceil(x1)),
                range(math.floor(y0), math.ceil(y1)),
            )
        )
        totalnum = len(corners)
        futures = []
        with concurrent.futures.ThreadPoolExecutor(5) as executor:
            for x, y in corners:
                futures.append(
                    executor.submit(get_tile, source.format(z=zoom, x=x, y=y))
                )
            bbox = (math.floor(x0), math.floor(y0), math.ceil(x1), math.ceil(y1))
            bigim = None
            base_size = [256, 256]
            for k, (fut, corner_xy) in enumerate(zip(futures, corners), 1):
                bigim = paste_tile(bigim, base_size, fut.result(), corner_xy, bbox)
                if not quiet:
                    print(
                        f"Downloaded image {str(k).zfill(len(str(totalnum)))}/{totalnum}"
                    )

        if not quiet:
            print("Saving GeoTIFF. Please wait...")
        img = finish_picture(bigim, base_size, bbox, x0, y0, x1, y1)
        imgbands = len(img.getbands())
        driver = gdal.GetDriverByName("GTiff")

        if "options" not in kwargs:
            kwargs["options"] = [
                "COMPRESS=DEFLATE",
                "PREDICTOR=2",
                "ZLEVEL=9",
                "TILED=YES",
            ]

        gtiff = driver.Create(
            filename,
            img.size[0],
            img.size[1],
            imgbands,
            gdal.GDT_Byte,
            **kwargs,
        )
        xp0, yp0 = from4326_to3857(lat0, lon0)
        xp1, yp1 = from4326_to3857(lat1, lon1)
        pwidth = abs(xp1 - xp0) / img.size[0]
        pheight = abs(yp1 - yp0) / img.size[1]
        gtiff.SetGeoTransform((min(xp0, xp1), pwidth, 0, max(yp0, yp1), 0, -pheight))
        gtiff.SetProjection(WKT_3857)
        for band in range(imgbands):
            array = numpy.array(img.getdata(band), dtype="u8")
            array = array.reshape((img.size[1], img.size[0]))
            band = gtiff.GetRasterBand(band + 1)
            band.WriteArray(array)
        gtiff.FlushCache()

        if not quiet:
            print(f"Image saved to {filename}")
        return img

    try:
        image = draw_tile(
            source, south, west, north, east, zoom, output, quiet, **kwargs
        )
        if return_image:
            return image
        if crs.upper() != "EPSG:3857":
            reproject(output, output, crs, to_cog=to_cog)
        elif to_cog:
            image_to_cog(output, output)
    except Exception as e:
        raise Exception(e)

transform_coords(x, y, src_crs, dst_crs, **kwargs)

Transform coordinates from one CRS to another.

Parameters:

Name Type Description Default
x float

The x coordinate.

required
y float

The y coordinate.

required
src_crs str

The source CRS, e.g., "EPSG:4326".

required
dst_crs str

The destination CRS, e.g., "EPSG:3857".

required

Returns:

Type Description
dict

The transformed coordinates in the format of (x, y)

Source code in samgeo/common.py
def transform_coords(x, y, src_crs, dst_crs, **kwargs):
    """Transform coordinates from one CRS to another.

    Args:
        x (float): The x coordinate.
        y (float): The y coordinate.
        src_crs (str): The source CRS, e.g., "EPSG:4326".
        dst_crs (str): The destination CRS, e.g., "EPSG:3857".

    Returns:
        dict: The transformed coordinates in the format of (x, y)
    """
    transformer = pyproj.Transformer.from_crs(
        src_crs, dst_crs, always_xy=True, **kwargs
    )
    return transformer.transform(x, y)

update_package(out_dir=None, keep=False, **kwargs)

Updates the package from the GitHub repository without the need to use pip or conda.

Parameters:

Name Type Description Default
out_dir str

The output directory. Defaults to None.

None
keep bool

Whether to keep the downloaded package. Defaults to False.

False
**kwargs

Additional keyword arguments to pass to the download_file() function.

{}
Source code in samgeo/common.py
def update_package(out_dir=None, keep=False, **kwargs):
    """Updates the package from the GitHub repository without the need to use pip or conda.

    Args:
        out_dir (str, optional): The output directory. Defaults to None.
        keep (bool, optional): Whether to keep the downloaded package. Defaults to False.
        **kwargs: Additional keyword arguments to pass to the download_file() function.
    """

    import shutil

    try:
        if out_dir is None:
            out_dir = os.getcwd()
        url = (
            "https://github.com/opengeos/segment-geospatial/archive/refs/heads/main.zip"
        )
        filename = "segment-geospatial-main.zip"
        download_file(url, filename, **kwargs)

        pkg_dir = os.path.join(out_dir, "segment-geospatial-main")
        work_dir = os.getcwd()
        os.chdir(pkg_dir)

        if shutil.which("pip") is None:
            cmd = "pip3 install ."
        else:
            cmd = "pip install ."

        os.system(cmd)
        os.chdir(work_dir)

        if not keep:
            shutil.rmtree(pkg_dir)
            try:
                os.remove(filename)
            except:
                pass

        print("Package updated successfully.")

    except Exception as e:
        raise Exception(e)

vector_to_geojson(filename, output=None, **kwargs)

Converts a vector file to a geojson file.

Parameters:

Name Type Description Default
filename str

The vector file path.

required
output str

The output geojson file path. Defaults to None.

None

Returns:

Type Description
dict

The geojson dictionary.

Source code in samgeo/common.py
def vector_to_geojson(filename, output=None, **kwargs):
    """Converts a vector file to a geojson file.

    Args:
        filename (str): The vector file path.
        output (str, optional): The output geojson file path. Defaults to None.

    Returns:
        dict: The geojson dictionary.
    """

    if filename.startswith("http"):
        filename = download_file(filename)

    gdf = gpd.read_file(filename, **kwargs)
    if output is None:
        return gdf.__geo_interface__
    else:
        gdf.to_file(output, driver="GeoJSON")

video_to_images(video_path, output_dir, frame_rate=None, prefix='')

Converts a video into a series of images. Each frame of the video is saved as an image.

Parameters:

Name Type Description Default
video_path str

The path to the video file.

required
output_dir str

The directory where the images will be saved.

required
frame_rate Optional[int]

The number of frames to save per second of video. If None, all frames will be saved. Defaults to None.

None
prefix str

The prefix for the output image filenames. Defaults to 'frame_'.

''

Exceptions:

Type Description
ValueError

If the video file cannot be read or if the output directory is invalid.

Example usage: video_to_images('input_video.mp4', 'output_images', frame_rate=1, prefix='image_')

Source code in samgeo/common.py
def video_to_images(
    video_path: str,
    output_dir: str,
    frame_rate: Optional[int] = None,
    prefix: str = "",
) -> None:
    """
    Converts a video into a series of images. Each frame of the video is saved as an image.

    Args:
        video_path (str): The path to the video file.
        output_dir (str): The directory where the images will be saved.
        frame_rate (Optional[int], optional): The number of frames to save per second of video.
            If None, all frames will be saved. Defaults to None.
        prefix (str, optional): The prefix for the output image filenames. Defaults to 'frame_'.

    Raises:
        ValueError: If the video file cannot be read or if the output directory is invalid.

    Example usage:
        video_to_images('input_video.mp4', 'output_images', frame_rate=1, prefix='image_')
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Open the video file
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Error opening video file {video_path}")

    # Get video properties
    video_fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_rate = (
        frame_rate if frame_rate else video_fps
    )  # Default to original FPS if not provided

    # Calculate the number of digits based on the total frames (e.g., if total frames are 1000, width = 4)
    num_digits = len(str(total_frames))

    print(f"Video FPS: {video_fps}")
    print(f"Total Frames: {total_frames}")
    print(f"Saving every {video_fps // frame_rate} frame(s)")

    frame_count = 0
    saved_frame_count = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Save frames based on frame_rate
        if frame_count % (video_fps // frame_rate) == 0:
            img_path = os.path.join(
                output_dir, f"{prefix}{saved_frame_count:0{num_digits}d}.jpg"
            )
            cv2.imwrite(img_path, frame)
            saved_frame_count += 1
            # print(f"Saved {img_path}")

        frame_count += 1

    # Release the video capture object
    cap.release()
    print(f"Finished saving {saved_frame_count} images to {output_dir}")