Skip to content

REST API

segment-geospatial includes a built-in REST API powered by FastAPI that allows you to run image segmentation over HTTP. This is useful for integrating segmentation into web applications, pipelines, and non-Python clients.

Installation

Install the API dependencies with the api extra:

1
pip install "segment-geospatial[api]"

To also install a specific SAM model backend, combine extras:

1
pip install "segment-geospatial[api,samgeo3]"

Starting the Server

Use the samgeo-api command:

1
samgeo-api

Options:

1
2
3
samgeo-api --host 0.0.0.0 --port 8000        # Custom host/port
samgeo-api --preload sam2:sam2-hiera-large     # Preload a model at startup
samgeo-api --reload                            # Auto-reload for development

Alternatively, use uvicorn directly:

1
uvicorn samgeo.api:app --host 0.0.0.0 --port 8000

Once running, interactive API docs (Swagger UI) are available at http://localhost:8000/docs.

Endpoints

Health Check

1
GET /health

Returns the server status and version.

1
curl http://localhost:8000/health
1
{"status": "ok", "version": "1.2.3"}

List Models

1
GET /models

Returns available model versions/IDs and which models are currently loaded in memory.

1
curl http://localhost:8000/models

Clear Models

1
DELETE /models

Clears the model cache and frees GPU memory.

1
curl -X DELETE http://localhost:8000/models

Automatic Segmentation

1
POST /segment/automatic

Runs automatic mask generation on an uploaded image. Supports SAM, SAM2, and SAM3.

Parameters (multipart form):

Parameter Type Default Description
file file required Image file (TIFF, PNG, JPEG)
model_version string sam2 One of sam, sam2, sam3
model_id string auto Model identifier (e.g., sam2-hiera-large)
output_format string geojson One of geojson, geotiff, png
foreground bool true Extract foreground objects only
unique bool true Assign unique ID to each object
min_size int 0 Minimum mask size in pixels
max_size int none Maximum mask size in pixels
points_per_side int 32 Points sampled per side (SAM/SAM2)
pred_iou_thresh float 0.8 IoU threshold for filtering
stability_score_thresh float 0.95 Stability score threshold

Example:

1
2
3
4
curl -X POST http://localhost:8000/segment/automatic \
  -F "file=@image.tif" \
  -F "model_version=sam2" \
  -F "output_format=geojson"

Prompt-based Segmentation

1
POST /segment/predict

Runs segmentation with point or bounding box prompts. Supports SAM and SAM2.

Parameters (multipart form):

Parameter Type Default Description
file file required Image file (TIFF, PNG, JPEG)
model_version string sam2 One of sam, sam2
model_id string auto Model identifier
output_format string geojson One of geojson, geotiff, png
point_coords string none JSON array of [[x, y], ...]
point_labels string none JSON array of [1, 0, ...] (1=foreground, 0=background)
boxes string none JSON array of [[xmin, ymin, xmax, ymax], ...]
point_crs string none CRS string (e.g., EPSG:4326)
multimask_output bool false Return multiple masks per prompt

Example with point prompts:

1
2
3
4
5
curl -X POST http://localhost:8000/segment/predict \
  -F "file=@image.tif" \
  -F "point_coords=[[100, 200]]" \
  -F "point_labels=[1]" \
  -F "output_format=geojson"

Example with box prompts:

1
2
3
4
curl -X POST http://localhost:8000/segment/predict \
  -F "file=@image.tif" \
  -F "boxes=[[10, 20, 300, 400]]" \
  -F "output_format=geotiff"

Text-prompt Segmentation

1
POST /segment/text

Runs text-prompt segmentation using SAM3.

Parameters (multipart form):

Parameter Type Default Description
file file required Image file (TIFF, PNG, JPEG)
prompt string required Text description (e.g., building, tree)
model_id string auto SAM3 model identifier
backend string meta One of meta, transformers
output_format string geojson One of geojson, geotiff, png
confidence_threshold float 0.5 Detection confidence threshold
min_size int 0 Minimum mask size in pixels
max_size int none Maximum mask size in pixels

Example:

1
2
3
4
curl -X POST http://localhost:8000/segment/text \
  -F "file=@image.tif" \
  -F "prompt=building" \
  -F "output_format=geojson"

Caching

The API automatically caches models and image encodings for better performance:

  • Model cache: Models are loaded once and reused across requests. Use DELETE /models to free GPU memory.
  • Image cache: When the same image is sent multiple times (e.g., with different prompts), the expensive image encoding step is skipped. This makes subsequent requests significantly faster.

Example timing with a 13 MB GeoTIFF:

Request Description Time
1st Model load + image encoding ~7s
2nd Same image, different prompt ~0.4s
3rd Same image, another prompt ~0.2s

Python Client Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import requests

url = "http://localhost:8000/segment/text"

with open("image.tif", "rb") as f:
    response = requests.post(
        url,
        files={"file": ("image.tif", f, "image/tiff")},
        data={"prompt": "building", "output_format": "geojson"},
    )

geojson = response.json()
print(f"Found {len(geojson['features'])} features")

API Reference

REST API for segment-geospatial.

Provides FastAPI endpoints for image segmentation using SAM, SAM2, and SAM3 models. Install with: pip install segment-geospatial[api]

Usage

samgeo-api # Start on default port 8000 samgeo-api --port 9000 # Custom port samgeo-api --preload sam2:sam2-hiera-large # Preload a model uvicorn samgeo.api:app # Direct uvicorn usage

clear_models()

Clear the model cache and free GPU memory.

Source code in samgeo/api.py
312
313
314
315
316
317
318
319
320
321
322
323
324
@app.delete("/models")
def clear_models():
    """Clear the model cache and free GPU memory."""
    _model_cache.clear()
    _image_hash_cache.clear()
    try:
        import torch

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except ImportError:
        pass
    return {"status": "cleared"}

get_model(model_version, model_id=None, **kwargs)

Get or create a cached model instance.

Parameters:

Name Type Description Default
model_version str

One of "sam", "sam2", "sam3".

required
model_id Optional[str]

Specific model identifier. Uses default if None.

None
**kwargs

Additional keyword arguments for model initialization.

{}

Returns:

Name Type Description
tuple

(model_instance, threading.Lock)

Raises:

Type Description
HTTPException

If model_version or model_id is invalid, or dependencies are missing.

Source code in samgeo/api.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def get_model(model_version: str, model_id: Optional[str] = None, **kwargs):
    """Get or create a cached model instance.

    Args:
        model_version: One of "sam", "sam2", "sam3".
        model_id: Specific model identifier. Uses default if None.
        **kwargs: Additional keyword arguments for model initialization.

    Returns:
        tuple: (model_instance, threading.Lock)

    Raises:
        HTTPException: If model_version or model_id is invalid, or
            dependencies are missing.
    """
    if model_version not in _DEFAULT_MODEL_IDS:
        raise HTTPException(
            status_code=400,
            detail=(
                f"Invalid model_version '{model_version}'. "
                f"Must be one of: {list(_DEFAULT_MODEL_IDS.keys())}"
            ),
        )

    if model_id is None:
        model_id = _DEFAULT_MODEL_IDS[model_version]

    valid_ids = _AVAILABLE_MODELS[model_version]
    if model_id not in valid_ids:
        raise HTTPException(
            status_code=400,
            detail=(
                f"Invalid model_id '{model_id}' for {model_version}. "
                f"Must be one of: {valid_ids}"
            ),
        )

    key = (model_version, model_id)
    with _model_cache_lock:
        if key in _model_cache:
            logger.info("Model cache hit for %s", key)
            return _model_cache[key]

        logger.info("Loading model %s", key)
        extra = _EXTRAS_MAP.get(model_version, model_version)
        try:
            if model_version == "sam":
                from samgeo.samgeo import SamGeo

                model = SamGeo(model_type=model_id, **kwargs)
            elif model_version == "sam2":
                from samgeo.samgeo2 import SamGeo2

                model = SamGeo2(model_id=model_id, **kwargs)
            elif model_version == "sam3":
                from samgeo.samgeo3 import SamGeo3

                model = SamGeo3(**kwargs)
        except ImportError as e:
            raise HTTPException(
                status_code=503,
                detail=(
                    f"Dependencies for {model_version} are not installed. "
                    f"Install with: pip install segment-geospatial[{extra}]. "
                    f"Error: {e}"
                ),
            )
        _model_cache[key] = (model, threading.Lock())
        return _model_cache[key]

health()

Health check endpoint.

Source code in samgeo/api.py
299
300
301
302
@app.get("/health")
def health():
    """Health check endpoint."""
    return {"status": "ok", "version": __version__}

list_models()

List available and currently loaded models.

Source code in samgeo/api.py
305
306
307
308
309
@app.get("/models")
def list_models():
    """List available and currently loaded models."""
    loaded = [list(key) for key in _model_cache]
    return {"models": _AVAILABLE_MODELS, "loaded": loaded}

main()

Entry point for the samgeo-api console script.

Source code in samgeo/api.py
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
def main():
    """Entry point for the samgeo-api console script."""
    parser = argparse.ArgumentParser(
        description="Run the samgeo REST API server."
    )
    parser.add_argument(
        "--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)"
    )
    parser.add_argument(
        "--port", type=int, default=8000, help="Port to listen on (default: 8000)"
    )
    parser.add_argument(
        "--reload", action="store_true", help="Enable auto-reload for development"
    )
    parser.add_argument(
        "--preload",
        type=str,
        default=None,
        help="Preload a model at startup, e.g. 'sam2:sam2-hiera-large'",
    )
    args = parser.parse_args()

    if args.preload:
        if ":" not in args.preload:
            parser.error(
                "Invalid --preload format. "
                "Expected 'model_version:model_id', e.g. 'sam2:sam2-hiera-large'"
            )
        version, mid = args.preload.split(":", 1)
        get_model(version, mid)

    uvicorn.run("samgeo.api:app", host=args.host, port=args.port, reload=args.reload)

segment_automatic(file=File(...), model_version=Form('sam2'), model_id=Form(None), output_format=Form('geojson'), foreground=Form(True), unique=Form(True), min_size=Form(0), max_size=Form(None), points_per_side=Form(32), pred_iou_thresh=Form(0.8), stability_score_thresh=Form(0.95)) async

Run automatic mask generation on an uploaded image.

Parameters:

Name Type Description Default
file UploadFile

Image file (TIFF, PNG, JPEG).

File(...)
model_version str

One of "sam", "sam2", "sam3".

Form('sam2')
model_id Optional[str]

Specific model identifier.

Form(None)
output_format str

One of "geojson", "geotiff", "png".

Form('geojson')
foreground bool

Whether to extract foreground objects only.

Form(True)
unique bool

Whether to assign unique IDs to each object.

Form(True)
min_size int

Minimum mask size in pixels.

Form(0)
max_size Optional[int]

Maximum mask size in pixels.

Form(None)
points_per_side int

Number of points sampled per side (SAM/SAM2).

Form(32)
pred_iou_thresh float

IoU threshold for filtering masks.

Form(0.8)
stability_score_thresh float

Stability score threshold for filtering.

Form(0.95)

Returns:

Type Description

Segmentation result in the requested format.

Source code in samgeo/api.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
@app.post("/segment/automatic")
async def segment_automatic(
    file: UploadFile = File(...),
    model_version: str = Form("sam2"),
    model_id: Optional[str] = Form(None),
    output_format: str = Form("geojson"),
    foreground: bool = Form(True),
    unique: bool = Form(True),
    min_size: int = Form(0),
    max_size: Optional[int] = Form(None),
    points_per_side: int = Form(32),
    pred_iou_thresh: float = Form(0.8),
    stability_score_thresh: float = Form(0.95),
):
    """Run automatic mask generation on an uploaded image.

    Args:
        file: Image file (TIFF, PNG, JPEG).
        model_version: One of "sam", "sam2", "sam3".
        model_id: Specific model identifier.
        output_format: One of "geojson", "geotiff", "png".
        foreground: Whether to extract foreground objects only.
        unique: Whether to assign unique IDs to each object.
        min_size: Minimum mask size in pixels.
        max_size: Maximum mask size in pixels.
        points_per_side: Number of points sampled per side (SAM/SAM2).
        pred_iou_thresh: IoU threshold for filtering masks.
        stability_score_thresh: Stability score threshold for filtering.

    Returns:
        Segmentation result in the requested format.
    """
    _validate_output_format(output_format)
    max_size = _normalize_max_size(max_size)
    tmpdir = tempfile.mkdtemp()
    try:
        input_path, image_hash = await _save_upload(file, tmpdir)
        output_path = os.path.join(tmpdir, "mask.tif")

        t_start = time.time()
        if model_version == "sam3":
            model, lock = get_model(model_version, model_id)
            model_key = (model_version, model_id or _DEFAULT_MODEL_IDS[model_version])
            with lock:
                _set_image_cached(model, model_key, input_path, image_hash)
                model.generate_masks(
                    prompt="everything",
                    min_size=min_size,
                    max_size=max_size,
                )
                model.save_masks(output=output_path, unique=unique)
        else:
            sam_kwargs = {
                "points_per_side": points_per_side,
                "pred_iou_thresh": pred_iou_thresh,
                "stability_score_thresh": stability_score_thresh,
            }
            if model_version == "sam":
                model, lock = get_model(
                    model_version, model_id, sam_kwargs=sam_kwargs
                )
            else:
                model, lock = get_model(model_version, model_id, **sam_kwargs)

            with lock:
                model.generate(
                    source=input_path,
                    output=output_path,
                    foreground=foreground,
                    unique=unique,
                    min_size=min_size,
                    max_size=max_size,
                )

        t_inference = time.time() - t_start
        logger.info(
            "Automatic segmentation completed in %.2fs (model: %s)",
            t_inference,
            model_version,
        )
        return _format_response(output_path, output_format, tmpdir)
    except HTTPException:
        _cleanup_tmpdir(tmpdir)
        raise
    except Exception as e:
        _cleanup_tmpdir(tmpdir)
        raise HTTPException(status_code=500, detail=str(e))

segment_predict(file=File(...), model_version=Form('sam2'), model_id=Form(None), output_format=Form('geojson'), point_coords=Form(None), point_labels=Form(None), boxes=Form(None), point_crs=Form(None), multimask_output=Form(False), min_size=Form(0), max_size=Form(None)) async

Run prompt-based segmentation with points or bounding boxes.

Parameters:

Name Type Description Default
file UploadFile

Image file (TIFF, PNG, JPEG).

File(...)
model_version str

One of "sam", "sam2".

Form('sam2')
model_id Optional[str]

Specific model identifier.

Form(None)
output_format str

One of "geojson", "geotiff", "png".

Form('geojson')
point_coords Optional[str]

JSON string of [[x, y], ...] coordinate pairs.

Form(None)
point_labels Optional[str]

JSON string of [1, 0, ...] labels (1=foreground, 0=background).

Form(None)
boxes Optional[str]

JSON string of [[xmin, ymin, xmax, ymax], ...] bounding boxes.

Form(None)
point_crs Optional[str]

CRS string (e.g., "EPSG:4326") for point coordinates.

Form(None)
multimask_output bool

Whether to return multiple masks per prompt.

Form(False)
min_size int

Minimum mask size in pixels.

Form(0)
max_size Optional[int]

Maximum mask size in pixels.

Form(None)

Returns:

Type Description

Segmentation result in the requested format.

Source code in samgeo/api.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
@app.post("/segment/predict")
async def segment_predict(
    file: UploadFile = File(...),
    model_version: str = Form("sam2"),
    model_id: Optional[str] = Form(None),
    output_format: str = Form("geojson"),
    point_coords: Optional[str] = Form(None),
    point_labels: Optional[str] = Form(None),
    boxes: Optional[str] = Form(None),
    point_crs: Optional[str] = Form(None),
    multimask_output: bool = Form(False),
    min_size: int = Form(0),
    max_size: Optional[int] = Form(None),
):
    """Run prompt-based segmentation with points or bounding boxes.

    Args:
        file: Image file (TIFF, PNG, JPEG).
        model_version: One of "sam", "sam2".
        model_id: Specific model identifier.
        output_format: One of "geojson", "geotiff", "png".
        point_coords: JSON string of [[x, y], ...] coordinate pairs.
        point_labels: JSON string of [1, 0, ...] labels (1=foreground,
            0=background).
        boxes: JSON string of [[xmin, ymin, xmax, ymax], ...] bounding boxes.
        point_crs: CRS string (e.g., "EPSG:4326") for point coordinates.
        multimask_output: Whether to return multiple masks per prompt.
        min_size: Minimum mask size in pixels.
        max_size: Maximum mask size in pixels.

    Returns:
        Segmentation result in the requested format.
    """
    _validate_output_format(output_format)

    if model_version == "sam3":
        raise HTTPException(
            status_code=400,
            detail="Use /segment/text for SAM3 text-based segmentation.",
        )

    if point_coords is None and boxes is None:
        raise HTTPException(
            status_code=400,
            detail="At least one of point_coords or boxes must be provided.",
        )

    max_size = _normalize_max_size(max_size)
    tmpdir = tempfile.mkdtemp()
    try:
        input_path, image_hash = await _save_upload(file, tmpdir)
        output_path = os.path.join(tmpdir, "mask.tif")

        # Parse JSON prompt fields
        parsed_coords = None
        parsed_labels = None
        parsed_boxes = None

        if point_coords is not None:
            parsed_coords = np.array(json.loads(point_coords))
        if point_labels is not None:
            parsed_labels = np.array(json.loads(point_labels))
        if boxes is not None:
            parsed_boxes = np.array(json.loads(boxes))

        t_start = time.time()
        model, lock = get_model(model_version, model_id, automatic=False)
        model_key = (model_version, model_id or _DEFAULT_MODEL_IDS[model_version])
        with lock:
            _set_image_cached(model, model_key, input_path, image_hash)
            model.predict(
                point_coords=parsed_coords,
                point_labels=parsed_labels,
                boxes=parsed_boxes,
                point_crs=point_crs,
                multimask_output=multimask_output,
                output=output_path,
            )

        t_inference = time.time() - t_start
        logger.info(
            "Prompt segmentation completed in %.2fs (model: %s)",
            t_inference,
            model_version,
        )
        return _format_response(output_path, output_format, tmpdir)
    except HTTPException:
        _cleanup_tmpdir(tmpdir)
        raise
    except json.JSONDecodeError as e:
        _cleanup_tmpdir(tmpdir)
        raise HTTPException(
            status_code=400, detail=f"Invalid JSON in prompt fields: {e}"
        )
    except Exception as e:
        _cleanup_tmpdir(tmpdir)
        raise HTTPException(status_code=500, detail=str(e))

segment_text(file=File(...), prompt=Form(...), model_id=Form(None), backend=Form('meta'), output_format=Form('geojson'), confidence_threshold=Form(0.5), min_size=Form(0), max_size=Form(None)) async

Run text-prompt segmentation using SAM3.

Parameters:

Name Type Description Default
file UploadFile

Image file (TIFF, PNG, JPEG).

File(...)
prompt str

Text description of objects to segment (e.g., "building").

Form(...)
model_id Optional[str]

SAM3 model identifier.

Form(None)
backend str

SAM3 backend, one of "meta" or "transformers".

Form('meta')
output_format str

One of "geojson", "geotiff", "png".

Form('geojson')
confidence_threshold float

Confidence threshold for detections.

Form(0.5)
min_size int

Minimum mask size in pixels.

Form(0)
max_size Optional[int]

Maximum mask size in pixels.

Form(None)

Returns:

Type Description

Segmentation result in the requested format.

Source code in samgeo/api.py
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
@app.post("/segment/text")
async def segment_text(
    file: UploadFile = File(...),
    prompt: str = Form(...),
    model_id: Optional[str] = Form(None),
    backend: str = Form("meta"),
    output_format: str = Form("geojson"),
    confidence_threshold: float = Form(0.5),
    min_size: int = Form(0),
    max_size: Optional[int] = Form(None),
):
    """Run text-prompt segmentation using SAM3.

    Args:
        file: Image file (TIFF, PNG, JPEG).
        prompt: Text description of objects to segment (e.g., "building").
        model_id: SAM3 model identifier.
        backend: SAM3 backend, one of "meta" or "transformers".
        output_format: One of "geojson", "geotiff", "png".
        confidence_threshold: Confidence threshold for detections.
        min_size: Minimum mask size in pixels.
        max_size: Maximum mask size in pixels.

    Returns:
        Segmentation result in the requested format.
    """
    _validate_output_format(output_format)
    max_size = _normalize_max_size(max_size)
    tmpdir = tempfile.mkdtemp()
    try:
        input_path, image_hash = await _save_upload(file, tmpdir)
        output_path = os.path.join(tmpdir, "mask.tif")

        model, lock = get_model(
            "sam3",
            model_id,
            backend=backend,
            confidence_threshold=confidence_threshold,
        )
        t_start = time.time()
        model_key = ("sam3", model_id or _DEFAULT_MODEL_IDS["sam3"])
        with lock:
            _set_image_cached(model, model_key, input_path, image_hash)
            model.generate_masks(
                prompt=prompt,
                min_size=min_size,
                max_size=max_size,
            )
            model.save_masks(output=output_path)

        t_inference = time.time() - t_start
        logger.info(
            "Text segmentation completed in %.2fs (prompt: '%s')",
            t_inference,
            prompt,
        )
        return _format_response(output_path, output_format, tmpdir)
    except HTTPException:
        _cleanup_tmpdir(tmpdir)
        raise
    except Exception as e:
        _cleanup_tmpdir(tmpdir)
        raise HTTPException(status_code=500, detail=str(e))