diff --git a/src/convert.cpp b/src/convert.cpp index 7cae8df0f..3996cf6df 100644 --- a/src/convert.cpp +++ b/src/convert.cpp @@ -99,7 +99,7 @@ bool convert(const char* input_path, model_loader.convert_tensors_name(); } - ggml_type type = (ggml_type)output_type; + ggml_type type = sd_type_to_ggml(output_type); bool output_is_safetensors = ends_with(output_path, ".safetensors"); TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules); diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 962f46a23..f428d161e 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -359,9 +359,7 @@ class StableDiffusionGGML { auto& tensor_storage_map = model_loader.get_tensor_storage_map(); LOG_INFO("Version: %s ", model_version_to_str[version]); - ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) - ? (ggml_type)sd_ctx_params->wtype - : GGML_TYPE_COUNT; + ggml_type wtype = sd_type_to_ggml(sd_ctx_params->wtype); std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules); if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) { model_loader.set_wtype_override(wtype, tensor_type_rules); diff --git a/src/util.cpp b/src/util.cpp index 1c2e5e899..271cf7ed7 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -406,6 +406,15 @@ std::vector split_string(const std::string& str, char delimiter) { return result; } +ggml_type sd_type_to_ggml(sd_type_t sdtype) { + const int type_value = static_cast(sdtype); + if (type_value < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT)) { + return static_cast(type_value); + } else { + return GGML_TYPE_COUNT; + } +} + static std::string build_progress_bar(int step, int steps) { std::string progress = " |"; int max_progress = 50; diff --git a/src/util.h b/src/util.h index 9843ae18f..3ceabfe20 100644 --- a/src/util.h +++ b/src/util.h @@ -70,6 +70,8 @@ void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...); +ggml_type sd_type_to_ggml(sd_type_t sdtype); + std::string trim(const std::string& s); std::vector> parse_prompt_attention(const std::string& text);