diff --git a/common/download.cpp b/common/download.cpp index 40f6eb780f41..c3c8ff49bb70 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -997,3 +997,87 @@ std::vector common_list_cached_models() { return result; } + +bool common_download_remove(const std::string & hf_repo_with_tag) { + namespace fs = std::filesystem; + + auto [repo_id, tag] = common_download_split_repo_tag(hf_repo_with_tag); + + if (tag.empty()) { + return hf_cache::remove_cached_repo(repo_id); + } + + std::string tag_upper = tag; + for (char & c : tag_upper) { + c = (char) std::toupper((unsigned char) c); + } + + auto files = hf_cache::get_cached_files(repo_id); + if (files.empty()) { + return false; + } + + // collect snapshot entries whose tag matches + std::vector to_remove; + for (const auto & f : files) { + auto split = get_gguf_split_info(f.path); + if (split.tag == tag_upper) { + to_remove.emplace_back(f.local_path); + } + } + + if (to_remove.empty()) { + return false; + } + + // resolve blob paths from symlinks before deleting snapshot entries + std::vector blobs_to_check; + for (const auto & p : to_remove) { + std::error_code ec; + if (fs::is_symlink(p, ec)) { + auto target = fs::read_symlink(p, ec); + if (!ec) { + blobs_to_check.push_back((p.parent_path() / target).lexically_normal()); + } + } + } + + // remove snapshot entries + for (const auto & p : to_remove) { + std::error_code ec; + fs::remove(p, ec); + if (ec) { + LOG_WRN("%s: failed to remove %s: %s\n", __func__, p.string().c_str(), ec.message().c_str()); + } + } + + if (blobs_to_check.empty()) { + return true; + } + + // collect blobs still referenced by remaining snapshot entries + std::unordered_set still_referenced; + for (const auto & f : hf_cache::get_cached_files(repo_id)) { + fs::path p(f.local_path); + std::error_code ec; + if (fs::is_symlink(p, ec)) { + auto target = fs::read_symlink(p, ec); + if (!ec) { + still_referenced.insert((p.parent_path() / target).lexically_normal().string()); + } + } + } + + // remove orphaned blobs + for (const auto & blob : blobs_to_check) { + if (still_referenced.find(blob.string()) == still_referenced.end()) { + std::error_code ec; + fs::remove(blob, ec); + if (ec) { + LOG_WRN("%s: failed to remove blob %s: %s\n", __func__, blob.string().c_str(), ec.message().c_str()); + } + } + } + + return true; +} diff --git a/common/download.h b/common/download.h index ebeedd6058c7..237179764421 100644 --- a/common/download.h +++ b/common/download.h @@ -115,3 +115,10 @@ int common_download_file_single(const std::string & url, // resolve and download model from Docker registry // return local path to downloaded model file std::string common_docker_resolve_model(const std::string & docker); + +// Remove a cached model from disk +// input format: "user/model" or "user/model:tag" +// - if tag is omitted, removes the entire repo cache directory +// - if tag is present, removes only files matching that tag (and orphaned blobs) +// returns true if anything was removed +bool common_download_remove(const std::string & hf_repo_with_tag); diff --git a/common/hf-cache.cpp b/common/hf-cache.cpp index ba7417a12bb6..f1dacaa4778d 100644 --- a/common/hf-cache.cpp +++ b/common/hf-cache.cpp @@ -495,4 +495,19 @@ std::string finalize_file(const hf_file & file) { return file.final_path; } +bool remove_cached_repo(const std::string & repo_id) { + if (!is_valid_repo_id(repo_id)) { + LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str()); + return false; + } + fs::path repo_path = get_repo_path(repo_id); + std::error_code ec; + auto removed = fs::remove_all(repo_path, ec); + if (ec) { + LOG_ERR("%s: failed to remove repo cache %s: %s\n", __func__, repo_path.string().c_str(), ec.message().c_str()); + return false; + } + return removed > 0; +} + } // namespace hf_cache diff --git a/common/hf-cache.h b/common/hf-cache.h index 23fa0adb729d..42c9c6ce34f0 100644 --- a/common/hf-cache.h +++ b/common/hf-cache.h @@ -29,4 +29,7 @@ hf_files get_cached_files(const std::string & repo_id = {}); // Create snapshot path (link or move/copy) and return it std::string finalize_file(const hf_file & file); +// Remove the entire cached directory for a repo, returns true if removed +bool remove_cached_repo(const std::string & repo_id); + } // namespace hf_cache diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md index 0ff334724a39..4c410312398b 100644 --- a/tools/server/README-dev.md +++ b/tools/server/README-dev.md @@ -180,6 +180,24 @@ That requires `JSON.stringify` when formatted to message content: } ``` +### Model management API (router mode) + +Model management API was added via PR [#23976](https://github.com/ggml-org/llama.cpp/pull/23976) + +The main goal of this API is to allow downloading models and/or removing models from the web UI. It relies on the model cache infrastructure under the hood to manage the list of models dynamically. + +Instead of building everything from the ground up (like what most AI agents will do when you ask them to implement a similar feature), we built on top of existing, already well-engineered components inside the codebase: +- Model cache infrastructure as mentioned above (`common/download.h`) +- Server response queue (`server-queue.h`). We use this feature to broadcast events to SSE clients. +- Server router thread management (`server-models.h`). We re-use the same thread model that is used for managing subprocess life cycle, except that we don't create a new subprocess, but launch the download right inside the thread. + +The flow for downloading a new model: +- POST request comes in --> `post_router_models` --> validation +- `server_models::download()` is called + - Sets up a new thread `inst.th` and runs the download inside +- If a stop request comes in, set `stop_download` to `true` +- Otherwise, upon completion, we call `load_models()` to refresh the list of models + ### Notable Related PRs - Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443 diff --git a/tools/server/README.md b/tools/server/README.md index b41491059020..88a507e2c558 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1778,6 +1778,20 @@ The `status` object can be: } ``` +Note: for "downloading" state, there can be multiple files be downloading in parallel + +```json +"status": { + "value": "downloading", + "progress": { + "https://...model.gguf": { + "done": 195963406, + "total": 219307424 + } + } +} +``` + ### POST `/models/load`: Load a model Load a model @@ -1820,6 +1834,107 @@ Response: } ``` +### GET `/models/sse`: Real-time events + +Example events: + +```js +{ + "model": "...", + "event": "model_status", + "data": { + "status": "loading" + } +} + +{ + "model": "...", + "event": "download_progress", + "data": { + // note: there can be multiple files being downloaded in parallel + "https://...model.gguf": { + "done": 195963406, + "total": 219307424 + } + } +} + +{ + "model": "...", + "event": "download_finished", + "data": { + "status": "loading" + } +} + +{ + "model": "...", + "event": "model_remove" +} + +// special event: reload of the list of all models +{ + "model": "*", + "event": "models_reload" +} +``` + +### POST `/models`: Download new model + +Trigger a new download (non-blocking), the progress can be tracked via SSE endpoint `/models/sse` + +To cancel model downloading, send an event to `/models/unload` + +Download procedure: +- Send POST request to `/models` +- Subscribe to `/models/sse` for updates +- On downloading completed, you will receive either `download_finished` or `download_failed` event +- Call GET `/models` to trigger model list update. If the download success, you should see the new model in the list + +Payload: + +```json +{ + "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", +} +``` + +Response (download is started in the background): + +```json +{ + "success": true +} +``` + +Response (error, cannot start the download): + +```json +{ + "error": { + "code": 400, + "message": "model validation failed, unable to download", + "type": "invalid_request_error" + } +} +``` + +### DELETE `/models`: Delete a model from cache + +IMPORTANT: only model stored in cache can be deleted. You cannot delete models in a preset. + +Model name must be passed via query param: `?model={name}` + +If delete success, it will send an SSE event of type `model_remove` + +Response: + +```json +{ + "success": true +} +``` + ## API errors `llama-server` returns errors in the same format as OAI: https://github.com/openai/openai-openapi diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 3c58fcfece89..5defee1f5e03 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -588,6 +588,23 @@ void server_http_context::post(const std::string & path, const server_http_conte }); } +void server_http_context::del(const std::string & path, const server_http_context::handler_t & handler) const { + handlers.emplace(path, handler); + pimpl->srv->Delete(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { + server_http_req_ptr request = std::make_unique(server_http_req{ + get_params(req), + get_headers(req), + req.path, + build_query_string(req), + req.body, + {}, + req.is_connection_closed + }); + server_http_res_ptr response = handler(*request); + process_handler_response(std::move(request), response, res); + }); +} + // // Vertex AI Prediction protocol (AIP_PREDICT_ROUTE) // https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements diff --git a/tools/server/server-http.h b/tools/server/server-http.h index 25c7f10629b3..6b4a4b87a631 100644 --- a/tools/server/server-http.h +++ b/tools/server/server-http.h @@ -86,6 +86,7 @@ struct server_http_context { void get(const std::string & path, const handler_t & handler) const; void post(const std::string & path, const handler_t & handler) const; + void del(const std::string & path, const handler_t & handler) const; // Register the Google Cloud Platform (Vertex AI) compat (AIP_PREDICT_ROUTE env var, or /predict) // Must be called AFTER all other API routes are registered diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 49b0e423f462..ff9a0df12f4b 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -51,6 +52,21 @@ extern char **environ; // ref: https://github.com/ggml-org/llama.cpp/issues/17862 #define CHILD_ADDR "127.0.0.1" +struct server_subproc { + std::optional sproc; // empty while in DOWNLOADING state + std::atomic stop_download{false}; // flag to signal download cancellation + + subprocess_s & get() { + GGML_ASSERT(sproc.has_value() && "subprocess not initialized"); + return sproc.value(); + } + + bool is_alive() { + return sproc.has_value() && subprocess_alive(&sproc.value()); + } +}; + + static std::filesystem::path get_server_exec_path() { #if defined(_WIN32) wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths @@ -272,12 +288,25 @@ void server_models::add_model(server_model_meta && meta) { meta.update_caps(); std::string name = meta.name; mapping[name] = instance_t{ - /* subproc */ std::make_shared(), + /* subproc */ std::make_shared(), /* th */ std::thread(), /* meta */ std::move(meta) }; } +void server_models::notify_sse(const std::string & event, const std::string & model_id, const json & data) { + std::unique_ptr result = std::make_unique(); + result->data = { + {"model", model_id}, + {"event", event}, + }; + if (!data.is_null()) { + result->data["data"] = data; + } + SRV_DBG("notifying SSE clients about event '%s' for model '%s': %s\n", event.c_str(), model_id.c_str(), safe_json_to_str(result->data).c_str()); + sse.broadcast(std::move(result)); +} + void server_models::load_models() { // Phase 1: load presets from all sources — pure I/O, no lock needed // 1. cached models @@ -304,19 +333,27 @@ void server_models::load_models() { // note: if a model exists in both cached and local, local takes precedence common_presets final_presets; - for (const auto & [name, preset] : cached_models) final_presets[name] = preset; - for (const auto & [name, preset] : local_models) final_presets[name] = preset; + std::unordered_map source_map; + for (const auto & [name, preset] : cached_models) { + final_presets[name] = preset; + source_map[name] = SERVER_MODEL_SOURCE_CACHE; + } + for (const auto & [name, preset] : local_models) { + final_presets[name] = preset; + source_map[name] = SERVER_MODEL_SOURCE_MODELS_DIR; + } for (const auto & [name, custom] : custom_presets) { if (final_presets.find(name) != final_presets.end()) { final_presets[name].merge(custom); } else { final_presets[name] = custom; } + source_map[name] = SERVER_MODEL_SOURCE_PRESET; } - // server base preset from CLI args takes highest precedence - for (auto & [name, preset] : final_presets) { - preset.merge(base_preset); - } + + auto get_source = [&](const std::string & name) { + return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET; + }; // Helpers that read `mapping` — must be called while holding the lock. std::unordered_set custom_names; @@ -366,12 +403,15 @@ void server_models::load_models() { // (unload, load) or when joining threads (the monitoring thread calls update_status // which locks the mutex, so joining while holding it would deadlock). std::unique_lock lk(mutex); + + need_reload = false; bool is_first_load = mapping.empty(); if (is_first_load) { // FIRST LOAD: add all models, then unlock for autoloading for (const auto & [name, preset] : final_presets) { server_model_meta meta{ + /* source */ get_source(name), /* preset */ preset, /* name */ name, /* aliases */ {}, @@ -384,7 +424,7 @@ void server_models::load_models() { /* exit_code */ 0, /* stop_timeout */ DEFAULT_STOP_TIMEOUT, /* multimodal */ mtmd_caps{false, false}, - /* need_download */ false, + // /* need_download */ false, }; add_model(std::move(meta)); } @@ -453,6 +493,9 @@ void server_models::load_models() { } } for (auto & [name, inst] : mapping) { + if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) { + continue; // downloading models are not from config sources, leave them alone + } if (final_presets.find(name) == final_presets.end() && !inst.meta.is_running() && inst.th.joinable()) { threads_to_join.push_back(std::move(inst.th)); } @@ -465,7 +508,15 @@ void server_models::load_models() { // erase models no longer in any source for (auto it = mapping.begin(); it != mapping.end(); ) { - if (final_presets.find(it->first) == final_presets.end()) { + if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) { + ++it; // download thread is still busy, skip + } else if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADED) { + // download finished, safe to erase + if (it->second.th.joinable()) { + it->second.th.join(); + } + it = mapping.erase(it); + } else if (final_presets.find(it->first) == final_presets.end()) { SRV_INF("(reload) removing model name=%s (no longer in source)\n", it->first.c_str()); GGML_ASSERT(!it->second.th.joinable()); // must have been joined above it = mapping.erase(it); @@ -526,6 +577,7 @@ void server_models::load_models() { for (const auto & [name, preset] : final_presets) { if (mapping.find(name) == mapping.end()) { server_model_meta meta{ + /* source */ get_source(name), /* preset */ preset, /* name */ name, /* aliases */ {}, @@ -538,7 +590,7 @@ void server_models::load_models() { /* exit_code */ 0, /* stop_timeout */ DEFAULT_STOP_TIMEOUT, /* multimodal */ mtmd_caps{false, false}, - /* need_download */ false, + // /* need_download */ false, }; add_model(std::move(meta)); newly_added.push_back(name); @@ -571,6 +623,8 @@ void server_models::load_models() { SRV_INF("(reload) loading new model %s\n", name.c_str()); load(name); } + + notify_sse("models_reload", "*"); } } @@ -597,7 +651,13 @@ bool server_models::has_model(const std::string & name) { } std::optional server_models::get_meta(const std::string & name) { - std::lock_guard lk(mutex); + std::unique_lock lk(mutex); + if (need_reload) { + lk.unlock(); + load_models(); + lk.lock(); + } + auto it = mapping.find(name); if (it != mapping.end()) { return it->second.meta; @@ -683,7 +743,13 @@ static std::vector to_char_ptr_array(const std::vector & ve } std::vector server_models::get_all_meta() { - std::lock_guard lk(mutex); + std::unique_lock lk(mutex); + if (need_reload) { + lk.unlock(); + load_models(); + lk.lock(); + } + std::vector result; result.reserve(mapping.size()); for (const auto & [name, inst] : mapping) { @@ -770,7 +836,7 @@ void server_models::load(const std::string & name) { throw std::runtime_error("failed to get a port number"); } - inst.subproc = std::make_shared(); + inst.subproc = std::make_shared(); { SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port); @@ -792,19 +858,20 @@ void server_models::load(const std::string & name) { // TODO @ngxson : maybe separate stdout and stderr in the future // so that we can use stdout for commands and stderr for logging int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr; - int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get()); + inst.subproc->sproc.emplace(); + int result = subprocess_create_ex(argv.data(), options, envp.data(), &inst.subproc->get()); if (result != 0) { throw std::runtime_error("failed to spawn server instance"); } - inst.stdin_file = subprocess_stdin(inst.subproc.get()); + inst.stdin_file = subprocess_stdin(&inst.subproc->get()); } // start a thread to manage the child process // captured variables are guaranteed to be destroyed only after the thread is joined inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() { - FILE * stdin_file = subprocess_stdin(child_proc.get()); - FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr + FILE * stdin_file = subprocess_stdin(&child_proc->get()); + FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr std::thread log_thread([&]() { // read stdout/stderr and forward to main server log @@ -834,14 +901,14 @@ void server_models::load(const std::string & name) { return this->stopping_models.find(name) != this->stopping_models.end(); }; auto should_wake = [&]() { - return is_stopping() || !subprocess_alive(child_proc.get()); + return is_stopping() || !child_proc->is_alive(); }; { std::unique_lock lk(this->mutex); this->cv_stop.wait(lk, should_wake); } // child may have already exited (e.g. crashed) — skip shutdown sequence - if (!subprocess_alive(child_proc.get())) { + if (!child_proc->is_alive()) { return; } SRV_INF("stopping model instance name=%s\n", name.c_str()); @@ -859,7 +926,7 @@ void server_models::load(const std::string & name) { if (elapsed >= stop_timeout * 1000) { // timeout, force kill SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout); - subprocess_terminate(child_proc.get()); + subprocess_terminate(&child_proc->get()); return; } this->cv_stop.wait_for(lk, std::chrono::seconds(1)); @@ -884,8 +951,8 @@ void server_models::load(const std::string & name) { // get the exit code int exit_code = 0; - subprocess_join(child_proc.get(), &exit_code); - subprocess_destroy(child_proc.get()); + subprocess_join(&child_proc->get(), &exit_code); + subprocess_destroy(&child_proc->get()); // update status and exit code this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code); @@ -896,30 +963,118 @@ void server_models::load(const std::string & name) { { auto & old_instance = mapping[name]; // old process should have exited already, but just in case, we clean it up here - if (subprocess_alive(old_instance.subproc.get())) { + if (old_instance.subproc->is_alive()) { SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str()); - subprocess_terminate(old_instance.subproc.get()); // force kill + subprocess_terminate(&old_instance.subproc->get()); // force kill } if (old_instance.th.joinable()) { old_instance.th.join(); } } + notify_sse("model_status", name, { + {"status", server_model_status_to_string(inst.meta.status)}, + }); + mapping[name] = std::move(inst); cv.notify_all(); } +// callback for model downloading functionality +struct server_models_download_res : public common_download_callback { + common_params_model model; + common_download_opts opts; + + std::function should_stop; + std::function on_progress; + + bool is_ok = false; + + bool run() { + try { + common_download_model(model, opts); + is_ok = true; + } catch (const std::exception & e) { + SRV_ERR("download failed for model name=%s: %s\n", model.name.c_str(), e.what()); + is_ok = false; + } + return is_ok; + } + void on_start(const common_download_progress & p) override { + on_progress(p); + } + void on_update(const common_download_progress & p) override { + on_progress(p); + } + void on_done(const common_download_progress &, bool ok) override { + is_ok = ok; + } + bool is_cancelled() const override { + return should_stop(); + } +}; + +void server_models::download(common_params_model && model, common_download_opts && opts) { + std::string name = model.name; + GGML_ASSERT(name == model.hf_repo); + + std::unique_lock lk(mutex); + if (mapping.find(name) != mapping.end()) { + throw std::runtime_error("model name=" + name + " already exists"); + } + + instance_t inst; + inst.meta.name = name; + inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING; + inst.subproc = std::make_shared(); + + auto dl = std::make_unique(); + dl->model = model; // copy + dl->opts = opts; // copy + + dl->should_stop = [sp = inst.subproc]() { + return sp->stop_download.load(std::memory_order_relaxed); + }; + + dl->on_progress = [this, name](const common_download_progress & p) { + update_download_progress(name, p, false); + }; + + inst.th = std::thread([this, dl = std::move(dl)]() { + dl->opts.callback = dl.get(); + bool ok = dl->run(); + SRV_INF("download finished for model name=%s with status=%s\n", + dl->model.name.c_str(), ok ? "success" : "failure"); + update_download_progress(dl->model.name, {}, true, ok); + // need_reload is set inside update_download_progress under the mutex; + // the next load_models() call will clean up this instance + }); + + mapping[name] = std::move(inst); + notify_sse("status_update", name, { + {"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)}, + }); + cv.notify_all(); +} + void server_models::unload(const std::string & name) { - std::lock_guard lk(mutex); + std::unique_lock lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { - if (it->second.meta.is_running()) { + if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) { + SRV_INF("cancelling download for model name=%s\n", name.c_str()); + it->second.subproc->stop_download.store(true, std::memory_order_relaxed); + // for convenience, we wait the status change here + wait(lk, name, [](const server_model_meta & new_meta) { + return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING; + }); + } else if (it->second.meta.is_running()) { SRV_INF("stopping model instance name=%s\n", name.c_str()); stopping_models.insert(name); if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) { // special case: if model is in loading state, unloading means force-killing it SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str()); - subprocess_terminate(it->second.subproc.get()); + subprocess_terminate(&it->second.subproc->get()); } cv_stop.notify_all(); // status change will be handled by the managing thread @@ -934,7 +1089,10 @@ void server_models::unload_all() { { std::lock_guard lk(mutex); for (auto & [name, inst] : mapping) { - if (inst.meta.is_running()) { + if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) { + SRV_INF("cancelling download for model name=%s\n", name.c_str()); + inst.subproc->stop_download.store(true, std::memory_order_relaxed); + } else if (inst.meta.is_running()) { SRV_INF("stopping model instance name=%s\n", name.c_str()); stopping_models.insert(name); cv_stop.notify_all(); @@ -959,6 +1117,17 @@ void server_models::update_status(const std::string & name, server_model_status meta.status = status; meta.exit_code = exit_code; } + // broadcast status change to SSE + { + json data = { + {"status", server_model_status_to_string(status)}, + }; + if (status == SERVER_MODEL_STATUS_UNLOADED) { + data["exit_code"] = exit_code; + } + // note: notify_sse doesn't acquire the lock, so no deadlock here + notify_sse("status_change", name, data); + } cv.notify_all(); } @@ -985,12 +1154,82 @@ void server_models::update_loaded_info(const std::string & name, std::string & r cv.notify_all(); } -void server_models::wait_until_loading_finished(const std::string & name) { +void server_models::update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok) { + json curr; + { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + if (done) { + // mark the instance to be erased on next load_models() call + it->second.meta.status = SERVER_MODEL_STATUS_DOWNLOADED; + need_reload = true; + } else { + json & info = it->second.meta.loaded_info; + if (!info.contains("progress")) { + info["progress"] = json{}; + } + info["progress"][progress.url] = { + {"done", progress.downloaded}, + {"total", progress.total}, + }; + curr = it->second.meta.loaded_info; // copy + } + } + } + if (done) { + cv.notify_all(); // notify in case unload() is waiting for download to be cancelled + notify_sse(ok ? "download_finished" : "download_failed", name, {}); + } else { + notify_sse("download_progress", name, curr); + } +} + +bool server_models::remove(const std::string & name) { + auto meta = get_meta(name); + + if (!meta.has_value()) { + throw std::runtime_error("model name=" + name + " is not found"); + } + if (meta->source != SERVER_MODEL_SOURCE_CACHE) { + throw std::runtime_error("model name=" + name + " is not removable (not from cache)"); + } + + unload(name); // cancel download or stop running instance + { + std::unique_lock lk(mutex); + // a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED + wait(lk, name, [](const server_model_meta & new_meta) { + return new_meta.status == SERVER_MODEL_STATUS_UNLOADED + || new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED; + }); + // join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no + // longer acquires this mutex, so joining while holding it is safe + if (mapping[name].th.joinable()) { + mapping[name].th.join(); + } + // remove the model from disk (hold lock to prevent concurrent load) + bool ok = common_download_remove(name); + if (ok) { + mapping.erase(name); + } + SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed"); + notify_sse("model_remove", name, {}); + return ok; + } +} + +void server_models::wait(const std::string & name, std::function predicate) { std::unique_lock lk(mutex); - cv.wait(lk, [this, &name]() { + wait(lk, name, predicate); +} + +void server_models::wait(std::unique_lock & lk, const std::string & name, std::function predicate) { + cv.wait(lk, [this, &name, &predicate]() { auto it = mapping.find(name); if (it != mapping.end()) { - return it->second.meta.status != SERVER_MODEL_STATUS_LOADING; + return predicate(it->second.meta); + } return false; }); @@ -1014,10 +1253,15 @@ bool server_models::ensure_model_ready(const std::string & name) { // wait for loading to complete SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); - wait_until_loading_finished(name); + wait(name, [&meta](const server_model_meta & new_meta) { + if (new_meta.status != SERVER_MODEL_STATUS_LOADING) { + meta = new_meta; // update meta for final check after wait + return true; + } + return false; + }); // check final status - meta = get_meta(name); if (!meta.has_value() || meta->is_failed()) { throw std::runtime_error("model name=" + name + " failed to load"); } @@ -1111,6 +1355,42 @@ void server_models::notify_router_sleeping_state(bool is_sleeping) { // server_models_routes // +// RAII wrapper similar to server_response_reader, but doesn't use server_queue +static std::atomic sse_client_id_counter = 0; +struct server_models_sse_client { + server_response & queue_results; + int client_id; + server_models_sse_client(server_response & q) + : queue_results(q), client_id(sse_client_id_counter.fetch_add(1, std::memory_order_relaxed)) { + SRV_DBG("new SSE client connected, assigned client_id=%d\n", client_id); + queue_results.add_waiting_task_id(client_id); + } + ~server_models_sse_client() { + SRV_DBG("SSE client disconnected, removing client_id=%d\n", client_id); + queue_results.remove_waiting_task_id(client_id); + } + + // return nullptr if should_stop() is true before receiving a result + // note: if one error is received, it will stop further processing and return error result + server_task_result_ptr next(const std::function & should_stop) { + while (true) { + static const int http_polling_seconds = 1; // check should_stop every 1 second + server_task_result_ptr result = queue_results.recv_with_timeout({client_id}, http_polling_seconds); + if (result == nullptr) { + // timeout, check stop condition + if (should_stop()) { + return nullptr; + } + // continue waiting otherwise + } else { + SRV_DBG("recv result for client_id=%d: %s\n", client_id, safe_json_to_str(result->to_json()).c_str()); + return result; + } + } + // should not reach here + } +}; + static void res_ok(std::unique_ptr & res, const json & response_data) { res->status = 200; res->data = safe_json_to_str(response_data); @@ -1274,7 +1554,9 @@ void server_models_routes::init_routes() { {"created", t}, // for OAI-compat {"status", status}, {"architecture", architecture}, - {"need_download", meta.need_download}, + {"source", server_model_source_to_string(meta.source)}, + {"can_remove", meta.source == SERVER_MODEL_SOURCE_CACHE}, + // {"need_download", meta.need_download}, // TODO: add other fields, may require reading GGUF metadata }; @@ -1312,6 +1594,87 @@ void server_models_routes::init_routes() { res_ok(res, {{"success", true}}); return res; }; + + this->get_router_models_sse = [this](const server_http_req & req) { + auto res = std::make_unique(); + res->status = 200; + res->content_type = "text/event-stream"; + auto sse_client = std::make_shared(models.sse); + res->next = [this, sse_client, &req](std::string & output) -> bool { + auto result = sse_client->next([&]() { + return stopping.load(std::memory_order_relaxed) || req.should_stop(); + }); + if (result == nullptr) { + return false; // client disconnected or should_stop + } + output = "data: " + safe_json_to_str(result->to_json()) + "\n\n"; + return true; // listen for the next event + }; + return res; + }; + + this->post_router_models = [this](const server_http_req & req) { + auto res = std::make_unique(); + + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + if (name.empty()) { + throw std::invalid_argument("model must be a non-empty string"); + } + + common_params_model model; + common_download_opts opts; + + model.name = name; + model.hf_repo = name; + opts.bearer_token = params.hf_token; + opts.download_mmproj = true; + opts.download_mtp = true; + + // first, only check if the model is valid and can be downloaded + opts.skip_download = true; + bool ok = false; + try { + auto validation = common_download_model(model, opts); + ok = !validation.model_path.empty(); + } catch (const common_skip_download_exception &) { + // model is valid and will be downloaded + ok = true; + } catch (...) { + SRV_ERR("unknown error while validating model '%s'\n", name.c_str()); + // other exceptions will be handled by the outer ex_wrapper() + throw; + } + + if (!ok) { + throw std::invalid_argument("model validation failed, unable to download"); + } + + // then, proceed with the actual download + opts.skip_download = false; + SRV_INF("starting download for model '%s'\n", name.c_str()); + models.download(std::move(model), std::move(opts)); + + res_ok(res, {{"success", true}}); + return res; + }; + + this->del_router_models = [this](const server_http_req & req) { + auto res = std::make_unique(); + + std::string name = req.get_param("model"); + if (name.empty()) { + throw std::invalid_argument("model must be a non-empty string"); + } + + bool ok = models.remove(name); + if (!ok) { + throw std::runtime_error("failed to remove model '" + name + "'"); + } + + res_ok(res, {{"success", true}}); + return res; + }; } diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 2198589a7aa2..319c4352e2e9 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -1,9 +1,11 @@ #pragma once #include "common.h" +#include "download.h" #include "preset.h" #include "server-common.h" #include "server-http.h" +#include "server-queue.h" #include #include @@ -14,6 +16,8 @@ /** * state diagram: * + * DOWNLOADING ──► DOWNLOADED ──► (replaced by new instance) + * * UNLOADED ──► LOADING ──► LOADED ◄──── SLEEPING * ▲ │ │ ▲ * └───failed───┘ │ │ @@ -22,39 +26,43 @@ */ enum server_model_status { // TODO: also add downloading state when the logic is added + SERVER_MODEL_STATUS_DOWNLOADING, + SERVER_MODEL_STATUS_DOWNLOADED, SERVER_MODEL_STATUS_UNLOADED, SERVER_MODEL_STATUS_LOADING, SERVER_MODEL_STATUS_LOADED, SERVER_MODEL_STATUS_SLEEPING }; -static server_model_status server_model_status_from_string(const std::string & status_str) { - if (status_str == "unloaded") { - return SERVER_MODEL_STATUS_UNLOADED; - } - if (status_str == "loading") { - return SERVER_MODEL_STATUS_LOADING; - } - if (status_str == "loaded") { - return SERVER_MODEL_STATUS_LOADED; - } - if (status_str == "sleeping") { - return SERVER_MODEL_STATUS_SLEEPING; - } - throw std::runtime_error("invalid server model status"); -} +enum server_model_source { + SERVER_MODEL_SOURCE_PRESET, + SERVER_MODEL_SOURCE_MODELS_DIR, + SERVER_MODEL_SOURCE_CACHE, +}; static std::string server_model_status_to_string(server_model_status status) { switch (status) { - case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; - case SERVER_MODEL_STATUS_LOADING: return "loading"; - case SERVER_MODEL_STATUS_LOADED: return "loaded"; - case SERVER_MODEL_STATUS_SLEEPING: return "sleeping"; - default: return "unknown"; + case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading"; + case SERVER_MODEL_STATUS_DOWNLOADED: return "downloaded"; + case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; + case SERVER_MODEL_STATUS_LOADING: return "loading"; + case SERVER_MODEL_STATUS_LOADED: return "loaded"; + case SERVER_MODEL_STATUS_SLEEPING: return "sleeping"; + default: return "unknown"; + } +} + +static std::string server_model_source_to_string(server_model_source source) { + switch (source) { + case SERVER_MODEL_SOURCE_PRESET: return "preset"; + case SERVER_MODEL_SOURCE_MODELS_DIR: return "models_dir"; + case SERVER_MODEL_SOURCE_CACHE: return "cache"; + default: return "unknown"; } } struct server_model_meta { + server_model_source source = SERVER_MODEL_SOURCE_CACHE; common_preset preset; std::string name; std::set aliases; // additional names that resolve to this model @@ -63,11 +71,11 @@ struct server_model_meta { server_model_status status = SERVER_MODEL_STATUS_UNLOADED; int64_t last_used = 0; // for LRU unloading std::vector args; // args passed to the model instance, will be populated by render_args() - json loaded_info; // info to be reflected via /v1/models endpoint + json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown mtmd_caps multimodal; // multimodal capabilities - bool need_download = false; // whether the model needs to be downloaded before loading + // bool need_download = false; // whether the model needs to be downloaded before loading // TODO @ngxson: implement this bool is_ready() const { return status == SERVER_MODEL_STATUS_LOADED; @@ -85,12 +93,15 @@ struct server_model_meta { void update_caps(); }; -struct subprocess_s; +struct server_models_routes; +struct server_subproc; // defined in server-models.cpp struct server_models { + friend struct server_models_routes; + private: struct instance_t { - std::shared_ptr subproc; // shared between main thread and monitoring thread + std::shared_ptr subproc; // shared between main thread and monitoring thread std::thread th; server_model_meta meta; FILE * stdin_file = nullptr; @@ -107,6 +118,9 @@ struct server_models { // set to true while load_models() is executing a reload; load() will wait until clear bool is_reloading = false; + // if true, the next get_meta() will trigger a reload of model list + bool need_reload = false; + common_preset_context ctx_preset; common_params base_params; @@ -122,9 +136,14 @@ struct server_models { // not thread-safe, caller must hold mutex void add_model(server_model_meta && meta); + // notify SSE clients + void notify_sse(const std::string & event, const std::string & model_id, const json & data = nullptr); + public: server_models(const common_params & params, int argc, char ** argv); + server_response sse; // for real-time updates via SSE endpoint + // (re-)load the list of models from various sources and prepare the metadata mapping // - if this is called the first time, simply populate the metadata // - if this is called subsequently (e.g. when refreshing from disk): @@ -147,13 +166,24 @@ struct server_models { void unload(const std::string & name); void unload_all(); + // download a new model, progress is reported via SSE + // to stop the download, call unload() + void download(common_params_model && model, common_download_opts && opts); + // update the status of a model instance (thread-safe) void update_status(const std::string & name, server_model_status status, int exit_code); void update_loaded_info(const std::string & name, std::string & raw_info); + void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true); + + // remove a cache model from disk and update the list (thread-safe) + // note: only cache models can be removed; returns false if the model doesn't exist or is not a cache model + bool remove(const std::string & name); // wait until the model instance is fully loaded (thread-safe) + // note: predicate is called while holding the lock // return when the model no longer in "loading" state - void wait_until_loading_finished(const std::string & name); + void wait(const std::string & name, std::function predicate); + void wait(std::unique_lock & lk, const std::string & name, std::function predicate); // ensure the model is in ready state (thread-safe) // return false if model is ready @@ -176,8 +206,9 @@ struct server_models { struct server_models_routes { common_params params; - json ui_settings = json::object(); // Primary: new name - json webui_settings = json::object(); // Deprecated: use ui_settings (kept for compat) + json ui_settings = json::object(); // Primary: new name + json webui_settings = json::object(); // Deprecated: use ui_settings (kept for compat) + std::atomic stopping = false; // for graceful disconnecting SSE clients during shutdown server_models models; server_models_routes(const common_params & params, int argc, char ** argv) : params(params), models(params, argc, argv) { @@ -206,6 +237,10 @@ struct server_models_routes { server_http_context::handler_t get_router_models; server_http_context::handler_t post_router_models_load; server_http_context::handler_t post_router_models_unload; + // management API + server_http_context::handler_t get_router_models_sse; + server_http_context::handler_t post_router_models; + server_http_context::handler_t del_router_models; }; /** diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 32cfe7830c35..5d37c34536e3 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -331,6 +331,17 @@ void server_response::send(server_task_result_ptr && result) { } } +void server_response::broadcast(server_task_result_ptr && result) { + std::unique_lock lock(mutex_results); + for (const auto & id_task : waiting_task_ids) { + RES_DBG("task id = %d pushed to result queue\n", id_task); + server_task_result_ptr res_copy(result->clone()); + res_copy->id = id_task; // override id with target task id + queue_results.emplace_back(std::move(res_copy)); + } + condition_results.notify_all(); +} + void server_response::terminate() { running = false; condition_results.notify_all(); diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 35f010401fcb..0b674d6ff0f9 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -154,11 +154,15 @@ struct server_response { // Send a new result to a waiting id_task void send(server_task_result_ptr && result); + // broadcast a new result to all waiting tasks + // (used by router mode) + void broadcast(server_task_result_ptr && result); + // terminate the waiting loop void terminate(); }; -// utility class to make working with server_queue and server_response easier +// RAII wrapper to make working with server_queue and server_response easier // it provides a generator-like API for server responses // support pooling connection state and aggregating multiple results struct server_response_reader { diff --git a/tools/server/server-task.h b/tools/server/server-task.h index bdadcff76527..1a03d5f26604 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -312,6 +312,9 @@ struct server_task_result { } virtual json to_json() = 0; virtual ~server_task_result() = default; + virtual server_task_result * clone() const { + GGML_ABORT("not implemented for this task type"); + } }; // using shared_ptr for polymorphism of server_task_result @@ -649,3 +652,12 @@ struct server_prompt_cache { void update(); }; + +// used exclusively by router mode +struct server_task_result_router : server_task_result { + json data; + virtual json to_json() override { return data; } + virtual server_task_result * clone() const override { + return new server_task_result_router(*this); + } +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index da635c625668..0364d7d3b7ca 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -174,8 +174,11 @@ int llama_server(int argc, char ** argv) { routes.get_props = models_routes->get_router_props; routes.get_models = models_routes->get_router_models; + ctx_http.post("/models", ex_wrapper(models_routes->post_router_models)); ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load)); ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload)); + ctx_http.get ("/models/sse", ex_wrapper(models_routes->get_router_models_sse)); + ctx_http.del ("/models", ex_wrapper(models_routes->del_router_models)); } ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) @@ -261,6 +264,7 @@ int llama_server(int argc, char ** argv) { clean_up = [&models_routes]() { SRV_INF("%s: cleaning up before exit...\n", __func__); if (models_routes.has_value()) { + models_routes->stopping.store(true); // maybe redundant, but just to be safe models_routes->models.unload_all(); } llama_backend_free(); @@ -274,6 +278,10 @@ int llama_server(int argc, char ** argv) { ctx_http.is_ready.store(true); shutdown_handler = [&](int) { + if (models_routes.has_value()) { + // important to disconnect any SSE clients + models_routes->stopping.store(true); + } ctx_http.stop(); }; diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index c93b92b0b2ee..11c77ca7aa15 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -1,3 +1,4 @@ +import threading import pytest from utils import * @@ -253,3 +254,98 @@ def test_router_reload_models(): assert "model-reload-c" in ids, "newly added model should appear" finally: os.remove(preset_path) + + +MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16" +MODEL_DOWNLOAD_TIMEOUT = 300 + + +def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event): + """Collect /models/sse events into `collected` until `stop` is set.""" + url = f"http://{server.server_host}:{server.server_port}/models/sse" + try: + with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp: + for line_bytes in resp.iter_lines(): + if stop.is_set(): + break + line = line_bytes.decode("utf-8") + if line.startswith("data: "): + collected.append(json.loads(line[6:])) + except Exception: + pass + + +def _wait_for_sse_event(collected: list, event_type: str, model: str, timeout: int) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if any(e.get("event") == event_type and e.get("model") == model for e in collected): + return True + time.sleep(0.5) + return False + + +def test_router_download_model(): + """Case 1: download a model, verify SSE events and GET /models.""" + global server + server.start() + + # Ensure the model is not present before we start + server.make_request("DELETE", f"/models?model={MODEL_DOWNLOAD_ID}") + + sse_events: list = [] + stop = threading.Event() + sse_thread = threading.Thread( + target=_listen_sse, args=(server, sse_events, stop), daemon=True + ) + sse_thread.start() + + # Trigger the download + res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) + assert res.status_code == 200 + assert res.body.get("success") is True + + # Wait for download_finished SSE event + finished = _wait_for_sse_event( + sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT + ) + stop.set() + + assert finished, "Never received download_finished SSE event" + assert any( + e.get("event") == "download_progress" and e.get("model") == MODEL_DOWNLOAD_ID + for e in sse_events + ), "No download_progress events received" + + # Model should now appear in GET /models + ids = _get_model_ids(is_reload=False) + assert MODEL_DOWNLOAD_ID in ids, f"{MODEL_DOWNLOAD_ID} not found in /models after download" + + +def test_router_delete_model(): + """Case 2: delete the downloaded model, verify it disappears from GET /models.""" + global server + server.start() + + # Ensure the model exists (download it if needed) + if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False): + res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) + assert res.status_code == 200 + sse_events: list = [] + stop = threading.Event() + threading.Thread( + target=_listen_sse, args=(server, sse_events, stop), daemon=True + ).start() + finished = _wait_for_sse_event( + sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT + ) + stop.set() + assert finished, "Model did not finish downloading before delete test" + + # Delete the model + del_res = server.make_request("DELETE", f"/models?model={MODEL_DOWNLOAD_ID}") + assert del_res.status_code == 200 + assert del_res.body.get("success") is True + + # Model should no longer appear in GET /models + ids = _get_model_ids(is_reload=False) + assert MODEL_DOWNLOAD_ID not in ids, f"{MODEL_DOWNLOAD_ID} still present after deletion" diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index c5dba1c139fc..c50c9a0f5a71 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -340,6 +340,9 @@ def make_request( elif method == "POST": response = requests.post(url, headers=headers, json=data, timeout=timeout) parse_body = True + elif method == "DELETE": + response = requests.delete(url, headers=headers, timeout=timeout) + parse_body = True elif method == "OPTIONS": response = requests.options(url, headers=headers, timeout=timeout) else: