Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pipeline {
sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-OSM.pt'
sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-ESRI.pt'
sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-CLS.pt'
sh 'wget --quiet -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-SAM.pt'
}
}
post {
Expand Down
2 changes: 1 addition & 1 deletion docs/development-setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ npm install
npm run build

# Download ml-model weights
wget -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-{OSM,ESRI,CLS}.pt
wget -P weights https://downloads.ohsome.org/sketch-map-tool/weights/SMT-{OSM,ESRI,CLS,SAM}.pt

# Fetch and run database & result store (postgres)
docker run --name smt-postgres -d -p 5432:5432 -e POSTGRES_PASSWORD=smt -e POSTGRES_USER=smt postgres:15
Expand Down
4 changes: 3 additions & 1 deletion docs/model_registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The **Model Registry** maintains a collection of fine-tuned machine learning mod
| Object Detection | YOLO_OSM | 6-Channel Input | Detects sketches on OSM | [download](https://downloads.ohsome.org/sketch-map-tool/weights/SMT-OSM.pt) |
| Object Detection | YOLO_ESRI | 6-Channel Input | Detects sketches on ESRI maps | [download](hhttps://downloads.ohsome.org/sketch-map-tool/weights/SMT-ESRI.pt) |
| Image Classification | YOLO_CLS | Standard RGB | Classifies colors in sketches | [download](https://downloads.ohsome.org/sketch-map-tool/weights/SMT-CLS.pt) |
| Segmentation | SAM2 | tandard RGB | Performs segmentation on sketch | [github](https://github.com/facebookresearch/sam2) |
| Segmentation | SAM2 | Standard RGB | Performs segmentation on on sketches, finetuned from **SAM 2.1 hiera large** | [download](hhttps://downloads.ohsome.org/sketch-map-tool/weights/SMT-SAM.pt) |

## Models in the Registry
### 1. Object Detection Models
Expand Down Expand Up @@ -40,7 +40,9 @@ This model is used to determine the sketch's **color**.
For segmentation tasks, **SAM2 (Segment Anything Model v2)** is utilized.

#### **SAM2 - Segmentation Model**
- **Base Map:** Both OSM and ESRI Satellite Imagery
- **Task:** Performs pixel-wise segmentation to extract regions from sketches.
- **Fine-tuned On:** On a set of manually selected segmented sketches to improve performance on sketch data.


For questions, contact the **SketchMapTool Team**.
Expand Down
3 changes: 2 additions & 1 deletion sketch_map_tool/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ class Config(BaseSettings):
esri_api_key: str = ""
log_level: str = "INFO"
max_nr_simultaneous_uploads: int = 100
model_type_sam: str = "vit_b"
model_type_sam: str = "configs/sam2.1/sam2.1_hiera_l.yaml"
point_area_threshold: float = 0.00047
result_backend: str = "db+postgresql://smt:smt@localhost:5432"
sam_checkpoint: str = "SMT-SAM"
user_agent: str = "sketch-map-tool"
weights_dir: str = str(get_project_root() / "weights") # TODO: make this a Path
wms_layers_esri_world_imagery: str = "world_imagery"
Expand Down
10 changes: 6 additions & 4 deletions sketch_map_tool/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from io import BytesIO

import torch
from celery.result import AsyncResult
from celery.signals import setup_logging, worker_process_init, worker_process_shutdown
from geojson import FeatureCollection
Expand All @@ -27,7 +28,6 @@
from sketch_map_tool.upload_processing.detect_markings import detect_markings
from sketch_map_tool.upload_processing.ml_models import (
init_model,
init_sam2,
select_computation_device,
)
from sketch_map_tool.wms import client as wms_client
Expand Down Expand Up @@ -55,13 +55,15 @@ def init_worker_ml_models(**_):
global yolo_obj_esri
global yolo_cls

path = init_sam2()
device = select_computation_device()
sam2_model = build_sam2(
config_file="sam2_hiera_b+.yaml",
ckpt_path=path,
config_file=CONFIG.model_type_sam,
ckpt_path=None,
device=device,
)
sam2_model.load_state_dict(
torch.load(init_model(CONFIG.sam_checkpoint), map_location=device)
)
sam_predictor = SAM2ImagePredictor(sam2_model)

yolo_obj_osm = YOLO_MB(init_model(CONFIG.yolo_osm_obj))
Expand Down
14 changes: 0 additions & 14 deletions sketch_map_tool/upload_processing/ml_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from pathlib import Path

import requests
import torch
from torch._prims_common import DeviceLikeType

Expand All @@ -17,19 +16,6 @@ def init_model(id: str) -> Path:
return path


def init_sam2(id: str = "sam2_hiera_base_plus") -> Path:
raw = Path(CONFIG.weights_dir) / id
path = raw.with_suffix(".pt")
base_url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/"
url = base_url + id + ".pt"
if not path.is_file():
logging.info(f"Downloading model SAM-2 from fbaipublicfiles.com to {path}.")
response = requests.get(url=url)
with open(path, mode="wb") as file:
file.write(response.content)
return path


def select_computation_device() -> DeviceLikeType:
"""Select computation device (cuda, mps, cpu) for SAM-2"""
if torch.cuda.is_available():
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 6 additions & 4 deletions tests/integration/upload_processing/test_detect_markings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
import torch
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
Expand All @@ -12,7 +13,6 @@
)
from sketch_map_tool.upload_processing.ml_models import (
init_model,
init_sam2,
select_computation_device,
)

Expand All @@ -22,13 +22,15 @@
@pytest.fixture
def sam_predictor():
"""Zero shot segment anything model"""
path = init_sam2()
device = select_computation_device()
sam2_model = build_sam2(
config_file="sam2_hiera_b+.yaml",
ckpt_path=path,
config_file=CONFIG.model_type_sam,
ckpt_path=None,
device=device,
)
sam2_model.load_state_dict(
torch.load(init_model(CONFIG.sam_checkpoint), map_location=device)
)
return SAM2ImagePredictor(sam2_model)


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_get_config_path_set_env(monkeypatch):
assert config.get_config_path() == "/some/absolute/path"


def test_config_user_agent_env(monkeypatch):
def test_config_user_agent_env():
# env takes precedence over file (see pyproject.toml)
assert config.CONFIG.user_agent == "sketch-map-tool-test"

Expand Down