Skip to content

detectree2 module

Tree crown delineation using detectree2.

This module provides a high-level interface for automatic tree crown delineation in aerial RGB imagery using the detectree2 library, which is based on Mask R-CNN (Detectron2 implementation).

Reference

Ball, J.G.C., et al. (2023). Accurate delineation of individual tree crowns in tropical forests from aerial RGB imagery using Mask R-CNN. Remote Sens Ecol Conserv. 9(5):641-655. https://doi.org/10.1002/rse2.332

Repository: https://github.com/PatBall1/detectree2

TreeCrownDelineator

Class for automatic tree crown delineation using detectree2.

This class provides methods for detecting and delineating individual tree crowns in aerial RGB imagery using pre-trained or custom Mask R-CNN models.

Attributes:

Name Type Description
model_path str

Path to the trained model weights.

device str

Device to run inference on ('cuda' or 'cpu').

cfg str

Detectron2 configuration object.

predictor str

Detectron2 DefaultPredictor instance.

Example

from samgeo.detectree2 import TreeCrownDelineator delineator = TreeCrownDelineator() delineator.predict("orthomosaic.tif", "crowns.gpkg")

Source code in samgeo/detectree2.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
class TreeCrownDelineator:
    """Class for automatic tree crown delineation using detectree2.

    This class provides methods for detecting and delineating individual tree
    crowns in aerial RGB imagery using pre-trained or custom Mask R-CNN models.

    Attributes:
        model_path (str): Path to the trained model weights.
        device (str): Device to run inference on ('cuda' or 'cpu').
        cfg: Detectron2 configuration object.
        predictor: Detectron2 DefaultPredictor instance.

    Example:
        >>> from samgeo.detectree2 import TreeCrownDelineator
        >>> delineator = TreeCrownDelineator()
        >>> delineator.predict("orthomosaic.tif", "crowns.gpkg")
    """

    def __init__(
        self,
        model_path: Optional[str] = None,
        model_name: str = "default",
        device: Optional[str] = None,
        confidence_threshold: float = 0.5,
        nms_threshold: float = 0.3,
    ) -> None:
        """Initialize the TreeCrownDelineator.

        Args:
            model_path: Path to a trained model file (.pth). If None, downloads
                a pre-trained model based on model_name.
            model_name: Name of pre-trained model to use if model_path is None.
                Options: 'paracou', 'sepilok', 'danum', 'default'.
            device: Device for inference ('cuda' or 'cpu'). If None, auto-detects.
            confidence_threshold: Minimum confidence score for predictions (0-1).
            nms_threshold: IoU threshold for non-maximum suppression (0-1).
        """
        _check_detectree2()
        _check_detectron2()

        import torch

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.confidence_threshold = confidence_threshold
        self.nms_threshold = nms_threshold
        self.model_path = model_path
        self.model_name = model_name
        self._predictor = None
        self._cfg = None

        # Download model if not provided
        if self.model_path is None:
            self.model_path = self._download_model(model_name)

        logger.info(f"TreeCrownDelineator initialized with model: {self.model_path}")
        logger.info(f"Using device: {self.device}")

    def _download_model(self, model_name: str) -> str:
        """Download a pre-trained model from the model garden.

        Args:
            model_name: Name of the model to download.

        Returns:
            Path to the downloaded model file.
        """
        from samgeo.common import download_file

        if model_name not in PRETRAINED_MODELS:
            available = ", ".join(PRETRAINED_MODELS.keys())
            raise ValueError(
                f"Unknown model '{model_name}'. Available models: {available}"
            )

        url = PRETRAINED_MODELS[model_name]
        filename = os.path.basename(url)

        # Download to cache directory
        cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "detectree2")
        os.makedirs(cache_dir, exist_ok=True)
        model_path = os.path.join(cache_dir, filename)

        if not os.path.exists(model_path):
            logger.info(f"Downloading pre-trained model '{model_name}' from {url}")
            download_file(url, model_path)
            logger.info(f"Model downloaded to {model_path}")
        else:
            logger.info(f"Using cached model: {model_path}")

        return model_path

    def _setup_predictor(self) -> None:
        """Set up the Detectron2 predictor with the model configuration."""
        if self._predictor is not None:
            return

        from detectree2.models.train import setup_cfg
        from detectron2.engine import DefaultPredictor

        self._cfg = setup_cfg(update_model=self.model_path)
        self._cfg.MODEL.DEVICE = self.device
        self._cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = self.confidence_threshold
        self._cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = self.nms_threshold

        self._predictor = DefaultPredictor(self._cfg)
        logger.info("Predictor initialized successfully")

    def predict(
        self,
        image_path: str,
        output_path: str,
        tile_width: int = 40,
        tile_height: int = 40,
        buffer: int = 30,
        simplify_tolerance: float = 0.3,
        min_confidence: float = 0.5,
        iou_threshold: float = 0.6,
        output_format: str = "gpkg",
        cleanup: bool = True,
        **kwargs: Any,
    ) -> "gpd.GeoDataFrame":
        """Detect and delineate tree crowns in an orthomosaic.

        Args:
            image_path: Path to the input orthomosaic (GeoTIFF).
            output_path: Path for the output crown polygons.
            tile_width: Width of prediction tiles in meters.
            tile_height: Height of prediction tiles in meters.
            buffer: Buffer size around tiles in meters (for edge handling).
            simplify_tolerance: Tolerance for simplifying crown geometries.
            min_confidence: Minimum confidence score to keep predictions.
            iou_threshold: IoU threshold for removing overlapping crowns.
            output_format: Output format ('gpkg', 'shp', 'geojson').
            cleanup: Whether to remove temporary files after prediction.
            **kwargs: Additional arguments passed to tile_data.

        Returns:
            GeoDataFrame containing the detected tree crown polygons.
        """
        import geopandas as gpd

        from detectree2.models.outputs import (
            clean_crowns,
            project_to_geojson,
            stitch_crowns,
        )
        from detectree2.models.predict import predict_on_data
        from detectree2.preprocessing.tiling import tile_data

        # Initialize predictor
        self._setup_predictor()

        # Create temporary directory for tiles
        temp_dir = tempfile.mkdtemp(prefix="detectree2_")
        tiles_dir = os.path.join(temp_dir, "tiles")
        pred_dir = os.path.join(temp_dir, "predictions")
        geo_dir = os.path.join(temp_dir, "predictions_geo")

        try:
            logger.info(f"Tiling orthomosaic: {image_path}")
            tile_data(
                image_path,
                tiles_dir,
                buffer=buffer,
                tile_width=tile_width,
                tile_height=tile_height,
                **kwargs,
            )

            logger.info("Running predictions on tiles...")
            predict_on_data(tiles_dir, pred_dir, predictor=self._predictor)

            logger.info("Projecting predictions to geographic coordinates...")
            os.makedirs(geo_dir, exist_ok=True)
            project_to_geojson(tiles_dir, pred_dir, geo_dir)

            logger.info("Stitching and cleaning crown predictions...")
            crowns = stitch_crowns(geo_dir)
            crowns = clean_crowns(crowns, iou_threshold, confidence=min_confidence)

            # Simplify geometries
            if simplify_tolerance > 0:
                crowns = crowns.set_geometry(crowns.simplify(simplify_tolerance))

            # Save to file
            driver_map = {
                "gpkg": "GPKG",
                "shp": "ESRI Shapefile",
                "geojson": "GeoJSON",
            }
            driver = driver_map.get(output_format.lower(), "GPKG")

            crowns.to_file(output_path, driver=driver)
            logger.info(f"Crown polygons saved to: {output_path}")

            return crowns

        finally:
            if cleanup and os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
                logger.debug(f"Cleaned up temporary directory: {temp_dir}")

    def predict_tiles(
        self,
        tiles_dir: str,
        output_dir: Optional[str] = None,
    ) -> List[str]:
        """Run predictions on pre-tiled images.

        Args:
            tiles_dir: Directory containing tiled images.
            output_dir: Directory to save predictions. If None, saves in tiles_dir.

        Returns:
            List of paths to prediction files.
        """
        from detectree2.models.predict import predict_on_data

        self._setup_predictor()

        if output_dir is None:
            output_dir = tiles_dir

        pred_dir = os.path.join(output_dir, "predictions")
        logger.info(f"Running predictions on tiles in: {tiles_dir}")
        predict_on_data(tiles_dir, pred_dir, predictor=self._predictor)

        # Find prediction files
        if os.path.exists(pred_dir):
            pred_files = list(Path(pred_dir).glob("*.json"))
            logger.info(f"Generated {len(pred_files)} prediction files")
            return [str(f) for f in pred_files]

        return []

__init__(model_path=None, model_name='default', device=None, confidence_threshold=0.5, nms_threshold=0.3)

Initialize the TreeCrownDelineator.

Parameters:

Name Type Description Default
model_path Optional[str]

Path to a trained model file (.pth). If None, downloads a pre-trained model based on model_name.

None
model_name str

Name of pre-trained model to use if model_path is None. Options: 'paracou', 'sepilok', 'danum', 'default'.

'default'
device Optional[str]

Device for inference ('cuda' or 'cpu'). If None, auto-detects.

None
confidence_threshold float

Minimum confidence score for predictions (0-1).

0.5
nms_threshold float

IoU threshold for non-maximum suppression (0-1).

0.3
Source code in samgeo/detectree2.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(
    self,
    model_path: Optional[str] = None,
    model_name: str = "default",
    device: Optional[str] = None,
    confidence_threshold: float = 0.5,
    nms_threshold: float = 0.3,
) -> None:
    """Initialize the TreeCrownDelineator.

    Args:
        model_path: Path to a trained model file (.pth). If None, downloads
            a pre-trained model based on model_name.
        model_name: Name of pre-trained model to use if model_path is None.
            Options: 'paracou', 'sepilok', 'danum', 'default'.
        device: Device for inference ('cuda' or 'cpu'). If None, auto-detects.
        confidence_threshold: Minimum confidence score for predictions (0-1).
        nms_threshold: IoU threshold for non-maximum suppression (0-1).
    """
    _check_detectree2()
    _check_detectron2()

    import torch

    self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    self.confidence_threshold = confidence_threshold
    self.nms_threshold = nms_threshold
    self.model_path = model_path
    self.model_name = model_name
    self._predictor = None
    self._cfg = None

    # Download model if not provided
    if self.model_path is None:
        self.model_path = self._download_model(model_name)

    logger.info(f"TreeCrownDelineator initialized with model: {self.model_path}")
    logger.info(f"Using device: {self.device}")

predict(image_path, output_path, tile_width=40, tile_height=40, buffer=30, simplify_tolerance=0.3, min_confidence=0.5, iou_threshold=0.6, output_format='gpkg', cleanup=True, **kwargs)

Detect and delineate tree crowns in an orthomosaic.

Parameters:

Name Type Description Default
image_path str

Path to the input orthomosaic (GeoTIFF).

required
output_path str

Path for the output crown polygons.

required
tile_width int

Width of prediction tiles in meters.

40
tile_height int

Height of prediction tiles in meters.

40
buffer int

Buffer size around tiles in meters (for edge handling).

30
simplify_tolerance float

Tolerance for simplifying crown geometries.

0.3
min_confidence float

Minimum confidence score to keep predictions.

0.5
iou_threshold float

IoU threshold for removing overlapping crowns.

0.6
output_format str

Output format ('gpkg', 'shp', 'geojson').

'gpkg'
cleanup bool

Whether to remove temporary files after prediction.

True
**kwargs Any

Additional arguments passed to tile_data.

{}

Returns:

Type Description
GeoDataFrame

GeoDataFrame containing the detected tree crown polygons.

Source code in samgeo/detectree2.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def predict(
    self,
    image_path: str,
    output_path: str,
    tile_width: int = 40,
    tile_height: int = 40,
    buffer: int = 30,
    simplify_tolerance: float = 0.3,
    min_confidence: float = 0.5,
    iou_threshold: float = 0.6,
    output_format: str = "gpkg",
    cleanup: bool = True,
    **kwargs: Any,
) -> "gpd.GeoDataFrame":
    """Detect and delineate tree crowns in an orthomosaic.

    Args:
        image_path: Path to the input orthomosaic (GeoTIFF).
        output_path: Path for the output crown polygons.
        tile_width: Width of prediction tiles in meters.
        tile_height: Height of prediction tiles in meters.
        buffer: Buffer size around tiles in meters (for edge handling).
        simplify_tolerance: Tolerance for simplifying crown geometries.
        min_confidence: Minimum confidence score to keep predictions.
        iou_threshold: IoU threshold for removing overlapping crowns.
        output_format: Output format ('gpkg', 'shp', 'geojson').
        cleanup: Whether to remove temporary files after prediction.
        **kwargs: Additional arguments passed to tile_data.

    Returns:
        GeoDataFrame containing the detected tree crown polygons.
    """
    import geopandas as gpd

    from detectree2.models.outputs import (
        clean_crowns,
        project_to_geojson,
        stitch_crowns,
    )
    from detectree2.models.predict import predict_on_data
    from detectree2.preprocessing.tiling import tile_data

    # Initialize predictor
    self._setup_predictor()

    # Create temporary directory for tiles
    temp_dir = tempfile.mkdtemp(prefix="detectree2_")
    tiles_dir = os.path.join(temp_dir, "tiles")
    pred_dir = os.path.join(temp_dir, "predictions")
    geo_dir = os.path.join(temp_dir, "predictions_geo")

    try:
        logger.info(f"Tiling orthomosaic: {image_path}")
        tile_data(
            image_path,
            tiles_dir,
            buffer=buffer,
            tile_width=tile_width,
            tile_height=tile_height,
            **kwargs,
        )

        logger.info("Running predictions on tiles...")
        predict_on_data(tiles_dir, pred_dir, predictor=self._predictor)

        logger.info("Projecting predictions to geographic coordinates...")
        os.makedirs(geo_dir, exist_ok=True)
        project_to_geojson(tiles_dir, pred_dir, geo_dir)

        logger.info("Stitching and cleaning crown predictions...")
        crowns = stitch_crowns(geo_dir)
        crowns = clean_crowns(crowns, iou_threshold, confidence=min_confidence)

        # Simplify geometries
        if simplify_tolerance > 0:
            crowns = crowns.set_geometry(crowns.simplify(simplify_tolerance))

        # Save to file
        driver_map = {
            "gpkg": "GPKG",
            "shp": "ESRI Shapefile",
            "geojson": "GeoJSON",
        }
        driver = driver_map.get(output_format.lower(), "GPKG")

        crowns.to_file(output_path, driver=driver)
        logger.info(f"Crown polygons saved to: {output_path}")

        return crowns

    finally:
        if cleanup and os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
            logger.debug(f"Cleaned up temporary directory: {temp_dir}")

predict_tiles(tiles_dir, output_dir=None)

Run predictions on pre-tiled images.

Parameters:

Name Type Description Default
tiles_dir str

Directory containing tiled images.

required
output_dir Optional[str]

Directory to save predictions. If None, saves in tiles_dir.

None

Returns:

Type Description
List[str]

List of paths to prediction files.

Source code in samgeo/detectree2.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def predict_tiles(
    self,
    tiles_dir: str,
    output_dir: Optional[str] = None,
) -> List[str]:
    """Run predictions on pre-tiled images.

    Args:
        tiles_dir: Directory containing tiled images.
        output_dir: Directory to save predictions. If None, saves in tiles_dir.

    Returns:
        List of paths to prediction files.
    """
    from detectree2.models.predict import predict_on_data

    self._setup_predictor()

    if output_dir is None:
        output_dir = tiles_dir

    pred_dir = os.path.join(output_dir, "predictions")
    logger.info(f"Running predictions on tiles in: {tiles_dir}")
    predict_on_data(tiles_dir, pred_dir, predictor=self._predictor)

    # Find prediction files
    if os.path.exists(pred_dir):
        pred_files = list(Path(pred_dir).glob("*.json"))
        logger.info(f"Generated {len(pred_files)} prediction files")
        return [str(f) for f in pred_files]

    return []

download_sample_data(output_dir='./detectree2_sample')

Download sample data for testing detectree2.

Parameters:

Name Type Description Default
output_dir str

Directory to save the sample data.

'./detectree2_sample'

Returns:

Type Description
str

Path to the output directory.

Source code in samgeo/detectree2.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
def download_sample_data(output_dir: str = "./detectree2_sample") -> str:
    """Download sample data for testing detectree2.

    Args:
        output_dir: Directory to save the sample data.

    Returns:
        Path to the output directory.
    """
    from samgeo.common import download_file

    sample_url = "https://zenodo.org/records/8136161/files/Paracou_sample.zip"

    os.makedirs(output_dir, exist_ok=True)
    zip_path = os.path.join(output_dir, "sample.zip")

    logger.info(f"Downloading sample data from {sample_url}")
    download_file(sample_url, zip_path)

    # Extract
    import zipfile

    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(output_dir)

    os.remove(zip_path)
    logger.info(f"Sample data extracted to: {output_dir}")

    return output_dir

list_pretrained_models()

List available pre-trained models.

Returns:

Type Description
Dict[str, str]

Dictionary mapping model names to their download URLs.

Source code in samgeo/detectree2.py
465
466
467
468
469
470
471
def list_pretrained_models() -> Dict[str, str]:
    """List available pre-trained models.

    Returns:
        Dictionary mapping model names to their download URLs.
    """
    return PRETRAINED_MODELS.copy()

prepare_training_data(image_path, crowns_path, output_dir, tile_width=40, tile_height=40, buffer=30, threshold=0.6, test_fraction=0.15, mode='rgb')

Prepare training and test data for detectree2.

Parameters:

Name Type Description Default
image_path str

Path to the input orthomosaic (GeoTIFF).

required
crowns_path str

Path to manually delineated crown polygons.

required
output_dir str

Directory to save the training data.

required
tile_width int

Width of tiles in meters.

40
tile_height int

Height of tiles in meters.

40
buffer int

Buffer size around tiles in meters.

30
threshold float

Minimum crown coverage to keep a tile.

0.6
test_fraction float

Fraction of data to use for testing (0-1).

0.15
mode str

Image mode ('rgb' or 'ms' for multispectral).

'rgb'

Returns:

Type Description
Tuple[str, str]

Tuple of (train_dir, test_dir) paths.

Source code in samgeo/detectree2.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def prepare_training_data(
    image_path: str,
    crowns_path: str,
    output_dir: str,
    tile_width: int = 40,
    tile_height: int = 40,
    buffer: int = 30,
    threshold: float = 0.6,
    test_fraction: float = 0.15,
    mode: str = "rgb",
) -> Tuple[str, str]:
    """Prepare training and test data for detectree2.

    Args:
        image_path: Path to the input orthomosaic (GeoTIFF).
        crowns_path: Path to manually delineated crown polygons.
        output_dir: Directory to save the training data.
        tile_width: Width of tiles in meters.
        tile_height: Height of tiles in meters.
        buffer: Buffer size around tiles in meters.
        threshold: Minimum crown coverage to keep a tile.
        test_fraction: Fraction of data to use for testing (0-1).
        mode: Image mode ('rgb' or 'ms' for multispectral).

    Returns:
        Tuple of (train_dir, test_dir) paths.
    """
    _check_detectree2()

    from detectree2.preprocessing.tiling import tile_data, to_traintest_folders

    # First tile the data
    tile_orthomosaic(
        image_path,
        output_dir,
        tile_width=tile_width,
        tile_height=tile_height,
        buffer=buffer,
        crowns_path=crowns_path,
        threshold=threshold,
        mode=mode,
    )

    # Split into train/test
    logger.info(f"Splitting data into train/test (test fraction: {test_fraction})")
    to_traintest_folders(output_dir, output_dir, test_frac=test_fraction)

    train_dir = os.path.join(output_dir, "train")
    test_dir = os.path.join(output_dir, "test")

    logger.info(f"Training data: {train_dir}")
    logger.info(f"Test data: {test_dir}")

    return train_dir, test_dir

stitch_predictions(geo_predictions_dir, output_path, iou_threshold=0.6, min_confidence=0.5, simplify_tolerance=0.3, output_format='gpkg')

Stitch and clean tile predictions into a single crown map.

Parameters:

Name Type Description Default
geo_predictions_dir str

Directory containing geo-referenced predictions.

required
output_path str

Path for the output crown polygons.

required
iou_threshold float

IoU threshold for removing overlapping crowns.

0.6
min_confidence float

Minimum confidence score to keep predictions.

0.5
simplify_tolerance float

Tolerance for simplifying crown geometries.

0.3
output_format str

Output format ('gpkg', 'shp', 'geojson').

'gpkg'

Returns:

Type Description
GeoDataFrame

GeoDataFrame containing the stitched and cleaned crown polygons.

Source code in samgeo/detectree2.py
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def stitch_predictions(
    geo_predictions_dir: str,
    output_path: str,
    iou_threshold: float = 0.6,
    min_confidence: float = 0.5,
    simplify_tolerance: float = 0.3,
    output_format: str = "gpkg",
) -> "gpd.GeoDataFrame":
    """Stitch and clean tile predictions into a single crown map.

    Args:
        geo_predictions_dir: Directory containing geo-referenced predictions.
        output_path: Path for the output crown polygons.
        iou_threshold: IoU threshold for removing overlapping crowns.
        min_confidence: Minimum confidence score to keep predictions.
        simplify_tolerance: Tolerance for simplifying crown geometries.
        output_format: Output format ('gpkg', 'shp', 'geojson').

    Returns:
        GeoDataFrame containing the stitched and cleaned crown polygons.
    """
    _check_detectree2()

    import geopandas as gpd

    from detectree2.models.outputs import clean_crowns, stitch_crowns

    logger.info(f"Stitching predictions from: {geo_predictions_dir}")
    crowns = stitch_crowns(geo_predictions_dir)

    logger.info("Cleaning overlapping crowns...")
    crowns = clean_crowns(crowns, iou_threshold, confidence=min_confidence)

    if simplify_tolerance > 0:
        crowns = crowns.set_geometry(crowns.simplify(simplify_tolerance))

    # Save to file
    driver_map = {
        "gpkg": "GPKG",
        "shp": "ESRI Shapefile",
        "geojson": "GeoJSON",
    }
    driver = driver_map.get(output_format.lower(), "GPKG")

    crowns.to_file(output_path, driver=driver)
    logger.info(f"Crown polygons saved to: {output_path}")

    return crowns

tile_orthomosaic(image_path, output_dir, tile_width=40, tile_height=40, buffer=30, crowns_path=None, threshold=0.6, mode='rgb', **kwargs)

Tile an orthomosaic for training or prediction.

Parameters:

Name Type Description Default
image_path str

Path to the input orthomosaic (GeoTIFF).

required
output_dir str

Directory to save the tiles.

required
tile_width int

Width of tiles in meters.

40
tile_height int

Height of tiles in meters.

40
buffer int

Buffer size around tiles in meters.

30
crowns_path Optional[str]

Path to crown polygons (for training data preparation).

None
threshold float

Minimum crown coverage to keep a tile (when crowns provided).

0.6
mode str

Image mode ('rgb' or 'ms' for multispectral).

'rgb'
**kwargs Any

Additional arguments passed to tile_data.

{}

Returns:

Type Description
str

Path to the output directory containing tiles.

Source code in samgeo/detectree2.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def tile_orthomosaic(
    image_path: str,
    output_dir: str,
    tile_width: int = 40,
    tile_height: int = 40,
    buffer: int = 30,
    crowns_path: Optional[str] = None,
    threshold: float = 0.6,
    mode: str = "rgb",
    **kwargs: Any,
) -> str:
    """Tile an orthomosaic for training or prediction.

    Args:
        image_path: Path to the input orthomosaic (GeoTIFF).
        output_dir: Directory to save the tiles.
        tile_width: Width of tiles in meters.
        tile_height: Height of tiles in meters.
        buffer: Buffer size around tiles in meters.
        crowns_path: Path to crown polygons (for training data preparation).
        threshold: Minimum crown coverage to keep a tile (when crowns provided).
        mode: Image mode ('rgb' or 'ms' for multispectral).
        **kwargs: Additional arguments passed to tile_data.

    Returns:
        Path to the output directory containing tiles.
    """
    _check_detectree2()

    import geopandas as gpd
    import rasterio

    from detectree2.preprocessing.tiling import tile_data

    crowns = None
    if crowns_path is not None:
        # Read crowns and match CRS to image
        with rasterio.open(image_path) as src:
            img_crs = src.crs
        crowns = gpd.read_file(crowns_path)
        crowns = crowns.to_crs(img_crs)

    logger.info(f"Tiling orthomosaic: {image_path}")
    tile_data(
        image_path,
        output_dir,
        buffer=buffer,
        tile_width=tile_width,
        tile_height=tile_height,
        crowns=crowns,
        threshold=threshold,
        mode=mode,
        **kwargs,
    )

    logger.info(f"Tiles saved to: {output_dir}")
    return output_dir