Skip to content

hq_sam module

Segment remote sensing imagery with HQ-SAM (High Quality Segment Anything Model). See https://github.com/SysCV/sam-hq

SamGeo

The main class for segmenting geospatial data with the Segment Anything Model (SAM). See https://github.com/facebookresearch/segment-anything for details.

Source code in samgeo/hq_sam.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
class SamGeo:
    """The main class for segmenting geospatial data with the Segment Anything Model (SAM). See
    https://github.com/facebookresearch/segment-anything for details.
    """

    def __init__(
        self,
        model_type="vit_h",
        automatic=True,
        device=None,
        checkpoint_dir=None,
        hq=False,
        sam_kwargs=None,
        **kwargs,
    ):
        """Initialize the class.

        Args:
            model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
                Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
            automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.
                The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
            device (str, optional): The device to use. It can be one of the following: cpu, cuda.
                Defaults to None, which will use cuda if available.
            hq (bool, optional): Whether to use the HQ-SAM model. Defaults to False.
            checkpoint_dir (str, optional): The path to the model checkpoint. It can be one of the following:
                sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
                Defaults to None. See https://bit.ly/3VrpxUh for more details.
            sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.
                The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.

                points_per_side: Optional[int] = 32,
                points_per_batch: int = 64,
                pred_iou_thresh: float = 0.88,
                stability_score_thresh: float = 0.95,
                stability_score_offset: float = 1.0,
                box_nms_thresh: float = 0.7,
                crop_n_layers: int = 0,
                crop_nms_thresh: float = 0.7,
                crop_overlap_ratio: float = 512 / 1500,
                crop_n_points_downscale_factor: int = 1,
                point_grids: Optional[List[np.ndarray]] = None,
                min_mask_region_area: int = 0,
                output_mode: str = "binary_mask",

        """

        hq = True  # Using HQ-SAM
        if "checkpoint" in kwargs:
            checkpoint = kwargs["checkpoint"]
            if not os.path.exists(checkpoint):
                checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)
            kwargs.pop("checkpoint")
        else:
            checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)

        # Use cuda if available
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            if device == "cuda":
                torch.cuda.empty_cache()

        self.checkpoint = checkpoint
        self.model_type = model_type
        self.device = device
        self.sam_kwargs = sam_kwargs  # Optional arguments for fine-tuning the SAM model
        self.source = None  # Store the input image path
        self.image = None  # Store the input image as a numpy array
        # Store the masks as a list of dictionaries. Each mask is a dictionary
        # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box
        self.masks = None
        self.objects = None  # Store the mask objects as a numpy array
        # Store the annotations (objects with random color) as a numpy array.
        self.annotations = None

        # Store the predicted masks, iou_predictions, and low_res_masks
        self.prediction = None
        self.scores = None
        self.logits = None

        # Build the SAM model
        self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
        self.sam.to(device=self.device)
        # Use optional arguments for fine-tuning the SAM model
        sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}

        if automatic:
            # Segment the entire image using the automatic mask generator
            self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)
        else:
            # Segment selected objects using input prompts
            self.predictor = SamPredictor(self.sam, **sam_kwargs)

    def __call__(
        self,
        image,
        foreground=True,
        erosion_kernel=(3, 3),
        mask_multiplier=255,
        **kwargs,
    ):
        """Generate masks for the input tile. This function originates from the segment-anything-eo repository.
            See https://bit.ly/41pwiHw

        Args:
            image (np.ndarray): The input image as a numpy array.
            foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
            erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).
            mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
                You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
        """
        h, w, _ = image.shape

        masks = self.mask_generator.generate(image)

        if foreground:  # Extract foreground objects only
            resulting_mask = np.zeros((h, w), dtype=np.uint8)
        else:
            resulting_mask = np.ones((h, w), dtype=np.uint8)
        resulting_borders = np.zeros((h, w), dtype=np.uint8)

        for m in masks:
            mask = (m["segmentation"] > 0).astype(np.uint8)
            resulting_mask += mask

            # Apply erosion to the mask
            if erosion_kernel is not None:
                mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
                mask_erode = (mask_erode > 0).astype(np.uint8)
                edge_mask = mask - mask_erode
                resulting_borders += edge_mask

        resulting_mask = (resulting_mask > 0).astype(np.uint8)
        resulting_borders = (resulting_borders > 0).astype(np.uint8)
        resulting_mask_with_borders = resulting_mask - resulting_borders
        return resulting_mask_with_borders * mask_multiplier

    def generate(
        self,
        source,
        output=None,
        foreground=True,
        batch=False,
        erosion_kernel=None,
        mask_multiplier=255,
        unique=True,
        **kwargs,
    ):
        """Generate masks for the input image.

        Args:
            source (str | np.ndarray): The path to the input image or the input image as a numpy array.
            output (str, optional): The path to the output image. Defaults to None.
            foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
            batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.
            erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
                Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
            mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
                You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
                The parameter is ignored if unique is True.
            unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
                The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.

        """

        if isinstance(source, str):
            if source.startswith("http"):
                source = download_file(source)

            if not os.path.exists(source):
                raise ValueError(f"Input path {source} does not exist.")

            if batch:  # Subdivide the image into tiles and segment each tile
                self.batch = True
                self.source = source
                self.masks = output
                return tiff_to_tiff(
                    source,
                    output,
                    self,
                    foreground=foreground,
                    erosion_kernel=erosion_kernel,
                    mask_multiplier=mask_multiplier,
                    **kwargs,
                )

            image = cv2.imread(source)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        elif isinstance(source, np.ndarray):
            image = source
            source = None
        else:
            raise ValueError("Input source must be either a path or a numpy array.")

        self.source = source  # Store the input image path
        self.image = image  # Store the input image as a numpy array
        mask_generator = self.mask_generator  # The automatic mask generator
        masks = mask_generator.generate(image)  # Segment the input image
        self.masks = masks  # Store the masks as a list of dictionaries
        self.batch = False

        if output is not None:
            # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
            self.save_masks(
                output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
            )

    def save_masks(
        self,
        output=None,
        foreground=True,
        unique=True,
        erosion_kernel=None,
        mask_multiplier=255,
        **kwargs,
    ):
        """Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.

        Args:
            output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.
            foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
            unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
            erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
                Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
            mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
                You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.

        """

        if self.masks is None:
            raise ValueError("No masks found. Please run generate() first.")

        h, w, _ = self.image.shape
        masks = self.masks

        # Set output image data type based on the number of objects
        if len(masks) < 255:
            dtype = np.uint8
        elif len(masks) < 65535:
            dtype = np.uint16
        else:
            dtype = np.uint32

        # Generate a mask of objects with unique values
        if unique:
            # Sort the masks by area in ascending order
            sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False)

            # Create an output image with the same size as the input image
            objects = np.zeros(
                (
                    sorted_masks[0]["segmentation"].shape[0],
                    sorted_masks[0]["segmentation"].shape[1],
                )
            )
            # Assign a unique value to each object
            for index, ann in enumerate(sorted_masks):
                m = ann["segmentation"]
                objects[m] = index + 1

        # Generate a binary mask
        else:
            if foreground:  # Extract foreground objects only
                resulting_mask = np.zeros((h, w), dtype=dtype)
            else:
                resulting_mask = np.ones((h, w), dtype=dtype)
            resulting_borders = np.zeros((h, w), dtype=dtype)

            for m in masks:
                mask = (m["segmentation"] > 0).astype(dtype)
                resulting_mask += mask

                # Apply erosion to the mask
                if erosion_kernel is not None:
                    mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
                    mask_erode = (mask_erode > 0).astype(dtype)
                    edge_mask = mask - mask_erode
                    resulting_borders += edge_mask

            resulting_mask = (resulting_mask > 0).astype(dtype)
            resulting_borders = (resulting_borders > 0).astype(dtype)
            objects = resulting_mask - resulting_borders
            objects = objects * mask_multiplier

        objects = objects.astype(dtype)
        self.objects = objects

        if output is not None:  # Save the output image
            array_to_image(self.objects, output, self.source, **kwargs)

    def show_masks(
        self, figsize=(12, 10), cmap="binary_r", axis="off", foreground=True, **kwargs
    ):
        """Show the binary mask or the mask of objects with unique values.

        Args:
            figsize (tuple, optional): The figure size. Defaults to (12, 10).
            cmap (str, optional): The colormap. Defaults to "binary_r".
            axis (str, optional): Whether to show the axis. Defaults to "off".
            foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.
            **kwargs: Other arguments for save_masks().
        """

        import matplotlib.pyplot as plt

        if self.batch:
            self.objects = cv2.imread(self.masks)
        else:
            if self.objects is None:
                self.save_masks(foreground=foreground, **kwargs)

        plt.figure(figsize=figsize)
        plt.imshow(self.objects, cmap=cmap)
        plt.axis(axis)
        plt.show()

    def show_anns(
        self,
        figsize=(12, 10),
        axis="off",
        alpha=0.35,
        output=None,
        blend=True,
        **kwargs,
    ):
        """Show the annotations (objects with random color) on the input image.

        Args:
            figsize (tuple, optional): The figure size. Defaults to (12, 10).
            axis (str, optional): Whether to show the axis. Defaults to "off".
            alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
            output (str, optional): The path to the output image. Defaults to None.
            blend (bool, optional): Whether to show the input image. Defaults to True.
        """

        import matplotlib.pyplot as plt

        anns = self.masks

        if self.image is None:
            print("Please run generate() first.")
            return

        if anns is None or len(anns) == 0:
            return

        plt.figure(figsize=figsize)
        plt.imshow(self.image)

        sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)

        ax = plt.gca()
        ax.set_autoscale_on(False)

        img = np.ones(
            (
                sorted_anns[0]["segmentation"].shape[0],
                sorted_anns[0]["segmentation"].shape[1],
                4,
            )
        )
        img[:, :, 3] = 0
        for ann in sorted_anns:
            m = ann["segmentation"]
            color_mask = np.concatenate([np.random.random(3), [alpha]])
            img[m] = color_mask
        ax.imshow(img)

        if "dpi" not in kwargs:
            kwargs["dpi"] = 100

        if "bbox_inches" not in kwargs:
            kwargs["bbox_inches"] = "tight"

        plt.axis(axis)

        self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)

        if output is not None:
            if blend:
                array = blend_images(
                    self.annotations, self.image, alpha=alpha, show=False
                )
            else:
                array = self.annotations
            array_to_image(array, output, self.source)

    def set_image(self, image, image_format="RGB"):
        """Set the input image as a numpy array.

        Args:
            image (np.ndarray): The input image as a numpy array.
            image_format (str, optional): The image format, can be RGB or BGR. Defaults to "RGB".
        """
        if isinstance(image, str):
            if image.startswith("http"):
                image = download_file(image)

            if not os.path.exists(image):
                raise ValueError(f"Input path {image} does not exist.")

            self.source = image

            image = cv2.imread(image)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            self.image = image
        elif isinstance(image, np.ndarray):
            pass
        else:
            raise ValueError("Input image must be either a path or a numpy array.")

        self.predictor.set_image(image, image_format=image_format)

    def save_prediction(
        self,
        output,
        index=None,
        mask_multiplier=255,
        dtype=np.float32,
        vector=None,
        simplify_tolerance=None,
        **kwargs,
    ):
        """Save the predicted mask to the output path.

        Args:
            output (str): The path to the output image.
            index (int, optional): The index of the mask to save. Defaults to None,
                which will save the mask with the highest score.
            mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
            vector (str, optional): The path to the output vector file. Defaults to None.
            dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
            simplify_tolerance (float, optional): The maximum allowed geometry displacement.
                The higher this value, the smaller the number of vertices in the resulting geometry.

        """
        if self.scores is None:
            raise ValueError("No predictions found. Please run predict() first.")

        if index is None:
            index = self.scores.argmax(axis=0)

        array = self.masks[index] * mask_multiplier
        self.prediction = array
        array_to_image(array, output, self.source, dtype=dtype, **kwargs)

        if vector is not None:
            raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)

    def predict(
        self,
        point_coords=None,
        point_labels=None,
        boxes=None,
        point_crs=None,
        mask_input=None,
        multimask_output=True,
        return_logits=False,
        output=None,
        index=None,
        mask_multiplier=255,
        dtype="float32",
        return_results=False,
        **kwargs,
    ):
        """Predict masks for the given input prompts, using the currently set image.

        Args:
            point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the
                model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON
                dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
            point_labels (list | int | np.ndarray, optional): A length N array of labels for the
                point prompts. 1 indicates a foreground point and 0 indicates a background point.
            point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
            boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
                model, in XYXY format.
            mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
                coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
                multimask_output (bool, optional): If true, the model will return three masks.
                For ambiguous input prompts (such as a single click), this will often
                produce better masks than a single prediction. If only a single
                mask is needed, the model's predicted quality score can be used
                to select the best mask. For non-ambiguous prompts, such as multiple
                input prompts, multimask_output=False can give better results.
            return_logits (bool, optional): If true, returns un-thresholded masks logits
                instead of a binary mask.
            output (str, optional): The path to the output image. Defaults to None.
            index (index, optional): The index of the mask to save. Defaults to None,
                which will save the mask with the highest score.
            mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
            dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
            return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.

        """

        if isinstance(boxes, str):
            gdf = gpd.read_file(boxes)
            if gdf.crs is not None:
                gdf = gdf.to_crs("epsg:4326")
            boxes = gdf.geometry.bounds.values.tolist()
        elif isinstance(boxes, dict):
            import json

            geojson = json.dumps(boxes)
            gdf = gpd.read_file(geojson, driver="GeoJSON")
            boxes = gdf.geometry.bounds.values.tolist()

        if isinstance(point_coords, str):
            point_coords = vector_to_geojson(point_coords)

        if isinstance(point_coords, dict):
            point_coords = geojson_to_coords(point_coords)

        if hasattr(self, "point_coords"):
            point_coords = self.point_coords

        if hasattr(self, "point_labels"):
            point_labels = self.point_labels

        if (point_crs is not None) and (point_coords is not None):
            point_coords = coords_to_xy(self.source, point_coords, point_crs)

        if isinstance(point_coords, list):
            point_coords = np.array(point_coords)

        if point_coords is not None:
            if point_labels is None:
                point_labels = [1] * len(point_coords)
            elif isinstance(point_labels, int):
                point_labels = [point_labels] * len(point_coords)

        if isinstance(point_labels, list):
            if len(point_labels) != len(point_coords):
                if len(point_labels) == 1:
                    point_labels = point_labels * len(point_coords)
                else:
                    raise ValueError(
                        "The length of point_labels must be equal to the length of point_coords."
                    )
            point_labels = np.array(point_labels)

        predictor = self.predictor

        input_boxes = None
        if isinstance(boxes, list) and (point_crs is not None):
            coords = bbox_to_xy(self.source, boxes, point_crs)
            input_boxes = np.array(coords)
            if isinstance(coords[0], int):
                input_boxes = input_boxes[None, :]
            else:
                input_boxes = torch.tensor(input_boxes, device=self.device)
                input_boxes = predictor.transform.apply_boxes_torch(
                    input_boxes, self.image.shape[:2]
                )
        elif isinstance(boxes, list) and (point_crs is None):
            input_boxes = np.array(boxes)
            if isinstance(boxes[0], int):
                input_boxes = input_boxes[None, :]

        self.boxes = input_boxes

        if (
            boxes is None
            or (len(boxes) == 1)
            or (len(boxes) == 4 and isinstance(boxes[0], float))
        ):
            if isinstance(boxes, list) and isinstance(boxes[0], list):
                boxes = boxes[0]
            masks, scores, logits = predictor.predict(
                point_coords,
                point_labels,
                input_boxes,
                mask_input,
                multimask_output,
                return_logits,
            )
        else:
            masks, scores, logits = predictor.predict_torch(
                point_coords=point_coords,
                point_labels=point_coords,
                boxes=input_boxes,
                multimask_output=True,
            )

        self.masks = masks
        self.scores = scores
        self.logits = logits

        if output is not None:
            if boxes is None or (not isinstance(boxes[0], list)):
                self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
            else:
                self.tensor_to_numpy(
                    index, output, mask_multiplier, dtype, save_args=kwargs
                )

        if return_results:
            return masks, scores, logits

    def tensor_to_numpy(
        self, index=None, output=None, mask_multiplier=255, dtype="uint8", save_args={}
    ):
        """Convert the predicted masks from tensors to numpy arrays.

        Args:
            index (index, optional): The index of the mask to save. Defaults to None,
                which will save the mask with the highest score.
            output (str, optional): The path to the output image. Defaults to None.
            mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
            dtype (np.dtype, optional): The data type of the output image. Defaults to np.uint8.
            save_args (dict, optional): Optional arguments for saving the output image. Defaults to {}.

        Returns:
            np.ndarray: The predicted mask as a numpy array.
        """

        boxes = self.boxes
        masks = self.masks

        image_pil = self.image
        image_np = np.array(image_pil)

        if index is None:
            index = 1

        masks = masks[:, index, :, :]
        masks = masks.squeeze(1)

        if boxes is None or (len(boxes) == 0):  # No "object" instances found
            print("No objects found in the image.")
            return
        else:
            # Create an empty image to store the mask overlays
            mask_overlay = np.zeros_like(
                image_np[..., 0], dtype=dtype
            )  # Adjusted for single channel

            for i, (box, mask) in enumerate(zip(boxes, masks)):
                # Convert tensor to numpy array if necessary and ensure it contains integers
                if isinstance(mask, torch.Tensor):
                    mask = (
                        mask.cpu().numpy().astype(dtype)
                    )  # If mask is on GPU, use .cpu() before .numpy()
                mask_overlay += ((mask > 0) * (i + 1)).astype(
                    dtype
                )  # Assign a unique value for each mask

            # Normalize mask_overlay to be in [0, 255]
            mask_overlay = (
                mask_overlay > 0
            ) * mask_multiplier  # Binary mask in [0, 255]

        if output is not None:
            array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)
        else:
            return mask_overlay

    def show_map(self, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwargs):
        """Show the interactive map.

        Args:
            basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
            repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.
            out_dir (str, optional): The path to the output directory. Defaults to None.

        Returns:
            leafmap.Map: The map object.
        """
        return sam_map_gui(
            self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs
        )

    def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):
        """Show a canvas to collect foreground and background points.

        Args:
            image (str | np.ndarray): The input image.
            fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).
            bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).
            radius (int, optional): The radius of the points. Defaults to 5.

        Returns:
            tuple: A tuple of two lists of foreground and background points.
        """

        if self.image is None:
            raise ValueError("Please run set_image() first.")

        image = self.image
        fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)
        self.fg_points = fg_points
        self.bg_points = bg_points
        point_coords = fg_points + bg_points
        point_labels = [1] * len(fg_points) + [0] * len(bg_points)
        self.point_coords = point_coords
        self.point_labels = point_labels

    def clear_cuda_cache(self):
        """Clear the CUDA cache."""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def image_to_image(self, image, **kwargs):
        return image_to_image(image, self, **kwargs)

    def download_tms_as_tiff(self, source, pt1, pt2, zoom, dist):
        image = draw_tile(source, pt1[0], pt1[1], pt2[0], pt2[1], zoom, dist)
        return image

    def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):
        """Save the result to a vector file.

        Args:
            image (str): The path to the image file.
            output (str): The path to the vector file.
            simplify_tolerance (float, optional): The maximum allowed geometry displacement.
                The higher this value, the smaller the number of vertices in the resulting geometry.
        """

        raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)

    def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):
        """Convert a tiff file to a gpkg file.

        Args:
            tiff_path (str): The path to the tiff file.
            output (str): The path to the vector file.
            simplify_tolerance (float, optional): The maximum allowed geometry displacement.
                The higher this value, the smaller the number of vertices in the resulting geometry.
        """

        raster_to_vector(
            tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
        )

    def tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs):
        """Convert a tiff file to a gpkg file.

        Args:
            tiff_path (str): The path to the tiff file.
            output (str): The path to the gpkg file.
            simplify_tolerance (float, optional): The maximum allowed geometry displacement.
                The higher this value, the smaller the number of vertices in the resulting geometry.
        """

        raster_to_gpkg(
            tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
        )

    def tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs):
        """Convert a tiff file to a shapefile.

        Args:
            tiff_path (str): The path to the tiff file.
            output (str): The path to the shapefile.
            simplify_tolerance (float, optional): The maximum allowed geometry displacement.
                The higher this value, the smaller the number of vertices in the resulting geometry.
        """

        raster_to_shp(
            tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
        )

    def tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs):
        """Convert a tiff file to a GeoJSON file.

        Args:
            tiff_path (str): The path to the tiff file.
            output (str): The path to the GeoJSON file.
            simplify_tolerance (float, optional): The maximum allowed geometry displacement.
                The higher this value, the smaller the number of vertices in the resulting geometry.
        """

        raster_to_geojson(
            tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
        )

__call__(image, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255, **kwargs)

Generate masks for the input tile. This function originates from the segment-anything-eo repository. See https://bit.ly/41pwiHw

Parameters:

Name Type Description Default
image ndarray

The input image as a numpy array.

required
foreground bool

Whether to generate the foreground mask. Defaults to True.

True
erosion_kernel tuple

The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).

(3, 3)
mask_multiplier int

The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.

255
Source code in samgeo/hq_sam.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def __call__(
    self,
    image,
    foreground=True,
    erosion_kernel=(3, 3),
    mask_multiplier=255,
    **kwargs,
):
    """Generate masks for the input tile. This function originates from the segment-anything-eo repository.
        See https://bit.ly/41pwiHw

    Args:
        image (np.ndarray): The input image as a numpy array.
        foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
        erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).
        mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
            You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
    """
    h, w, _ = image.shape

    masks = self.mask_generator.generate(image)

    if foreground:  # Extract foreground objects only
        resulting_mask = np.zeros((h, w), dtype=np.uint8)
    else:
        resulting_mask = np.ones((h, w), dtype=np.uint8)
    resulting_borders = np.zeros((h, w), dtype=np.uint8)

    for m in masks:
        mask = (m["segmentation"] > 0).astype(np.uint8)
        resulting_mask += mask

        # Apply erosion to the mask
        if erosion_kernel is not None:
            mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
            mask_erode = (mask_erode > 0).astype(np.uint8)
            edge_mask = mask - mask_erode
            resulting_borders += edge_mask

    resulting_mask = (resulting_mask > 0).astype(np.uint8)
    resulting_borders = (resulting_borders > 0).astype(np.uint8)
    resulting_mask_with_borders = resulting_mask - resulting_borders
    return resulting_mask_with_borders * mask_multiplier

__init__(model_type='vit_h', automatic=True, device=None, checkpoint_dir=None, hq=False, sam_kwargs=None, **kwargs)

Initialize the class.

Parameters:

Name Type Description Default
model_type str

The model type. It can be one of the following: vit_h, vit_l, vit_b. Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.

'vit_h'
automatic bool

Whether to use the automatic mask generator or input prompts. Defaults to True. The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.

True
device str

The device to use. It can be one of the following: cpu, cuda. Defaults to None, which will use cuda if available.

None
hq bool

Whether to use the HQ-SAM model. Defaults to False.

False
checkpoint_dir str

The path to the model checkpoint. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth. Defaults to None. See https://bit.ly/3VrpxUh for more details.

None
sam_kwargs dict

Optional arguments for fine-tuning the SAM model. Defaults to None. The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.

points_per_side: Optional[int] = 32, points_per_batch: int = 64, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, stability_score_offset: float = 1.0, box_nms_thresh: float = 0.7, crop_n_layers: int = 0, crop_nms_thresh: float = 0.7, crop_overlap_ratio: float = 512 / 1500, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[np.ndarray]] = None, min_mask_region_area: int = 0, output_mode: str = "binary_mask",

None
Source code in samgeo/hq_sam.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def __init__(
    self,
    model_type="vit_h",
    automatic=True,
    device=None,
    checkpoint_dir=None,
    hq=False,
    sam_kwargs=None,
    **kwargs,
):
    """Initialize the class.

    Args:
        model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_b.
            Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
        automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.
            The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
        device (str, optional): The device to use. It can be one of the following: cpu, cuda.
            Defaults to None, which will use cuda if available.
        hq (bool, optional): Whether to use the HQ-SAM model. Defaults to False.
        checkpoint_dir (str, optional): The path to the model checkpoint. It can be one of the following:
            sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
            Defaults to None. See https://bit.ly/3VrpxUh for more details.
        sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.
            The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.

            points_per_side: Optional[int] = 32,
            points_per_batch: int = 64,
            pred_iou_thresh: float = 0.88,
            stability_score_thresh: float = 0.95,
            stability_score_offset: float = 1.0,
            box_nms_thresh: float = 0.7,
            crop_n_layers: int = 0,
            crop_nms_thresh: float = 0.7,
            crop_overlap_ratio: float = 512 / 1500,
            crop_n_points_downscale_factor: int = 1,
            point_grids: Optional[List[np.ndarray]] = None,
            min_mask_region_area: int = 0,
            output_mode: str = "binary_mask",

    """

    hq = True  # Using HQ-SAM
    if "checkpoint" in kwargs:
        checkpoint = kwargs["checkpoint"]
        if not os.path.exists(checkpoint):
            checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)
        kwargs.pop("checkpoint")
    else:
        checkpoint = download_checkpoint(model_type, checkpoint_dir, hq)

    # Use cuda if available
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if device == "cuda":
            torch.cuda.empty_cache()

    self.checkpoint = checkpoint
    self.model_type = model_type
    self.device = device
    self.sam_kwargs = sam_kwargs  # Optional arguments for fine-tuning the SAM model
    self.source = None  # Store the input image path
    self.image = None  # Store the input image as a numpy array
    # Store the masks as a list of dictionaries. Each mask is a dictionary
    # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box
    self.masks = None
    self.objects = None  # Store the mask objects as a numpy array
    # Store the annotations (objects with random color) as a numpy array.
    self.annotations = None

    # Store the predicted masks, iou_predictions, and low_res_masks
    self.prediction = None
    self.scores = None
    self.logits = None

    # Build the SAM model
    self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
    self.sam.to(device=self.device)
    # Use optional arguments for fine-tuning the SAM model
    sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}

    if automatic:
        # Segment the entire image using the automatic mask generator
        self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)
    else:
        # Segment selected objects using input prompts
        self.predictor = SamPredictor(self.sam, **sam_kwargs)

clear_cuda_cache()

Clear the CUDA cache.

Source code in samgeo/hq_sam.py
751
752
753
754
def clear_cuda_cache(self):
    """Clear the CUDA cache."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

generate(source, output=None, foreground=True, batch=False, erosion_kernel=None, mask_multiplier=255, unique=True, **kwargs)

Generate masks for the input image.

Parameters:

Name Type Description Default
source str | ndarray

The path to the input image or the input image as a numpy array.

required
output str

The path to the output image. Defaults to None.

None
foreground bool

Whether to generate the foreground mask. Defaults to True.

True
batch bool

Whether to generate masks for a batch of image tiles. Defaults to False.

False
erosion_kernel tuple

The erosion kernel for filtering object masks and extract borders. Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.

None
mask_multiplier int

The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. The parameter is ignored if unique is True.

255
unique bool

Whether to assign a unique value to each object. Defaults to True. The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.

True
Source code in samgeo/hq_sam.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def generate(
    self,
    source,
    output=None,
    foreground=True,
    batch=False,
    erosion_kernel=None,
    mask_multiplier=255,
    unique=True,
    **kwargs,
):
    """Generate masks for the input image.

    Args:
        source (str | np.ndarray): The path to the input image or the input image as a numpy array.
        output (str, optional): The path to the output image. Defaults to None.
        foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
        batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.
        erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
            Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
        mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
            You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
            The parameter is ignored if unique is True.
        unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
            The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.

    """

    if isinstance(source, str):
        if source.startswith("http"):
            source = download_file(source)

        if not os.path.exists(source):
            raise ValueError(f"Input path {source} does not exist.")

        if batch:  # Subdivide the image into tiles and segment each tile
            self.batch = True
            self.source = source
            self.masks = output
            return tiff_to_tiff(
                source,
                output,
                self,
                foreground=foreground,
                erosion_kernel=erosion_kernel,
                mask_multiplier=mask_multiplier,
                **kwargs,
            )

        image = cv2.imread(source)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    elif isinstance(source, np.ndarray):
        image = source
        source = None
    else:
        raise ValueError("Input source must be either a path or a numpy array.")

    self.source = source  # Store the input image path
    self.image = image  # Store the input image as a numpy array
    mask_generator = self.mask_generator  # The automatic mask generator
    masks = mask_generator.generate(image)  # Segment the input image
    self.masks = masks  # Store the masks as a list of dictionaries
    self.batch = False

    if output is not None:
        # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
        self.save_masks(
            output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
        )

predict(point_coords=None, point_labels=None, boxes=None, point_crs=None, mask_input=None, multimask_output=True, return_logits=False, output=None, index=None, mask_multiplier=255, dtype='float32', return_results=False, **kwargs)

Predict masks for the given input prompts, using the currently set image.

Parameters:

Name Type Description Default
point_coords str | dict | list | ndarray

A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.

None
point_labels list | int | ndarray

A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.

None
point_crs str

The coordinate reference system (CRS) of the point prompts.

None
boxes list | ndarray

A length 4 array given a box prompt to the model, in XYXY format.

None
mask_input ndarray

A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. multimask_output (bool, optional): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results.

None
return_logits bool

If true, returns un-thresholded masks logits instead of a binary mask.

False
output str

The path to the output image. Defaults to None.

None
index index

The index of the mask to save. Defaults to None, which will save the mask with the highest score.

None
mask_multiplier int

The mask multiplier for the output mask, which is usually a binary mask [0, 1].

255
dtype dtype

The data type of the output image. Defaults to np.float32.

'float32'
return_results bool

Whether to return the predicted masks, scores, and logits. Defaults to False.

False
Source code in samgeo/hq_sam.py
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
def predict(
    self,
    point_coords=None,
    point_labels=None,
    boxes=None,
    point_crs=None,
    mask_input=None,
    multimask_output=True,
    return_logits=False,
    output=None,
    index=None,
    mask_multiplier=255,
    dtype="float32",
    return_results=False,
    **kwargs,
):
    """Predict masks for the given input prompts, using the currently set image.

    Args:
        point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the
            model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON
            dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
        point_labels (list | int | np.ndarray, optional): A length N array of labels for the
            point prompts. 1 indicates a foreground point and 0 indicates a background point.
        point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
        boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
            model, in XYXY format.
        mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
            coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
            multimask_output (bool, optional): If true, the model will return three masks.
            For ambiguous input prompts (such as a single click), this will often
            produce better masks than a single prediction. If only a single
            mask is needed, the model's predicted quality score can be used
            to select the best mask. For non-ambiguous prompts, such as multiple
            input prompts, multimask_output=False can give better results.
        return_logits (bool, optional): If true, returns un-thresholded masks logits
            instead of a binary mask.
        output (str, optional): The path to the output image. Defaults to None.
        index (index, optional): The index of the mask to save. Defaults to None,
            which will save the mask with the highest score.
        mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
        dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
        return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.

    """

    if isinstance(boxes, str):
        gdf = gpd.read_file(boxes)
        if gdf.crs is not None:
            gdf = gdf.to_crs("epsg:4326")
        boxes = gdf.geometry.bounds.values.tolist()
    elif isinstance(boxes, dict):
        import json

        geojson = json.dumps(boxes)
        gdf = gpd.read_file(geojson, driver="GeoJSON")
        boxes = gdf.geometry.bounds.values.tolist()

    if isinstance(point_coords, str):
        point_coords = vector_to_geojson(point_coords)

    if isinstance(point_coords, dict):
        point_coords = geojson_to_coords(point_coords)

    if hasattr(self, "point_coords"):
        point_coords = self.point_coords

    if hasattr(self, "point_labels"):
        point_labels = self.point_labels

    if (point_crs is not None) and (point_coords is not None):
        point_coords = coords_to_xy(self.source, point_coords, point_crs)

    if isinstance(point_coords, list):
        point_coords = np.array(point_coords)

    if point_coords is not None:
        if point_labels is None:
            point_labels = [1] * len(point_coords)
        elif isinstance(point_labels, int):
            point_labels = [point_labels] * len(point_coords)

    if isinstance(point_labels, list):
        if len(point_labels) != len(point_coords):
            if len(point_labels) == 1:
                point_labels = point_labels * len(point_coords)
            else:
                raise ValueError(
                    "The length of point_labels must be equal to the length of point_coords."
                )
        point_labels = np.array(point_labels)

    predictor = self.predictor

    input_boxes = None
    if isinstance(boxes, list) and (point_crs is not None):
        coords = bbox_to_xy(self.source, boxes, point_crs)
        input_boxes = np.array(coords)
        if isinstance(coords[0], int):
            input_boxes = input_boxes[None, :]
        else:
            input_boxes = torch.tensor(input_boxes, device=self.device)
            input_boxes = predictor.transform.apply_boxes_torch(
                input_boxes, self.image.shape[:2]
            )
    elif isinstance(boxes, list) and (point_crs is None):
        input_boxes = np.array(boxes)
        if isinstance(boxes[0], int):
            input_boxes = input_boxes[None, :]

    self.boxes = input_boxes

    if (
        boxes is None
        or (len(boxes) == 1)
        or (len(boxes) == 4 and isinstance(boxes[0], float))
    ):
        if isinstance(boxes, list) and isinstance(boxes[0], list):
            boxes = boxes[0]
        masks, scores, logits = predictor.predict(
            point_coords,
            point_labels,
            input_boxes,
            mask_input,
            multimask_output,
            return_logits,
        )
    else:
        masks, scores, logits = predictor.predict_torch(
            point_coords=point_coords,
            point_labels=point_coords,
            boxes=input_boxes,
            multimask_output=True,
        )

    self.masks = masks
    self.scores = scores
    self.logits = logits

    if output is not None:
        if boxes is None or (not isinstance(boxes[0], list)):
            self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
        else:
            self.tensor_to_numpy(
                index, output, mask_multiplier, dtype, save_args=kwargs
            )

    if return_results:
        return masks, scores, logits

raster_to_vector(image, output, simplify_tolerance=None, **kwargs)

Save the result to a vector file.

Parameters:

Name Type Description Default
image str

The path to the image file.

required
output str

The path to the vector file.

required
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in samgeo/hq_sam.py
763
764
765
766
767
768
769
770
771
772
773
def raster_to_vector(self, image, output, simplify_tolerance=None, **kwargs):
    """Save the result to a vector file.

    Args:
        image (str): The path to the image file.
        output (str): The path to the vector file.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.
    """

    raster_to_vector(image, output, simplify_tolerance=simplify_tolerance, **kwargs)

save_masks(output=None, foreground=True, unique=True, erosion_kernel=None, mask_multiplier=255, **kwargs)

Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.

Parameters:

Name Type Description Default
output str

The path to the output image. Defaults to None, saving the masks to SamGeo.objects.

None
foreground bool

Whether to generate the foreground mask. Defaults to True.

True
unique bool

Whether to assign a unique value to each object. Defaults to True.

True
erosion_kernel tuple

The erosion kernel for filtering object masks and extract borders. Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.

None
mask_multiplier int

The mask multiplier for the output mask, which is usually a binary mask [0, 1]. You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.

255
Source code in samgeo/hq_sam.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def save_masks(
    self,
    output=None,
    foreground=True,
    unique=True,
    erosion_kernel=None,
    mask_multiplier=255,
    **kwargs,
):
    """Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.

    Args:
        output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.
        foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
        unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
        erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
            Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
        mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
            You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.

    """

    if self.masks is None:
        raise ValueError("No masks found. Please run generate() first.")

    h, w, _ = self.image.shape
    masks = self.masks

    # Set output image data type based on the number of objects
    if len(masks) < 255:
        dtype = np.uint8
    elif len(masks) < 65535:
        dtype = np.uint16
    else:
        dtype = np.uint32

    # Generate a mask of objects with unique values
    if unique:
        # Sort the masks by area in ascending order
        sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False)

        # Create an output image with the same size as the input image
        objects = np.zeros(
            (
                sorted_masks[0]["segmentation"].shape[0],
                sorted_masks[0]["segmentation"].shape[1],
            )
        )
        # Assign a unique value to each object
        for index, ann in enumerate(sorted_masks):
            m = ann["segmentation"]
            objects[m] = index + 1

    # Generate a binary mask
    else:
        if foreground:  # Extract foreground objects only
            resulting_mask = np.zeros((h, w), dtype=dtype)
        else:
            resulting_mask = np.ones((h, w), dtype=dtype)
        resulting_borders = np.zeros((h, w), dtype=dtype)

        for m in masks:
            mask = (m["segmentation"] > 0).astype(dtype)
            resulting_mask += mask

            # Apply erosion to the mask
            if erosion_kernel is not None:
                mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
                mask_erode = (mask_erode > 0).astype(dtype)
                edge_mask = mask - mask_erode
                resulting_borders += edge_mask

        resulting_mask = (resulting_mask > 0).astype(dtype)
        resulting_borders = (resulting_borders > 0).astype(dtype)
        objects = resulting_mask - resulting_borders
        objects = objects * mask_multiplier

    objects = objects.astype(dtype)
    self.objects = objects

    if output is not None:  # Save the output image
        array_to_image(self.objects, output, self.source, **kwargs)

save_prediction(output, index=None, mask_multiplier=255, dtype=np.float32, vector=None, simplify_tolerance=None, **kwargs)

Save the predicted mask to the output path.

Parameters:

Name Type Description Default
output str

The path to the output image.

required
index int

The index of the mask to save. Defaults to None, which will save the mask with the highest score.

None
mask_multiplier int

The mask multiplier for the output mask, which is usually a binary mask [0, 1].

255
vector str

The path to the output vector file. Defaults to None.

None
dtype dtype

The data type of the output image. Defaults to np.float32.

float32
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in samgeo/hq_sam.py
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
def save_prediction(
    self,
    output,
    index=None,
    mask_multiplier=255,
    dtype=np.float32,
    vector=None,
    simplify_tolerance=None,
    **kwargs,
):
    """Save the predicted mask to the output path.

    Args:
        output (str): The path to the output image.
        index (int, optional): The index of the mask to save. Defaults to None,
            which will save the mask with the highest score.
        mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
        vector (str, optional): The path to the output vector file. Defaults to None.
        dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.

    """
    if self.scores is None:
        raise ValueError("No predictions found. Please run predict() first.")

    if index is None:
        index = self.scores.argmax(axis=0)

    array = self.masks[index] * mask_multiplier
    self.prediction = array
    array_to_image(array, output, self.source, dtype=dtype, **kwargs)

    if vector is not None:
        raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)

set_image(image, image_format='RGB')

Set the input image as a numpy array.

Parameters:

Name Type Description Default
image ndarray

The input image as a numpy array.

required
image_format str

The image format, can be RGB or BGR. Defaults to "RGB".

'RGB'
Source code in samgeo/hq_sam.py
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
def set_image(self, image, image_format="RGB"):
    """Set the input image as a numpy array.

    Args:
        image (np.ndarray): The input image as a numpy array.
        image_format (str, optional): The image format, can be RGB or BGR. Defaults to "RGB".
    """
    if isinstance(image, str):
        if image.startswith("http"):
            image = download_file(image)

        if not os.path.exists(image):
            raise ValueError(f"Input path {image} does not exist.")

        self.source = image

        image = cv2.imread(image)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        self.image = image
    elif isinstance(image, np.ndarray):
        pass
    else:
        raise ValueError("Input image must be either a path or a numpy array.")

    self.predictor.set_image(image, image_format=image_format)

show_anns(figsize=(12, 10), axis='off', alpha=0.35, output=None, blend=True, **kwargs)

Show the annotations (objects with random color) on the input image.

Parameters:

Name Type Description Default
figsize tuple

The figure size. Defaults to (12, 10).

(12, 10)
axis str

Whether to show the axis. Defaults to "off".

'off'
alpha float

The alpha value for the annotations. Defaults to 0.35.

0.35
output str

The path to the output image. Defaults to None.

None
blend bool

Whether to show the input image. Defaults to True.

True
Source code in samgeo/hq_sam.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
def show_anns(
    self,
    figsize=(12, 10),
    axis="off",
    alpha=0.35,
    output=None,
    blend=True,
    **kwargs,
):
    """Show the annotations (objects with random color) on the input image.

    Args:
        figsize (tuple, optional): The figure size. Defaults to (12, 10).
        axis (str, optional): Whether to show the axis. Defaults to "off".
        alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
        output (str, optional): The path to the output image. Defaults to None.
        blend (bool, optional): Whether to show the input image. Defaults to True.
    """

    import matplotlib.pyplot as plt

    anns = self.masks

    if self.image is None:
        print("Please run generate() first.")
        return

    if anns is None or len(anns) == 0:
        return

    plt.figure(figsize=figsize)
    plt.imshow(self.image)

    sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)

    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones(
        (
            sorted_anns[0]["segmentation"].shape[0],
            sorted_anns[0]["segmentation"].shape[1],
            4,
        )
    )
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann["segmentation"]
        color_mask = np.concatenate([np.random.random(3), [alpha]])
        img[m] = color_mask
    ax.imshow(img)

    if "dpi" not in kwargs:
        kwargs["dpi"] = 100

    if "bbox_inches" not in kwargs:
        kwargs["bbox_inches"] = "tight"

    plt.axis(axis)

    self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)

    if output is not None:
        if blend:
            array = blend_images(
                self.annotations, self.image, alpha=alpha, show=False
            )
        else:
            array = self.annotations
        array_to_image(array, output, self.source)

show_canvas(fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5)

Show a canvas to collect foreground and background points.

Parameters:

Name Type Description Default
image str | ndarray

The input image.

required
fg_color tuple

The color for the foreground points. Defaults to (0, 255, 0).

(0, 255, 0)
bg_color tuple

The color for the background points. Defaults to (0, 0, 255).

(0, 0, 255)
radius int

The radius of the points. Defaults to 5.

5

Returns:

Name Type Description
tuple

A tuple of two lists of foreground and background points.

Source code in samgeo/hq_sam.py
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):
    """Show a canvas to collect foreground and background points.

    Args:
        image (str | np.ndarray): The input image.
        fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).
        bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).
        radius (int, optional): The radius of the points. Defaults to 5.

    Returns:
        tuple: A tuple of two lists of foreground and background points.
    """

    if self.image is None:
        raise ValueError("Please run set_image() first.")

    image = self.image
    fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)
    self.fg_points = fg_points
    self.bg_points = bg_points
    point_coords = fg_points + bg_points
    point_labels = [1] * len(fg_points) + [0] * len(bg_points)
    self.point_coords = point_coords
    self.point_labels = point_labels

show_map(basemap='SATELLITE', repeat_mode=True, out_dir=None, **kwargs)

Show the interactive map.

Parameters:

Name Type Description Default
basemap str

The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.

'SATELLITE'
repeat_mode bool

Whether to use the repeat mode for draw control. Defaults to True.

True
out_dir str

The path to the output directory. Defaults to None.

None

Returns:

Type Description

leafmap.Map: The map object.

Source code in samgeo/hq_sam.py
711
712
713
714
715
716
717
718
719
720
721
722
723
724
def show_map(self, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwargs):
    """Show the interactive map.

    Args:
        basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
        repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.
        out_dir (str, optional): The path to the output directory. Defaults to None.

    Returns:
        leafmap.Map: The map object.
    """
    return sam_map_gui(
        self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs
    )

show_masks(figsize=(12, 10), cmap='binary_r', axis='off', foreground=True, **kwargs)

Show the binary mask or the mask of objects with unique values.

Parameters:

Name Type Description Default
figsize tuple

The figure size. Defaults to (12, 10).

(12, 10)
cmap str

The colormap. Defaults to "binary_r".

'binary_r'
axis str

Whether to show the axis. Defaults to "off".

'off'
foreground bool

Whether to show the foreground mask only. Defaults to True.

True
**kwargs

Other arguments for save_masks().

{}
Source code in samgeo/hq_sam.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def show_masks(
    self, figsize=(12, 10), cmap="binary_r", axis="off", foreground=True, **kwargs
):
    """Show the binary mask or the mask of objects with unique values.

    Args:
        figsize (tuple, optional): The figure size. Defaults to (12, 10).
        cmap (str, optional): The colormap. Defaults to "binary_r".
        axis (str, optional): Whether to show the axis. Defaults to "off".
        foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.
        **kwargs: Other arguments for save_masks().
    """

    import matplotlib.pyplot as plt

    if self.batch:
        self.objects = cv2.imread(self.masks)
    else:
        if self.objects is None:
            self.save_masks(foreground=foreground, **kwargs)

    plt.figure(figsize=figsize)
    plt.imshow(self.objects, cmap=cmap)
    plt.axis(axis)
    plt.show()

tensor_to_numpy(index=None, output=None, mask_multiplier=255, dtype='uint8', save_args={})

Convert the predicted masks from tensors to numpy arrays.

Parameters:

Name Type Description Default
index index

The index of the mask to save. Defaults to None, which will save the mask with the highest score.

None
output str

The path to the output image. Defaults to None.

None
mask_multiplier int

The mask multiplier for the output mask, which is usually a binary mask [0, 1].

255
dtype dtype

The data type of the output image. Defaults to np.uint8.

'uint8'
save_args dict

Optional arguments for saving the output image. Defaults to {}.

{}

Returns:

Type Description

np.ndarray: The predicted mask as a numpy array.

Source code in samgeo/hq_sam.py
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
def tensor_to_numpy(
    self, index=None, output=None, mask_multiplier=255, dtype="uint8", save_args={}
):
    """Convert the predicted masks from tensors to numpy arrays.

    Args:
        index (index, optional): The index of the mask to save. Defaults to None,
            which will save the mask with the highest score.
        output (str, optional): The path to the output image. Defaults to None.
        mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
        dtype (np.dtype, optional): The data type of the output image. Defaults to np.uint8.
        save_args (dict, optional): Optional arguments for saving the output image. Defaults to {}.

    Returns:
        np.ndarray: The predicted mask as a numpy array.
    """

    boxes = self.boxes
    masks = self.masks

    image_pil = self.image
    image_np = np.array(image_pil)

    if index is None:
        index = 1

    masks = masks[:, index, :, :]
    masks = masks.squeeze(1)

    if boxes is None or (len(boxes) == 0):  # No "object" instances found
        print("No objects found in the image.")
        return
    else:
        # Create an empty image to store the mask overlays
        mask_overlay = np.zeros_like(
            image_np[..., 0], dtype=dtype
        )  # Adjusted for single channel

        for i, (box, mask) in enumerate(zip(boxes, masks)):
            # Convert tensor to numpy array if necessary and ensure it contains integers
            if isinstance(mask, torch.Tensor):
                mask = (
                    mask.cpu().numpy().astype(dtype)
                )  # If mask is on GPU, use .cpu() before .numpy()
            mask_overlay += ((mask > 0) * (i + 1)).astype(
                dtype
            )  # Assign a unique value for each mask

        # Normalize mask_overlay to be in [0, 255]
        mask_overlay = (
            mask_overlay > 0
        ) * mask_multiplier  # Binary mask in [0, 255]

    if output is not None:
        array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)
    else:
        return mask_overlay

tiff_to_geojson(tiff_path, output, simplify_tolerance=None, **kwargs)

Convert a tiff file to a GeoJSON file.

Parameters:

Name Type Description Default
tiff_path str

The path to the tiff file.

required
output str

The path to the GeoJSON file.

required
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in samgeo/hq_sam.py
817
818
819
820
821
822
823
824
825
826
827
828
829
def tiff_to_geojson(self, tiff_path, output, simplify_tolerance=None, **kwargs):
    """Convert a tiff file to a GeoJSON file.

    Args:
        tiff_path (str): The path to the tiff file.
        output (str): The path to the GeoJSON file.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.
    """

    raster_to_geojson(
        tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
    )

tiff_to_gpkg(tiff_path, output, simplify_tolerance=None, **kwargs)

Convert a tiff file to a gpkg file.

Parameters:

Name Type Description Default
tiff_path str

The path to the tiff file.

required
output str

The path to the gpkg file.

required
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in samgeo/hq_sam.py
789
790
791
792
793
794
795
796
797
798
799
800
801
def tiff_to_gpkg(self, tiff_path, output, simplify_tolerance=None, **kwargs):
    """Convert a tiff file to a gpkg file.

    Args:
        tiff_path (str): The path to the tiff file.
        output (str): The path to the gpkg file.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.
    """

    raster_to_gpkg(
        tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
    )

tiff_to_shp(tiff_path, output, simplify_tolerance=None, **kwargs)

Convert a tiff file to a shapefile.

Parameters:

Name Type Description Default
tiff_path str

The path to the tiff file.

required
output str

The path to the shapefile.

required
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in samgeo/hq_sam.py
803
804
805
806
807
808
809
810
811
812
813
814
815
def tiff_to_shp(self, tiff_path, output, simplify_tolerance=None, **kwargs):
    """Convert a tiff file to a shapefile.

    Args:
        tiff_path (str): The path to the tiff file.
        output (str): The path to the shapefile.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.
    """

    raster_to_shp(
        tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
    )

tiff_to_vector(tiff_path, output, simplify_tolerance=None, **kwargs)

Convert a tiff file to a gpkg file.

Parameters:

Name Type Description Default
tiff_path str

The path to the tiff file.

required
output str

The path to the vector file.

required
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in samgeo/hq_sam.py
775
776
777
778
779
780
781
782
783
784
785
786
787
def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):
    """Convert a tiff file to a gpkg file.

    Args:
        tiff_path (str): The path to the tiff file.
        output (str): The path to the vector file.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.
    """

    raster_to_vector(
        tiff_path, output, simplify_tolerance=simplify_tolerance, **kwargs
    )