diff --git a/Jenkinsfile b/Jenkinsfile index 9b08d71b9..38d5e030c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -44,6 +44,7 @@ pipeline { sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-OSM.pt' sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-ESRI.pt' sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-CLS.pt' + sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-SAM.pt' } } post { diff --git a/docs/development-setup.md b/docs/development-setup.md index 2cf7b854d..de91de708 100644 --- a/docs/development-setup.md +++ b/docs/development-setup.md @@ -49,7 +49,7 @@ npm install npm run build # Download ml-model weights -wget -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-{OSM,ESRI,CLS}.pt +wget -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-{OSM,ESRI,CLS,SAM}.pt # Fetch and run database & result store (postgres) docker run --name smt-postgres -d -p 5432:5432 -e POSTGRES_PASSWORD=smt -e POSTGRES_USER=smt postgres:15 diff --git a/docs/model_registry.md b/docs/model_registry.md index 530dc857d..de8ab8e0a 100644 --- a/docs/model_registry.md +++ b/docs/model_registry.md @@ -10,7 +10,7 @@ The **Model Registry** maintains a collection of fine-tuned machine learning mod | Object Detection | YOLO_OSM | 6-Channel Input | Detects sketches on OSM | [download](https://downloads.ohsome.org/sketch-map-tool/weights/SMT-OSM.pt) | | Object Detection | YOLO_ESRI | 6-Channel Input | Detects sketches on ESRI maps | [download](hhttps://downloads.ohsome.org/sketch-map-tool/weights/SMT-ESRI.pt) | | Image Classification | YOLO_CLS | Standard RGB | Classifies colors in sketches | [download](https://downloads.ohsome.org/sketch-map-tool/weights/SMT-CLS.pt) | -| Segmentation | SAM2 | tandard RGB | Performs segmentation on sketch | [github](https://github.com/facebookresearch/sam2) | +| Segmentation | SAM2 | Standard RGB | Performs segmentation on on sketches, finetuned from **SAM 2.1 hiera large** | [download](hhttps://downloads.ohsome.org/sketch-map-tool/weights/SMT-SAM.pt) | ## Models in the Registry ### 1. Object Detection Models @@ -40,7 +40,9 @@ This model is used to determine the sketch's **color**. For segmentation tasks, **SAM2 (Segment Anything Model v2)** is utilized. #### **SAM2 - Segmentation Model** +- **Base Map:** Both OSM and ESRI Satellite Imagery - **Task:** Performs pixel-wise segmentation to extract regions from sketches. +- **Fine-tuned On:** On a set of manually selected segmented sketches to improve performance on sketch data. For questions, contact the **SketchMapTool Team**. diff --git a/sketch_map_tool/config.py b/sketch_map_tool/config.py index f78b49e04..458595ed6 100644 --- a/sketch_map_tool/config.py +++ b/sketch_map_tool/config.py @@ -29,9 +29,10 @@ class Config(BaseSettings): esri_api_key: str = "" log_level: str = "INFO" max_nr_simultaneous_uploads: int = 100 - model_type_sam: str = "vit_b" + model_type_sam: str = "configs/sam2.1/sam2.1_hiera_l.yaml" point_area_threshold: float = 0.00047 result_backend: str = "db+postgresql://smt:smt@localhost:5432" + sam_checkpoint: str = "SMT-SAM" user_agent: str = "sketch-map-tool" weights_dir: str = str(get_project_root() / "weights") # TODO: make this a Path wms_layers_esri_world_imagery: str = "world_imagery" diff --git a/sketch_map_tool/tasks.py b/sketch_map_tool/tasks.py index 447ebdbc4..16aec96e4 100644 --- a/sketch_map_tool/tasks.py +++ b/sketch_map_tool/tasks.py @@ -1,6 +1,7 @@ import logging from io import BytesIO +import torch from celery.result import AsyncResult from celery.signals import setup_logging, worker_process_init, worker_process_shutdown from geojson import FeatureCollection @@ -27,7 +28,6 @@ from sketch_map_tool.upload_processing.detect_markings import detect_markings from sketch_map_tool.upload_processing.ml_models import ( init_model, - init_sam2, select_computation_device, ) from sketch_map_tool.wms import client as wms_client @@ -55,13 +55,15 @@ def init_worker_ml_models(**_): global yolo_obj_esri global yolo_cls - path = init_sam2() device = select_computation_device() sam2_model = build_sam2( - config_file="sam2_hiera_b+.yaml", - ckpt_path=path, + config_file=CONFIG.model_type_sam, + ckpt_path=None, device=device, ) + sam2_model.load_state_dict( + torch.load(init_model(CONFIG.sam_checkpoint), map_location=device) + ) sam_predictor = SAM2ImagePredictor(sam2_model) yolo_obj_osm = YOLO_MB(init_model(CONFIG.yolo_osm_obj)) diff --git a/sketch_map_tool/upload_processing/ml_models.py b/sketch_map_tool/upload_processing/ml_models.py index 3a2d8c8bf..d7f12acdf 100644 --- a/sketch_map_tool/upload_processing/ml_models.py +++ b/sketch_map_tool/upload_processing/ml_models.py @@ -1,7 +1,6 @@ import logging from pathlib import Path -import requests import torch from torch._prims_common import DeviceLikeType @@ -17,19 +16,6 @@ def init_model(id: str) -> Path: return path -def init_sam2(id: str = "sam2_hiera_base_plus") -> Path: - raw = Path(CONFIG.weights_dir) / id - path = raw.with_suffix(".pt") - base_url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/" - url = base_url + id + ".pt" - if not path.is_file(): - logging.info(f"Downloading model SAM-2 from fbaipublicfiles.com to {path}.") - response = requests.get(url=url) - with open(path, mode="wb") as file: - file.write(response.content) - return path - - def select_computation_device() -> DeviceLikeType: """Select computation device (cuda, mps, cpu) for SAM-2""" if torch.cuda.is_available(): diff --git a/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[1726835278].approved.png b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[1726835278].approved.png new file mode 100644 index 000000000..4d2de3b8a Binary files /dev/null and b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[1726835278].approved.png differ diff --git a/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[2346410719].approved.png b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[2346410719].approved.png new file mode 100644 index 000000000..bd488cc54 Binary files /dev/null and b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[2346410719].approved.png differ diff --git a/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].approved.png b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].approved.png new file mode 100644 index 000000000..3cf27f47d Binary files /dev/null and b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].approved.png differ diff --git a/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].received.png b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].received.png new file mode 100644 index 000000000..1c82cdc24 Binary files /dev/null and b/tests/approvals/integration/upload_processing/test_detect_markings.py--test_detect_markings[68370924].received.png differ diff --git a/tests/integration/upload_processing/test_detect_markings.py b/tests/integration/upload_processing/test_detect_markings.py index 8cf1fbb63..459c67a97 100644 --- a/tests/integration/upload_processing/test_detect_markings.py +++ b/tests/integration/upload_processing/test_detect_markings.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import torch from PIL import Image from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor @@ -12,7 +13,6 @@ ) from sketch_map_tool.upload_processing.ml_models import ( init_model, - init_sam2, select_computation_device, ) @@ -22,13 +22,15 @@ @pytest.fixture def sam_predictor(): """Zero shot segment anything model""" - path = init_sam2() device = select_computation_device() sam2_model = build_sam2( - config_file="sam2_hiera_b+.yaml", - ckpt_path=path, + config_file=CONFIG.model_type_sam, + ckpt_path=None, device=device, ) + sam2_model.load_state_dict( + torch.load(init_model(CONFIG.sam_checkpoint), map_location=device) + ) return SAM2ImagePredictor(sam2_model) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index ff0a5d9a4..4d9414dae 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -47,7 +47,7 @@ def test_get_config_path_set_env(monkeypatch): assert config.get_config_path() == "/some/absolute/path" -def test_config_user_agent_env(monkeypatch): +def test_config_user_agent_env(): # env takes precedence over file (see pyproject.toml) assert config.CONFIG.user_agent == "sketch-map-tool-test"