diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 88e726c60cde..9ceb9d4f9a09 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -2877,6 +2877,49 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + # sherpa-onnx CPU + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64,linux/arm64' + tag-latest: 'auto' + tag-suffix: '-cpu-sherpa-onnx' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "sherpa-onnx" + dockerfile: "./backend/Dockerfile.golang" + context: "./" + ubuntu-version: '2404' + # sherpa-onnx CUDA 12 + - build-type: 'cublas' + cuda-major-version: "12" + cuda-minor-version: "8" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-12-sherpa-onnx' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "sherpa-onnx" + dockerfile: "./backend/Dockerfile.golang" + context: "./" + ubuntu-version: '2404' + # sherpa-onnx CUDA 13 — requires onnxruntime 1.24.x+ for the + # gpu_cuda13 tarball; sherpa-onnx SHERPA_COMMIT pins to v1.12.39. + - build-type: 'cublas' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-13-sherpa-onnx' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "sherpa-onnx" + dockerfile: "./backend/Dockerfile.golang" + context: "./" + ubuntu-version: '2404' backend-jobs-darwin: uses: ./.github/workflows/backend_build_darwin.yml strategy: diff --git a/.github/workflows/test-extra.yml b/.github/workflows/test-extra.yml index 4c2a52fb87e1..67ab1693869b 100644 --- a/.github/workflows/test-extra.yml +++ b/.github/workflows/test-extra.yml @@ -40,6 +40,7 @@ jobs: kokoros: ${{ steps.detect.outputs.kokoros }} insightface: ${{ steps.detect.outputs.insightface }} speaker-recognition: ${{ steps.detect.outputs.speaker-recognition }} + sherpa-onnx: ${{ steps.detect.outputs.sherpa-onnx }} steps: - name: Checkout repository uses: actions/checkout@v6 @@ -506,6 +507,72 @@ jobs: - name: Build llama-cpp backend image and run audio transcription gRPC e2e tests run: | make test-extra-backend-llama-cpp-transcription + # Realtime e2e with sherpa-onnx driving VAD + STT + TTS against a mocked LLM. + # Builds the sherpa-onnx Docker image, extracts the rootfs so the e2e suite + # can discover the backend binary + shared libs, downloads the three model + # bundles (silero-vad, omnilingual-asr, vits-ljs) and drives the realtime + # websocket spec end-to-end. + tests-sherpa-onnx-realtime: + needs: detect-changes + if: needs.detect-changes.outputs.sherpa-onnx == 'true' || needs.detect-changes.outputs.run-all == 'true' + runs-on: ubuntu-latest + timeout-minutes: 90 + steps: + - name: Clone + uses: actions/checkout@v6 + with: + submodules: true + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.25.4' + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: '22' + - name: Build sherpa-onnx backend image and run realtime e2e tests + run: | + make test-extra-e2e-realtime-sherpa + # Streaming ASR via the sherpa-onnx online recognizer (zipformer + # transducer). Exercises both AudioTranscription (buffered) and + # AudioTranscriptionStream (real-time deltas) on the e2e-backends + # harness. + tests-sherpa-onnx-grpc-transcription: + needs: detect-changes + if: needs.detect-changes.outputs.sherpa-onnx == 'true' || needs.detect-changes.outputs.run-all == 'true' + runs-on: ubuntu-latest + timeout-minutes: 90 + steps: + - name: Clone + uses: actions/checkout@v6 + with: + submodules: true + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.25.4' + - name: Build sherpa-onnx backend image and run streaming ASR gRPC e2e tests + run: | + make test-extra-backend-sherpa-onnx-transcription + # VITS TTS via the sherpa-onnx backend. Drives both TTS (file write) and + # TTSStream (PCM chunks) on the e2e-backends harness. + tests-sherpa-onnx-grpc-tts: + needs: detect-changes + if: needs.detect-changes.outputs.sherpa-onnx == 'true' || needs.detect-changes.outputs.run-all == 'true' + runs-on: ubuntu-latest + timeout-minutes: 90 + steps: + - name: Clone + uses: actions/checkout@v6 + with: + submodules: true + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.25.4' + - name: Build sherpa-onnx backend image and run TTS gRPC e2e tests + run: | + make test-extra-backend-sherpa-onnx-tts tests-ik-llama-cpp-grpc: needs: detect-changes if: needs.detect-changes.outputs.ik-llama-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true' diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 51bae1cb1854..2885f1ce319d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -195,7 +195,7 @@ jobs: run: go version - name: Dependencies run: | - brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm opus + brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm opus ffmpeg pip install --user --no-cache-dir grpcio-tools grpcio - name: Setup Node.js uses: actions/setup-node@v6 diff --git a/Makefile b/Makefile index 1d93e61a51eb..4d64068f03fc 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/turboquant backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/insightface backends/speaker-recognition backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/sglang backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization backends/kokoros backends/sam3-cpp backends/qwen3-tts-cpp backends/tinygrad backends/sherpa-onnx GOCMD=go GOTEST=$(GOCMD) test @@ -750,6 +750,44 @@ test-extra-backend-speaker-recognition-ecapa: docker-build-speaker-recognition test-extra-backend-speaker-recognition-all: \ test-extra-backend-speaker-recognition-ecapa +## Realtime e2e with sherpa-onnx driving VAD + STT + TTS against a mocked +## LLM. Extracts the sherpa-onnx Docker image rootfs, downloads the three +## gallery-referenced model bundles (silero-vad, omnilingual-asr, vits-ljs), +## writes the corresponding model config YAMLs, and runs the realtime +## websocket spec in tests/e2e with REALTIME_* env vars wiring the sherpa +## slots into the pipeline. The LLM slot stays on the in-repo mock-backend +## registered unconditionally by tests/e2e/e2e_suite_test.go. See +## tests/e2e/run-realtime-sherpa.sh for the full orchestration. +test-extra-e2e-realtime-sherpa: build-mock-backend docker-build-sherpa-onnx protogen-go react-ui + bash tests/e2e/run-realtime-sherpa.sh + +## Streaming ASR via the sherpa-onnx online recognizer. Uses the streaming +## zipformer English model (encoder/decoder/joiner int8 + tokens) from the +## sherpa-onnx gallery entry. Drives both AudioTranscription and +## AudioTranscriptionStream via the e2e-backends gRPC harness; streaming +## emits real partial deltas during decode. Each file is renamed on download +## to the shape sherpa-onnx's online loader expects (encoder.int8.onnx etc.). +test-extra-backend-sherpa-onnx-transcription: docker-build-sherpa-onnx + BACKEND_IMAGE=local-ai-backend:sherpa-onnx \ + BACKEND_TEST_MODEL_URL='https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/encoder-epoch-99-avg-1-chunk-16-left-128.int8.onnx#encoder.int8.onnx' \ + BACKEND_TEST_EXTRA_FILES='https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/decoder-epoch-99-avg-1-chunk-16-left-128.int8.onnx#decoder.int8.onnx|https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/joiner-epoch-99-avg-1-chunk-16-left-128.int8.onnx#joiner.int8.onnx|https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/tokens.txt' \ + BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \ + BACKEND_TEST_CAPS=health,load,transcription \ + BACKEND_TEST_OPTIONS=subtype=online \ + $(MAKE) test-extra-backend + +## VITS TTS via the sherpa-onnx backend. Pulls the individual files from +## HuggingFace (the vits-ljs release tarball lives on the k2-fsa github +## but is also mirrored as discrete files on HF). Exercises both +## TTS (write-to-file) and TTSStream (PCM chunks + WAV header) via the +## e2e-backends gRPC harness. +test-extra-backend-sherpa-onnx-tts: docker-build-sherpa-onnx + BACKEND_IMAGE=local-ai-backend:sherpa-onnx \ + BACKEND_TEST_MODEL_URL='https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx#vits-ljs.onnx' \ + BACKEND_TEST_EXTRA_FILES='https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt|https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt' \ + BACKEND_TEST_CAPS=health,load,tts \ + $(MAKE) test-extra-backend + ## sglang mirrors the vllm setup: HuggingFace model id, same tiny Qwen, ## tool-call extraction via sglang's native qwen parser. CPU builds use ## sglang's upstream pyproject_cpu.toml recipe (see backend/python/sglang/install.sh). @@ -887,6 +925,7 @@ BACKEND_VOXTRAL = voxtral|golang|.|false|true BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true BACKEND_QWEN3_TTS_CPP = qwen3-tts-cpp|golang|.|false|true BACKEND_OPUS = opus|golang|.|false|true +BACKEND_SHERPA_ONNX = sherpa-onnx|golang|.|false|true # Python backends with root context BACKEND_RERANKERS = rerankers|python|.|false|true @@ -999,12 +1038,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION))) $(eval $(call generate-docker-build-target,$(BACKEND_TINYGRAD))) $(eval $(call generate-docker-build-target,$(BACKEND_KOKOROS))) $(eval $(call generate-docker-build-target,$(BACKEND_SAM3_CPP))) +$(eval $(call generate-docker-build-target,$(BACKEND_SHERPA_ONNX))) # Pattern rule for docker-save targets docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar -docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-insightface docker-build-speaker-recognition +docker-build-backends: docker-build-llama-cpp docker-build-ik-llama-cpp docker-build-turboquant docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-sglang docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization docker-build-tinygrad docker-build-kokoros docker-build-sam3-cpp docker-build-qwen3-tts-cpp docker-build-insightface docker-build-speaker-recognition docker-build-sherpa-onnx ######################################################## ### Mock Backend for E2E Tests diff --git a/backend/go/sherpa-onnx/.gitignore b/backend/go/sherpa-onnx/.gitignore new file mode 100644 index 000000000000..f2afac64c04a --- /dev/null +++ b/backend/go/sherpa-onnx/.gitignore @@ -0,0 +1,11 @@ +.cache/ +sources/ +build*/ +package/ +backend-assets/ +sherpa-onnx +*.so +compile_commands.json +sherpa-onnx-whisper-* +vits-ljs/ +streaming-zipformer-en/ diff --git a/backend/go/sherpa-onnx/Makefile b/backend/go/sherpa-onnx/Makefile new file mode 100644 index 000000000000..aa4c6f748f8d --- /dev/null +++ b/backend/go/sherpa-onnx/Makefile @@ -0,0 +1,120 @@ +CURRENT_DIR=$(abspath ./) +GOCMD=go + +ONNX_VERSION?=1.24.4 +# v1.12.39 — includes upstream's onnxruntime 1.24.4 bump (#3501). Earlier +# pinned commits only support onnxruntime 1.23.2, which has no CUDA 13 +# pre-built tarball, blocking the -gpu-nvidia-cuda-13 build matrix entry. +SHERPA_COMMIT?=7288d15e3e31a7bd589b2ba88828d521e7a6b140 +ONNX_ARCH?=x64 +ONNX_OS?=linux + +ifneq (,$(findstring aarch64,$(shell uname -m))) + ONNX_ARCH=aarch64 +endif + +ifeq ($(OS),Darwin) + ONNX_OS=osx + ifneq (,$(findstring aarch64,$(shell uname -m))) + ONNX_ARCH=arm64 + else ifneq (,$(findstring arm64,$(shell uname -m))) + ONNX_ARCH=arm64 + else + ONNX_ARCH=x86_64 + endif +endif + +# Upstream onnxruntime ships CUDA 12 and CUDA 13 variants under different +# names: -gpu-.tgz for CUDA 12, -gpu_cuda13-.tgz for CUDA 13 +# (note underscore vs dash). CUDA 13 tarballs only exist from 1.24.x onward. +ifeq ($(BUILD_TYPE),cublas) + SHERPA_GPU=ON + ONNX_PROVIDER=cuda + ifeq ($(CUDA_MAJOR_VERSION),13) + ONNX_VARIANT=-gpu_cuda13 + else + ONNX_VARIANT=-gpu + endif +else + ONNX_VARIANT= + SHERPA_GPU=OFF + ONNX_PROVIDER=cpu +endif + +JOBS?=$(shell nproc --ignore=1 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4) + +sources/onnxruntime: + mkdir -p sources/onnxruntime + curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)$(ONNX_VARIANT)-$(ONNX_VERSION).tgz \ + -o sources/onnxruntime/onnxruntime.tgz + cd sources/onnxruntime && tar -xf onnxruntime.tgz --strip-components=1 && rm onnxruntime.tgz + +sources/sherpa-onnx: sources/onnxruntime + git clone https://github.com/k2-fsa/sherpa-onnx.git sources/sherpa-onnx + cd sources/sherpa-onnx && git checkout $(SHERPA_COMMIT) + mkdir -p sources/sherpa-onnx/build + # sherpa-onnx's cmake detects a pre-installed onnxruntime via the + # SHERPA_ONNXRUNTIME_{INCLUDE,LIB}_DIR env vars (not via -D flags). + # Point them at our locally-downloaded Microsoft tarball — without + # this, sherpa-onnx falls through to download_onnxruntime() which + # fetches from csukuangfj/onnxruntime-libs. For the GPU 1.24.4 + # build that release mirror publishes `-patched.zip` instead of the + # expected `.tgz`, so the download 404s and the build fails. + cd sources/sherpa-onnx/build && \ + SHERPA_ONNXRUNTIME_INCLUDE_DIR=$(CURRENT_DIR)/sources/onnxruntime/include \ + SHERPA_ONNXRUNTIME_LIB_DIR=$(CURRENT_DIR)/sources/onnxruntime/lib \ + cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_FLAGS="-Wno-error=format-security" \ + -DCMAKE_CXX_FLAGS="-Wno-error=format-security" \ + -DSHERPA_ONNX_ENABLE_GPU=$(SHERPA_GPU) \ + -DSHERPA_ONNX_ENABLE_TTS=ON \ + -DSHERPA_ONNX_ENABLE_BINARY=OFF \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_C_API=ON \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE=ON \ + .. + cd sources/sherpa-onnx/build && make -j$(JOBS) + +backend-assets/lib: sources/sherpa-onnx sources/onnxruntime + mkdir -p backend-assets/lib + cp -rfLv sources/onnxruntime/lib/* backend-assets/lib/ + cp -rfLv sources/sherpa-onnx/build/lib/*.so* backend-assets/lib/ 2>/dev/null || true + cp -rfLv sources/sherpa-onnx/build/lib/*.dylib backend-assets/lib/ 2>/dev/null || true + +# libsherpa-shim wraps sherpa-onnx's nested config structs and TTS +# callback plumbing behind a purego-friendly API: opaque handles plus +# fixed-signature setters/getters/trampoline. Plain C compile — no cgo. +SHIM_EXT=so +ifeq ($(OS),Darwin) + SHIM_EXT=dylib +endif + +backend-assets/lib/libsherpa-shim.$(SHIM_EXT): csrc/shim.c csrc/shim.h backend-assets/lib + $(CC) -shared -fPIC -O2 \ + -I$(CURRENT_DIR)/sources/sherpa-onnx/sherpa-onnx/c-api \ + -o $@ csrc/shim.c \ + -L$(CURRENT_DIR)/backend-assets/lib \ + -lsherpa-onnx-c-api \ + -Wl,-rpath,'$$ORIGIN' + +sherpa-onnx: backend-assets/lib backend-assets/lib/libsherpa-shim.$(SHIM_EXT) + CGO_ENABLED=0 $(GOCMD) build \ + -ldflags "$(LD_FLAGS) -X main.onnxProvider=$(ONNX_PROVIDER)" \ + -tags "$(GO_TAGS)" -o sherpa-onnx ./ + +package: + bash package.sh + +build: sherpa-onnx package + +clean: + rm -rf sherpa-onnx sources/ backend-assets/ package/ vits-ljs/ sherpa-onnx-whisper-*/ + +test: sherpa-onnx + LD_LIBRARY_PATH=$(CURRENT_DIR)/backend-assets/lib \ + bash test.sh + +.PHONY: build package clean test diff --git a/backend/go/sherpa-onnx/backend.go b/backend/go/sherpa-onnx/backend.go new file mode 100644 index 000000000000..5d858357f73b --- /dev/null +++ b/backend/go/sherpa-onnx/backend.go @@ -0,0 +1,1249 @@ +package main + +import ( + "bytes" + "encoding/binary" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "unsafe" + + "github.com/ebitengine/purego" + laudio "github.com/mudler/LocalAI/pkg/audio" + "github.com/mudler/LocalAI/pkg/grpc/base" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/utils" +) + +type SherpaBackend struct { + base.SingleThread + tts uintptr + recognizer uintptr + onlineRecognizer uintptr + vad uintptr + vadSampleRate int + vadWindowSize int + ttsSpeed float32 + onlineChunkSamples int +} + +var onnxProvider = "cpu" + +// ============================================================= +// purego bindings +// ============================================================= + +// libsherpa-shim — config builders, setters, result accessors, +// create wrappers, TTS callback trampoline. +var ( + // VAD config + shimVadConfigNew func() uintptr + shimVadConfigFree func(uintptr) + shimVadConfigSetSileroModel func(uintptr, string) + shimVadConfigSetSileroThreshold func(uintptr, float32) + shimVadConfigSetSileroMinSilenceDuration func(uintptr, float32) + shimVadConfigSetSileroMinSpeechDuration func(uintptr, float32) + shimVadConfigSetSileroWindowSize func(uintptr, int32) + shimVadConfigSetSileroMaxSpeechDuration func(uintptr, float32) + shimVadConfigSetSampleRate func(uintptr, int32) + shimVadConfigSetNumThreads func(uintptr, int32) + shimVadConfigSetProvider func(uintptr, string) + shimVadConfigSetDebug func(uintptr, int32) + shimCreateVad func(uintptr, float32) uintptr + + // TTS (offline, VITS) config + shimTtsConfigNew func() uintptr + shimTtsConfigFree func(uintptr) + shimTtsConfigSetVitsModel func(uintptr, string) + shimTtsConfigSetVitsTokens func(uintptr, string) + shimTtsConfigSetVitsLexicon func(uintptr, string) + shimTtsConfigSetVitsDataDir func(uintptr, string) + shimTtsConfigSetVitsNoiseScale func(uintptr, float32) + shimTtsConfigSetVitsNoiseScaleW func(uintptr, float32) + shimTtsConfigSetVitsLengthScale func(uintptr, float32) + shimTtsConfigSetNumThreads func(uintptr, int32) + shimTtsConfigSetDebug func(uintptr, int32) + shimTtsConfigSetProvider func(uintptr, string) + shimTtsConfigSetMaxNumSentences func(uintptr, int32) + shimCreateOfflineTts func(uintptr) uintptr + + // Offline recognizer config + shimOfflineRecogConfigNew func() uintptr + shimOfflineRecogConfigFree func(uintptr) + shimOfflineRecogConfigSetNumThreads func(uintptr, int32) + shimOfflineRecogConfigSetDebug func(uintptr, int32) + shimOfflineRecogConfigSetProvider func(uintptr, string) + shimOfflineRecogConfigSetTokens func(uintptr, string) + shimOfflineRecogConfigSetFeatSampleRate func(uintptr, int32) + shimOfflineRecogConfigSetFeatFeatureDim func(uintptr, int32) + shimOfflineRecogConfigSetDecodingMethod func(uintptr, string) + shimOfflineRecogConfigSetWhisperEncoder func(uintptr, string) + shimOfflineRecogConfigSetWhisperDecoder func(uintptr, string) + shimOfflineRecogConfigSetWhisperLanguage func(uintptr, string) + shimOfflineRecogConfigSetWhisperTask func(uintptr, string) + shimOfflineRecogConfigSetWhisperTailPaddings func(uintptr, int32) + shimOfflineRecogConfigSetParaformerModel func(uintptr, string) + shimOfflineRecogConfigSetSenseVoiceModel func(uintptr, string) + shimOfflineRecogConfigSetSenseVoiceLanguage func(uintptr, string) + shimOfflineRecogConfigSetSenseVoiceUseITN func(uintptr, int32) + shimOfflineRecogConfigSetOmnilingualModel func(uintptr, string) + shimCreateOfflineRecognizer func(uintptr) uintptr + + // Online recognizer config + shimOnlineRecogConfigNew func() uintptr + shimOnlineRecogConfigFree func(uintptr) + shimOnlineRecogConfigSetTransducerEncoder func(uintptr, string) + shimOnlineRecogConfigSetTransducerDecoder func(uintptr, string) + shimOnlineRecogConfigSetTransducerJoiner func(uintptr, string) + shimOnlineRecogConfigSetTokens func(uintptr, string) + shimOnlineRecogConfigSetNumThreads func(uintptr, int32) + shimOnlineRecogConfigSetDebug func(uintptr, int32) + shimOnlineRecogConfigSetProvider func(uintptr, string) + shimOnlineRecogConfigSetFeatSampleRate func(uintptr, int32) + shimOnlineRecogConfigSetFeatFeatureDim func(uintptr, int32) + shimOnlineRecogConfigSetDecodingMethod func(uintptr, string) + shimOnlineRecogConfigSetEnableEndpoint func(uintptr, int32) + shimOnlineRecogConfigSetRule1MinTrailingSilence func(uintptr, float32) + shimOnlineRecogConfigSetRule2MinTrailingSilence func(uintptr, float32) + shimOnlineRecogConfigSetRule3MinUtteranceLength func(uintptr, float32) + shimCreateOnlineRecognizer func(uintptr) uintptr + + // Result accessors. Pointer returns use unsafe.Pointer so Go's + // vet checker doesn't flag them — the returned memory is C-owned, + // not subject to Go GC motion. + shimWaveSampleRate func(uintptr) int32 + shimWaveNumSamples func(uintptr) int32 + shimWaveSamples func(uintptr) unsafe.Pointer + shimOfflineResultText func(uintptr) unsafe.Pointer + shimOnlineResultText func(uintptr) unsafe.Pointer + shimGeneratedAudioSampleRate func(uintptr) int32 + shimGeneratedAudioN func(uintptr) int32 + shimGeneratedAudioSamples func(uintptr) unsafe.Pointer + shimSpeechSegmentStart func(uintptr) int32 + shimSpeechSegmentN func(uintptr) int32 + + // TTS streaming callback trampoline + shimTtsGenerateWithCallback func(tts uintptr, text string, sid int32, speed float32, cb uintptr, ud uintptr) uintptr +) + +// libsherpa-onnx-c-api pass-throughs — called directly from Go via purego. +// Sample-pointer args (`samples unsafe.Pointer`) accept either a raw C +// pointer returned by the shim or `unsafe.Pointer(&slice[0])` from Go. +var ( + // VAD + sherpaVadAcceptWaveform func(vad uintptr, samples unsafe.Pointer, n int32) + sherpaVadReset func(vad uintptr) + sherpaVadFlush func(vad uintptr) + sherpaVadEmpty func(vad uintptr) int32 + sherpaVadFront func(vad uintptr) uintptr + sherpaVadPop func(vad uintptr) + sherpaDestroySpeechSegment func(seg uintptr) + + // Wave IO + sherpaReadWave func(filename string) uintptr + sherpaFreeWave func(wave uintptr) + sherpaWriteWave func(samples unsafe.Pointer, n int32, sampleRate int32, filename string) int32 + + // Offline ASR + sherpaCreateOfflineStream func(rec uintptr) uintptr + sherpaDestroyOfflineStream func(stream uintptr) + sherpaAcceptWaveformOffline func(stream uintptr, sr int32, samples unsafe.Pointer, n int32) + sherpaDecodeOfflineStream func(rec uintptr, stream uintptr) + sherpaGetOfflineStreamResult func(stream uintptr) uintptr + sherpaDestroyOfflineRecognizerResult func(result uintptr) + + // Online ASR + sherpaCreateOnlineStream func(rec uintptr) uintptr + sherpaDestroyOnlineStream func(stream uintptr) + sherpaOnlineStreamAcceptWaveform func(stream uintptr, sr int32, samples unsafe.Pointer, n int32) + sherpaIsOnlineStreamReady func(rec uintptr, stream uintptr) int32 + sherpaDecodeOnlineStream func(rec uintptr, stream uintptr) + sherpaGetOnlineStreamResult func(rec uintptr, stream uintptr) uintptr + sherpaDestroyOnlineRecognizerResult func(result uintptr) + sherpaOnlineStreamIsEndpoint func(rec uintptr, stream uintptr) int32 + sherpaOnlineStreamReset func(rec uintptr, stream uintptr) + sherpaOnlineStreamInputFinished func(stream uintptr) + + // TTS + sherpaOfflineTtsGenerate func(tts uintptr, text string, sid int32, speed float32) uintptr + sherpaDestroyOfflineTtsGeneratedAudio func(audio uintptr) + sherpaOfflineTtsSampleRate func(tts uintptr) int32 +) + +var ( + loadLibsOnce sync.Once + loadLibsErr error +) + +// loadSherpaLibs dlopens libsherpa-shim and libsherpa-onnx-c-api (and any +// deps via RTLD_GLOBAL) and registers every function pointer above. +// Idempotent — safe to call from both main and test TestMain. +func loadSherpaLibs() error { + loadLibsOnce.Do(func() { + loadLibsErr = loadSherpaLibsOnce() + }) + return loadLibsErr +} + +func loadSherpaLibsOnce() error { + shimLib := os.Getenv("SHERPA_SHIM_LIBRARY") + if shimLib == "" { + shimLib = "libsherpa-shim.so" + } + capiLib := os.Getenv("SHERPA_ONNX_LIBRARY") + if capiLib == "" { + capiLib = "libsherpa-onnx-c-api.so" + } + + shim, err := purego.Dlopen(shimLib, purego.RTLD_NOW|purego.RTLD_GLOBAL) + if err != nil { + return fmt.Errorf("dlopen %s: %w", shimLib, err) + } + capi, err := purego.Dlopen(capiLib, purego.RTLD_NOW|purego.RTLD_GLOBAL) + if err != nil { + return fmt.Errorf("dlopen %s: %w", capiLib, err) + } + + // --- shim registrations --- + for _, r := range []struct { + ptr any + name string + }{ + {&shimVadConfigNew, "sherpa_shim_vad_config_new"}, + {&shimVadConfigFree, "sherpa_shim_vad_config_free"}, + {&shimVadConfigSetSileroModel, "sherpa_shim_vad_config_set_silero_model"}, + {&shimVadConfigSetSileroThreshold, "sherpa_shim_vad_config_set_silero_threshold"}, + {&shimVadConfigSetSileroMinSilenceDuration, "sherpa_shim_vad_config_set_silero_min_silence_duration"}, + {&shimVadConfigSetSileroMinSpeechDuration, "sherpa_shim_vad_config_set_silero_min_speech_duration"}, + {&shimVadConfigSetSileroWindowSize, "sherpa_shim_vad_config_set_silero_window_size"}, + {&shimVadConfigSetSileroMaxSpeechDuration, "sherpa_shim_vad_config_set_silero_max_speech_duration"}, + {&shimVadConfigSetSampleRate, "sherpa_shim_vad_config_set_sample_rate"}, + {&shimVadConfigSetNumThreads, "sherpa_shim_vad_config_set_num_threads"}, + {&shimVadConfigSetProvider, "sherpa_shim_vad_config_set_provider"}, + {&shimVadConfigSetDebug, "sherpa_shim_vad_config_set_debug"}, + {&shimCreateVad, "sherpa_shim_create_vad"}, + + {&shimTtsConfigNew, "sherpa_shim_tts_config_new"}, + {&shimTtsConfigFree, "sherpa_shim_tts_config_free"}, + {&shimTtsConfigSetVitsModel, "sherpa_shim_tts_config_set_vits_model"}, + {&shimTtsConfigSetVitsTokens, "sherpa_shim_tts_config_set_vits_tokens"}, + {&shimTtsConfigSetVitsLexicon, "sherpa_shim_tts_config_set_vits_lexicon"}, + {&shimTtsConfigSetVitsDataDir, "sherpa_shim_tts_config_set_vits_data_dir"}, + {&shimTtsConfigSetVitsNoiseScale, "sherpa_shim_tts_config_set_vits_noise_scale"}, + {&shimTtsConfigSetVitsNoiseScaleW, "sherpa_shim_tts_config_set_vits_noise_scale_w"}, + {&shimTtsConfigSetVitsLengthScale, "sherpa_shim_tts_config_set_vits_length_scale"}, + {&shimTtsConfigSetNumThreads, "sherpa_shim_tts_config_set_num_threads"}, + {&shimTtsConfigSetDebug, "sherpa_shim_tts_config_set_debug"}, + {&shimTtsConfigSetProvider, "sherpa_shim_tts_config_set_provider"}, + {&shimTtsConfigSetMaxNumSentences, "sherpa_shim_tts_config_set_max_num_sentences"}, + {&shimCreateOfflineTts, "sherpa_shim_create_offline_tts"}, + + {&shimOfflineRecogConfigNew, "sherpa_shim_offline_recog_config_new"}, + {&shimOfflineRecogConfigFree, "sherpa_shim_offline_recog_config_free"}, + {&shimOfflineRecogConfigSetNumThreads, "sherpa_shim_offline_recog_config_set_num_threads"}, + {&shimOfflineRecogConfigSetDebug, "sherpa_shim_offline_recog_config_set_debug"}, + {&shimOfflineRecogConfigSetProvider, "sherpa_shim_offline_recog_config_set_provider"}, + {&shimOfflineRecogConfigSetTokens, "sherpa_shim_offline_recog_config_set_tokens"}, + {&shimOfflineRecogConfigSetFeatSampleRate, "sherpa_shim_offline_recog_config_set_feat_sample_rate"}, + {&shimOfflineRecogConfigSetFeatFeatureDim, "sherpa_shim_offline_recog_config_set_feat_feature_dim"}, + {&shimOfflineRecogConfigSetDecodingMethod, "sherpa_shim_offline_recog_config_set_decoding_method"}, + {&shimOfflineRecogConfigSetWhisperEncoder, "sherpa_shim_offline_recog_config_set_whisper_encoder"}, + {&shimOfflineRecogConfigSetWhisperDecoder, "sherpa_shim_offline_recog_config_set_whisper_decoder"}, + {&shimOfflineRecogConfigSetWhisperLanguage, "sherpa_shim_offline_recog_config_set_whisper_language"}, + {&shimOfflineRecogConfigSetWhisperTask, "sherpa_shim_offline_recog_config_set_whisper_task"}, + {&shimOfflineRecogConfigSetWhisperTailPaddings, "sherpa_shim_offline_recog_config_set_whisper_tail_paddings"}, + {&shimOfflineRecogConfigSetParaformerModel, "sherpa_shim_offline_recog_config_set_paraformer_model"}, + {&shimOfflineRecogConfigSetSenseVoiceModel, "sherpa_shim_offline_recog_config_set_sense_voice_model"}, + {&shimOfflineRecogConfigSetSenseVoiceLanguage, "sherpa_shim_offline_recog_config_set_sense_voice_language"}, + {&shimOfflineRecogConfigSetSenseVoiceUseITN, "sherpa_shim_offline_recog_config_set_sense_voice_use_itn"}, + {&shimOfflineRecogConfigSetOmnilingualModel, "sherpa_shim_offline_recog_config_set_omnilingual_model"}, + {&shimCreateOfflineRecognizer, "sherpa_shim_create_offline_recognizer"}, + + {&shimOnlineRecogConfigNew, "sherpa_shim_online_recog_config_new"}, + {&shimOnlineRecogConfigFree, "sherpa_shim_online_recog_config_free"}, + {&shimOnlineRecogConfigSetTransducerEncoder, "sherpa_shim_online_recog_config_set_transducer_encoder"}, + {&shimOnlineRecogConfigSetTransducerDecoder, "sherpa_shim_online_recog_config_set_transducer_decoder"}, + {&shimOnlineRecogConfigSetTransducerJoiner, "sherpa_shim_online_recog_config_set_transducer_joiner"}, + {&shimOnlineRecogConfigSetTokens, "sherpa_shim_online_recog_config_set_tokens"}, + {&shimOnlineRecogConfigSetNumThreads, "sherpa_shim_online_recog_config_set_num_threads"}, + {&shimOnlineRecogConfigSetDebug, "sherpa_shim_online_recog_config_set_debug"}, + {&shimOnlineRecogConfigSetProvider, "sherpa_shim_online_recog_config_set_provider"}, + {&shimOnlineRecogConfigSetFeatSampleRate, "sherpa_shim_online_recog_config_set_feat_sample_rate"}, + {&shimOnlineRecogConfigSetFeatFeatureDim, "sherpa_shim_online_recog_config_set_feat_feature_dim"}, + {&shimOnlineRecogConfigSetDecodingMethod, "sherpa_shim_online_recog_config_set_decoding_method"}, + {&shimOnlineRecogConfigSetEnableEndpoint, "sherpa_shim_online_recog_config_set_enable_endpoint"}, + {&shimOnlineRecogConfigSetRule1MinTrailingSilence, "sherpa_shim_online_recog_config_set_rule1_min_trailing_silence"}, + {&shimOnlineRecogConfigSetRule2MinTrailingSilence, "sherpa_shim_online_recog_config_set_rule2_min_trailing_silence"}, + {&shimOnlineRecogConfigSetRule3MinUtteranceLength, "sherpa_shim_online_recog_config_set_rule3_min_utterance_length"}, + {&shimCreateOnlineRecognizer, "sherpa_shim_create_online_recognizer"}, + + {&shimWaveSampleRate, "sherpa_shim_wave_sample_rate"}, + {&shimWaveNumSamples, "sherpa_shim_wave_num_samples"}, + {&shimWaveSamples, "sherpa_shim_wave_samples"}, + {&shimOfflineResultText, "sherpa_shim_offline_result_text"}, + {&shimOnlineResultText, "sherpa_shim_online_result_text"}, + {&shimGeneratedAudioSampleRate, "sherpa_shim_generated_audio_sample_rate"}, + {&shimGeneratedAudioN, "sherpa_shim_generated_audio_n"}, + {&shimGeneratedAudioSamples, "sherpa_shim_generated_audio_samples"}, + {&shimSpeechSegmentStart, "sherpa_shim_speech_segment_start"}, + {&shimSpeechSegmentN, "sherpa_shim_speech_segment_n"}, + {&shimTtsGenerateWithCallback, "sherpa_shim_tts_generate_with_callback"}, + } { + purego.RegisterLibFunc(r.ptr, shim, r.name) + } + + // --- sherpa-onnx-c-api registrations --- + for _, r := range []struct { + ptr any + name string + }{ + {&sherpaVadAcceptWaveform, "SherpaOnnxVoiceActivityDetectorAcceptWaveform"}, + {&sherpaVadReset, "SherpaOnnxVoiceActivityDetectorReset"}, + {&sherpaVadFlush, "SherpaOnnxVoiceActivityDetectorFlush"}, + {&sherpaVadEmpty, "SherpaOnnxVoiceActivityDetectorEmpty"}, + {&sherpaVadFront, "SherpaOnnxVoiceActivityDetectorFront"}, + {&sherpaVadPop, "SherpaOnnxVoiceActivityDetectorPop"}, + {&sherpaDestroySpeechSegment, "SherpaOnnxDestroySpeechSegment"}, + + {&sherpaReadWave, "SherpaOnnxReadWave"}, + {&sherpaFreeWave, "SherpaOnnxFreeWave"}, + {&sherpaWriteWave, "SherpaOnnxWriteWave"}, + + {&sherpaCreateOfflineStream, "SherpaOnnxCreateOfflineStream"}, + {&sherpaDestroyOfflineStream, "SherpaOnnxDestroyOfflineStream"}, + {&sherpaAcceptWaveformOffline, "SherpaOnnxAcceptWaveformOffline"}, + {&sherpaDecodeOfflineStream, "SherpaOnnxDecodeOfflineStream"}, + {&sherpaGetOfflineStreamResult, "SherpaOnnxGetOfflineStreamResult"}, + {&sherpaDestroyOfflineRecognizerResult, "SherpaOnnxDestroyOfflineRecognizerResult"}, + + {&sherpaCreateOnlineStream, "SherpaOnnxCreateOnlineStream"}, + {&sherpaDestroyOnlineStream, "SherpaOnnxDestroyOnlineStream"}, + {&sherpaOnlineStreamAcceptWaveform, "SherpaOnnxOnlineStreamAcceptWaveform"}, + {&sherpaIsOnlineStreamReady, "SherpaOnnxIsOnlineStreamReady"}, + {&sherpaDecodeOnlineStream, "SherpaOnnxDecodeOnlineStream"}, + {&sherpaGetOnlineStreamResult, "SherpaOnnxGetOnlineStreamResult"}, + {&sherpaDestroyOnlineRecognizerResult, "SherpaOnnxDestroyOnlineRecognizerResult"}, + {&sherpaOnlineStreamIsEndpoint, "SherpaOnnxOnlineStreamIsEndpoint"}, + {&sherpaOnlineStreamReset, "SherpaOnnxOnlineStreamReset"}, + {&sherpaOnlineStreamInputFinished, "SherpaOnnxOnlineStreamInputFinished"}, + + {&sherpaOfflineTtsGenerate, "SherpaOnnxOfflineTtsGenerate"}, + {&sherpaDestroyOfflineTtsGeneratedAudio, "SherpaOnnxDestroyOfflineTtsGeneratedAudio"}, + {&sherpaOfflineTtsSampleRate, "SherpaOnnxOfflineTtsSampleRate"}, + } { + purego.RegisterLibFunc(r.ptr, capi, r.name) + } + + // Register the TTS streaming callback once. The callback pointer is + // stable for the lifetime of the process; user_data maps a particular + // TTSStream invocation to its Go state via ttsStates. + ttsCallbackPtr = purego.NewCallback(ttsStreamCallback) + return nil +} + +// ============================================================= +// Helpers +// ============================================================= + +// goStringFromCPtr reads a NUL-terminated C string into a Go string. +// Used to consume const char* returns from shim getters (which return +// unsafe.Pointer). Returns "" for nil. +func goStringFromCPtr(p unsafe.Pointer) string { + if p == nil { + return "" + } + n := 0 + for *(*byte)(unsafe.Add(p, n)) != 0 { + n++ + } + return string(unsafe.Slice((*byte)(p), n)) +} + +// sliceBasePtr returns an unsafe.Pointer to the first element of s, or +// nil for empty slices. Caller must keep the slice alive (runtime.KeepAlive) +// while the pointer is in use — purego passes it through without a copy. +func sliceBasePtr[T any](s []T) unsafe.Pointer { + if len(s) == 0 { + return nil + } + return unsafe.Pointer(&s[0]) +} + +func isASRType(t string) bool { + t = strings.ToLower(t) + return t == "asr" || t == "transcription" || t == "transcribe" +} + +func isVADType(t string) bool { + t = strings.ToLower(t) + return t == "vad" +} + +// Model-options prefixes recognised by this backend. Kept as typed +// constants so the asrFamily / loadWhisperASR / loadGenericASR paths +// can all speak the same vocabulary. +const ( + optionSubtype = "subtype=" + optionLanguage = "language=" + + // VAD (Silero) — see upstream sherpa-onnx SherpaOnnxVadModelConfig. + optionVadThreshold = "vad.threshold=" + optionVadMinSilence = "vad.min_silence=" + optionVadMinSpeech = "vad.min_speech=" + optionVadWindowSize = "vad.window_size=" + optionVadMaxSpeech = "vad.max_speech=" + optionVadSampleRate = "vad.sample_rate=" + optionVadBufferSize = "vad.buffer_size=" + + // TTS (VITS) — see upstream SherpaOnnxOfflineTtsVitsModelConfig. + optionTtsNoiseScale = "tts.noise_scale=" + optionTtsNoiseScaleW = "tts.noise_scale_w=" + optionTtsLengthScale = "tts.length_scale=" + optionTtsMaxNumSentences = "tts.max_num_sentences=" + optionTtsSpeed = "tts.speed=" + + // Offline ASR — shared across whisper/paraformer/sense_voice/omnilingual, + // and reused for online ASR feat_config below. + optionAsrSampleRate = "asr.sample_rate=" + optionAsrFeatureDim = "asr.feature_dim=" + optionAsrDecodingMethod = "asr.decoding_method=" + optionAsrWhisperTask = "asr.whisper.task=" + optionAsrWhisperTailPaddings = "asr.whisper.tail_paddings=" + optionAsrSenseVoiceUseITN = "asr.sense_voice.use_itn=" + + // Online/streaming ASR (zipformer transducer) — endpoint rules and + // chunking. `online.chunk_samples` is a LocalAI-only knob (drives + // how much audio runOnlineASR feeds per decode call). + optionOnlineEnableEndpoint = "online.enable_endpoint=" + optionOnlineRule1 = "online.rule1_min_trailing_silence=" + optionOnlineRule2 = "online.rule2_min_trailing_silence=" + optionOnlineRule3 = "online.rule3_min_utterance_length=" + optionOnlineChunkSamples = "online.chunk_samples=" +) + +func hasOption(opts *pb.ModelOptions, prefix string) bool { + for _, o := range opts.Options { + if strings.HasPrefix(o, prefix) { + return true + } + } + return false +} + +// findOptionValue returns the first option value matching prefix, or +// the default if no such option is present. Used for parsing +// `subtype=xxx`, `language=yyy` etc. +func findOptionValue(opts *pb.ModelOptions, prefix, defaultValue string) string { + for _, o := range opts.Options { + if strings.HasPrefix(o, prefix) { + return strings.TrimPrefix(o, prefix) + } + } + return defaultValue +} + +// Typed option lookups. Parse failure falls back to the default — +// badly formed options shouldn't prevent model load. +func findOptionFloat(opts *pb.ModelOptions, prefix string, def float32) float32 { + raw := findOptionValue(opts, prefix, "") + if raw == "" { + return def + } + v, err := strconv.ParseFloat(raw, 32) + if err != nil { + return def + } + return float32(v) +} + +func findOptionInt(opts *pb.ModelOptions, prefix string, def int32) int32 { + raw := findOptionValue(opts, prefix, "") + if raw == "" { + return def + } + v, err := strconv.ParseInt(raw, 10, 32) + if err != nil { + return def + } + return int32(v) +} + +// findOptionBool returns 0 or 1. Accepts "0"/"1", "true"/"false", +// "yes"/"no", "on"/"off" (case-insensitive). Sherpa's C API takes int32, +// not bool, so the return type mirrors that. +func findOptionBool(opts *pb.ModelOptions, prefix string, def int32) int32 { + raw := findOptionValue(opts, prefix, "") + if raw == "" { + return def + } + switch strings.ToLower(raw) { + case "1", "true", "yes", "on": + return 1 + case "0", "false", "no", "off": + return 0 + } + return def +} + +func (s *SherpaBackend) Load(opts *pb.ModelOptions) error { + if isVADType(opts.Type) { + return s.loadVAD(opts) + } + // An explicit `subtype=...` option routes to ASR even when Type is + // unset — handy for the e2e-backends harness, which doesn't know + // about ModelOptions.Type. + if isASRType(opts.Type) || hasOption(opts, optionSubtype) { + return s.loadASR(opts) + } + return s.loadTTS(opts) +} + +// ============================================================= +// VAD +// ============================================================= + +func (s *SherpaBackend) loadVAD(opts *pb.ModelOptions) error { + if s.vad != 0 { + return nil + } + + cfg := shimVadConfigNew() + defer shimVadConfigFree(cfg) + + windowSize := findOptionInt(opts, optionVadWindowSize, 512) + sampleRate := findOptionInt(opts, optionVadSampleRate, 16000) + + shimVadConfigSetSileroModel(cfg, opts.ModelFile) + shimVadConfigSetSileroThreshold(cfg, findOptionFloat(opts, optionVadThreshold, 0.5)) + shimVadConfigSetSileroMinSilenceDuration(cfg, findOptionFloat(opts, optionVadMinSilence, 0.5)) + shimVadConfigSetSileroMinSpeechDuration(cfg, findOptionFloat(opts, optionVadMinSpeech, 0.25)) + shimVadConfigSetSileroWindowSize(cfg, windowSize) + shimVadConfigSetSileroMaxSpeechDuration(cfg, findOptionFloat(opts, optionVadMaxSpeech, 20.0)) + shimVadConfigSetSampleRate(cfg, sampleRate) + + threads := int32(1) + if opts.Threads != 0 { + threads = opts.Threads + } + shimVadConfigSetNumThreads(cfg, threads) + shimVadConfigSetProvider(cfg, onnxProvider) + shimVadConfigSetDebug(cfg, 0) + + vad := shimCreateVad(cfg, findOptionFloat(opts, optionVadBufferSize, 60.0)) + if vad == 0 { + return fmt.Errorf("failed to create sherpa-onnx VAD from %s", opts.ModelFile) + } + s.vad = vad + s.vadSampleRate = int(sampleRate) + s.vadWindowSize = int(windowSize) + return nil +} + +func (s *SherpaBackend) VAD(req *pb.VADRequest) (pb.VADResponse, error) { + if s.vad == 0 { + return pb.VADResponse{}, fmt.Errorf("sherpa-onnx VAD not loaded (model must be loaded with type=vad)") + } + + audio := req.Audio + if len(audio) == 0 { + return pb.VADResponse{Segments: []*pb.VADSegment{}}, nil + } + + sherpaVadReset(s.vad) + + windowSize := s.vadWindowSize + for i := 0; i+windowSize <= len(audio); i += windowSize { + sherpaVadAcceptWaveform(s.vad, sliceBasePtr(audio[i:i+windowSize]), int32(windowSize)) + } + if remaining := len(audio) % windowSize; remaining > 0 { + padded := make([]float32, windowSize) + copy(padded, audio[len(audio)-remaining:]) + sherpaVadAcceptWaveform(s.vad, sliceBasePtr(padded), int32(windowSize)) + } + sherpaVadFlush(s.vad) + + var segments []*pb.VADSegment + for sherpaVadEmpty(s.vad) == 0 { + seg := sherpaVadFront(s.vad) + if seg == 0 { + break + } + start := shimSpeechSegmentStart(seg) + n := shimSpeechSegmentN(seg) + startSec := float32(start) / float32(s.vadSampleRate) + endSec := float32(start+n) / float32(s.vadSampleRate) + segments = append(segments, &pb.VADSegment{Start: startSec, End: endSec}) + sherpaDestroySpeechSegment(seg) + sherpaVadPop(s.vad) + } + + if segments == nil { + segments = []*pb.VADSegment{} + } + return pb.VADResponse{Segments: segments}, nil +} + +// ============================================================= +// TTS +// ============================================================= + +func (s *SherpaBackend) loadTTS(opts *pb.ModelOptions) error { + if s.tts != 0 { + return nil + } + + modelFile := opts.ModelFile + modelDir := filepath.Dir(modelFile) + + cfg := shimTtsConfigNew() + defer shimTtsConfigFree(cfg) + + shimTtsConfigSetVitsModel(cfg, modelFile) + + if tokensPath := filepath.Join(modelDir, "tokens.txt"); fileExists(tokensPath) { + shimTtsConfigSetVitsTokens(cfg, tokensPath) + } + if lexiconPath := filepath.Join(modelDir, "lexicon.txt"); fileExists(lexiconPath) { + shimTtsConfigSetVitsLexicon(cfg, lexiconPath) + } + if dataDir := filepath.Join(modelDir, "espeak-ng-data"); dirExists(dataDir) { + shimTtsConfigSetVitsDataDir(cfg, dataDir) + } + + shimTtsConfigSetVitsNoiseScale(cfg, findOptionFloat(opts, optionTtsNoiseScale, 0.667)) + shimTtsConfigSetVitsNoiseScaleW(cfg, findOptionFloat(opts, optionTtsNoiseScaleW, 0.8)) + shimTtsConfigSetVitsLengthScale(cfg, findOptionFloat(opts, optionTtsLengthScale, 1.0)) + + threads := int32(1) + if opts.Threads != 0 { + threads = opts.Threads + } + shimTtsConfigSetNumThreads(cfg, threads) + shimTtsConfigSetDebug(cfg, 0) + shimTtsConfigSetProvider(cfg, onnxProvider) + shimTtsConfigSetMaxNumSentences(cfg, findOptionInt(opts, optionTtsMaxNumSentences, 1)) + + s.ttsSpeed = findOptionFloat(opts, optionTtsSpeed, 1.0) + + tts := shimCreateOfflineTts(cfg) + if tts == 0 { + return fmt.Errorf("failed to create sherpa-onnx TTS engine from %s", modelFile) + } + s.tts = tts + return nil +} + +func fileExists(p string) bool { + info, err := os.Stat(p) + return err == nil && !info.IsDir() +} + +func dirExists(p string) bool { + info, err := os.Stat(p) + return err == nil && info.IsDir() +} + +func findTokens(modelDir string) string { + candidates := []string{"tokens.txt"} + if entries, err := os.ReadDir(modelDir); err == nil { + for _, e := range entries { + if strings.HasSuffix(e.Name(), "-tokens.txt") { + candidates = append([]string{e.Name()}, candidates...) + } + } + } + for _, c := range candidates { + p := filepath.Join(modelDir, c) + if fileExists(p) { + return p + } + } + return "" +} + +// findWhisperPair scans modelDir for *-encoder.onnx / *-decoder.onnx pairs, +// preferring int8 variants. Returns encoder, decoder paths or empty strings. +func findWhisperPair(modelDir string) (string, string) { + entries, err := os.ReadDir(modelDir) + if err != nil { + return "", "" + } + + var encoderInt8, decoderInt8 string + var encoderFP, decoderFP string + + for _, e := range entries { + name := e.Name() + switch { + case strings.HasSuffix(name, "-encoder.int8.onnx"): + encoderInt8 = filepath.Join(modelDir, name) + case strings.HasSuffix(name, "-decoder.int8.onnx"): + decoderInt8 = filepath.Join(modelDir, name) + case strings.HasSuffix(name, "-encoder.onnx"): + encoderFP = filepath.Join(modelDir, name) + case strings.HasSuffix(name, "-decoder.onnx"): + decoderFP = filepath.Join(modelDir, name) + case name == "encoder.onnx": + encoderFP = filepath.Join(modelDir, name) + case name == "decoder.onnx": + decoderFP = filepath.Join(modelDir, name) + } + } + + if encoderInt8 != "" && decoderInt8 != "" { + return encoderInt8, decoderInt8 + } + return encoderFP, decoderFP +} + +// ============================================================= +// ASR +// ============================================================= + +type asrFamilyT string + +const ( + familyParaformer asrFamilyT = "paraformer" + familySensevoice asrFamilyT = "sensevoice" + familyOmnilingual asrFamilyT = "omnilingual" + familyOnline asrFamilyT = "online" +) + +// asrFamily classifies the ASR model family from an explicit option or a +// path-substring heuristic. Sherpa-onnx's factory picks an impl from the +// first non-empty model field in OfflineModelConfig, and wrong-family +// metadata reads inside that impl call SHERPA_ONNX_EXIT(-1) which kills the +// whole process. So we must commit to one family before calling Create. +func asrFamily(opts *pb.ModelOptions) asrFamilyT { + if v := findOptionValue(opts, optionSubtype, ""); v != "" { + return asrFamilyT(strings.ToLower(v)) + } + if enc, dec, join := findZipformerTriple(filepath.Dir(opts.ModelFile)); enc != "" && dec != "" && join != "" { + return familyOnline + } + lower := strings.ToLower(opts.ModelFile) + switch { + case strings.Contains(lower, "omnilingual"): + return familyOmnilingual + case strings.Contains(lower, "paraformer"): + return familyParaformer + case strings.Contains(lower, "sense-voice"), strings.Contains(lower, "sense_voice"), strings.Contains(lower, "sensevoice"): + return familySensevoice + case strings.Contains(lower, "streaming"), strings.Contains(lower, "online"): + return familyOnline + default: + return familyParaformer + } +} + +// findZipformerTriple returns the encoder, decoder and joiner paths for a +// streaming zipformer transducer, preferring the int8 variants over fp. +func findZipformerTriple(dir string) (enc, dec, join string) { + entries, err := os.ReadDir(dir) + if err != nil { + return "", "", "" + } + var encInt8, decInt8, joinInt8 string + var encFP, decFP, joinFP string + for _, e := range entries { + name := e.Name() + lower := strings.ToLower(name) + path := filepath.Join(dir, name) + switch { + case strings.Contains(lower, "encoder") && strings.HasSuffix(lower, ".int8.onnx"): + encInt8 = path + case strings.Contains(lower, "decoder") && strings.HasSuffix(lower, ".int8.onnx"): + decInt8 = path + case strings.Contains(lower, "joiner") && strings.HasSuffix(lower, ".int8.onnx"): + joinInt8 = path + case strings.Contains(lower, "encoder") && strings.HasSuffix(lower, ".onnx"): + encFP = path + case strings.Contains(lower, "decoder") && strings.HasSuffix(lower, ".onnx"): + decFP = path + case strings.Contains(lower, "joiner") && strings.HasSuffix(lower, ".onnx"): + joinFP = path + } + } + if encInt8 != "" && decInt8 != "" && joinInt8 != "" { + return encInt8, decInt8, joinInt8 + } + return encFP, decFP, joinFP +} + +func (s *SherpaBackend) loadASR(opts *pb.ModelOptions) error { + if s.recognizer != 0 || s.onlineRecognizer != 0 { + return nil + } + + // Streaming zipformer models take a different C API (online recognizer) + // and dispatch before we touch the offline config. Triggered explicitly + // by `subtype=online` or heuristically by detecting encoder/decoder/joiner + // triples in the model directory. + if asrFamily(opts) == familyOnline { + return s.loadOnlineASR(opts) + } + + cfg := shimOfflineRecogConfigNew() + defer shimOfflineRecogConfigFree(cfg) + + threads := int32(1) + if opts.Threads != 0 { + threads = opts.Threads + } + shimOfflineRecogConfigSetNumThreads(cfg, threads) + shimOfflineRecogConfigSetDebug(cfg, 0) + shimOfflineRecogConfigSetProvider(cfg, onnxProvider) + + modelFile := opts.ModelFile + modelDir := filepath.Dir(modelFile) + if tokensPath := findTokens(modelDir); tokensPath != "" { + shimOfflineRecogConfigSetTokens(cfg, tokensPath) + } + + shimOfflineRecogConfigSetFeatSampleRate(cfg, findOptionInt(opts, optionAsrSampleRate, 16000)) + shimOfflineRecogConfigSetFeatFeatureDim(cfg, findOptionInt(opts, optionAsrFeatureDim, 80)) + shimOfflineRecogConfigSetDecodingMethod(cfg, findOptionValue(opts, optionAsrDecodingMethod, "greedy_search")) + + // Detect model type from files in the model directory. + // Whisper models have separate encoder/decoder files (e.g. tiny.en-encoder.onnx). + encoderPath, decoderPath := findWhisperPair(modelDir) + if encoderPath != "" && decoderPath != "" { + return s.loadWhisperASR(cfg, opts, encoderPath, decoderPath) + } + + if fileExists(modelFile) { + return s.loadGenericASR(cfg, opts) + } + return fmt.Errorf("no recognizable ASR model found in %s", modelDir) +} + +func (s *SherpaBackend) loadWhisperASR(cfg uintptr, opts *pb.ModelOptions, encoderPath, decoderPath string) error { + shimOfflineRecogConfigSetWhisperEncoder(cfg, encoderPath) + shimOfflineRecogConfigSetWhisperDecoder(cfg, decoderPath) + shimOfflineRecogConfigSetWhisperLanguage(cfg, findOptionValue(opts, optionLanguage, "en")) + shimOfflineRecogConfigSetWhisperTask(cfg, findOptionValue(opts, optionAsrWhisperTask, "transcribe")) + shimOfflineRecogConfigSetWhisperTailPaddings(cfg, findOptionInt(opts, optionAsrWhisperTailPaddings, -1)) + + rec := shimCreateOfflineRecognizer(cfg) + if rec == 0 { + return fmt.Errorf("failed to create sherpa-onnx whisper recognizer from %s", filepath.Dir(encoderPath)) + } + s.recognizer = rec + return nil +} + +func (s *SherpaBackend) loadGenericASR(cfg uintptr, opts *pb.ModelOptions) error { + switch asrFamily(opts) { + case familyOmnilingual: + shimOfflineRecogConfigSetOmnilingualModel(cfg, opts.ModelFile) + case familySensevoice: + shimOfflineRecogConfigSetSenseVoiceModel(cfg, opts.ModelFile) + shimOfflineRecogConfigSetSenseVoiceLanguage(cfg, findOptionValue(opts, optionLanguage, "auto")) + // Upstream defaults ITN off; LocalAI enables it so transcription + // output is formatted ("100" not "one hundred"). Users who want + // raw tokens can set asr.sense_voice.use_itn=0. + shimOfflineRecogConfigSetSenseVoiceUseITN(cfg, findOptionBool(opts, optionAsrSenseVoiceUseITN, 1)) + default: // paraformer + shimOfflineRecogConfigSetParaformerModel(cfg, opts.ModelFile) + } + + rec := shimCreateOfflineRecognizer(cfg) + if rec == 0 { + return fmt.Errorf("failed to create sherpa-onnx recognizer from %s", opts.ModelFile) + } + s.recognizer = rec + return nil +} + +func (s *SherpaBackend) loadOnlineASR(opts *pb.ModelOptions) error { + modelDir := filepath.Dir(opts.ModelFile) + enc, dec, join := findZipformerTriple(modelDir) + if enc == "" || dec == "" || join == "" { + return fmt.Errorf("streaming zipformer requires encoder/decoder/joiner .onnx files in %s", modelDir) + } + tokens := findTokens(modelDir) + if tokens == "" { + return fmt.Errorf("tokens.txt not found next to streaming zipformer model in %s", modelDir) + } + + cfg := shimOnlineRecogConfigNew() + defer shimOnlineRecogConfigFree(cfg) + + shimOnlineRecogConfigSetTransducerEncoder(cfg, enc) + shimOnlineRecogConfigSetTransducerDecoder(cfg, dec) + shimOnlineRecogConfigSetTransducerJoiner(cfg, join) + shimOnlineRecogConfigSetTokens(cfg, tokens) + + threads := int32(1) + if opts.Threads != 0 { + threads = opts.Threads + } + shimOnlineRecogConfigSetNumThreads(cfg, threads) + shimOnlineRecogConfigSetDebug(cfg, 0) + shimOnlineRecogConfigSetProvider(cfg, onnxProvider) + + shimOnlineRecogConfigSetFeatSampleRate(cfg, findOptionInt(opts, optionAsrSampleRate, 16000)) + shimOnlineRecogConfigSetFeatFeatureDim(cfg, findOptionInt(opts, optionAsrFeatureDim, 80)) + shimOnlineRecogConfigSetDecodingMethod(cfg, findOptionValue(opts, optionAsrDecodingMethod, "greedy_search")) + + // Endpoint detection. Upstream sherpa defaults to off; LocalAI leaves + // it on because streaming ASR consumers (realtime pipeline, raw gRPC + // clients) need segment boundaries to know when utterances end. + // Disable via online.enable_endpoint=0 when pairing with an external + // endpointer. + shimOnlineRecogConfigSetEnableEndpoint(cfg, findOptionBool(opts, optionOnlineEnableEndpoint, 1)) + shimOnlineRecogConfigSetRule1MinTrailingSilence(cfg, findOptionFloat(opts, optionOnlineRule1, 2.4)) + shimOnlineRecogConfigSetRule2MinTrailingSilence(cfg, findOptionFloat(opts, optionOnlineRule2, 1.2)) + shimOnlineRecogConfigSetRule3MinUtteranceLength(cfg, findOptionFloat(opts, optionOnlineRule3, 20.0)) + + rec := shimCreateOnlineRecognizer(cfg) + if rec == 0 { + return fmt.Errorf("failed to create sherpa-onnx online recognizer from %s", modelDir) + } + s.onlineRecognizer = rec + s.onlineChunkSamples = int(findOptionInt(opts, optionOnlineChunkSamples, 1600)) + return nil +} + +// ============================================================= +// Transcription +// ============================================================= + +func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) { + if s.onlineRecognizer != 0 { + return s.runOnlineASR(req, nil) + } + if s.recognizer == 0 { + return pb.TranscriptResult{}, fmt.Errorf("sherpa-onnx ASR not loaded (model must be loaded with type=asr)") + } + + dir, err := os.MkdirTemp("", "sherpa-asr") + if err != nil { + return pb.TranscriptResult{}, fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(dir) + + wavPath := filepath.Join(dir, "input.wav") + if err := utils.AudioToWav(req.Dst, wavPath); err != nil { + return pb.TranscriptResult{}, fmt.Errorf("failed to convert audio to wav: %w", err) + } + + wave := sherpaReadWave(wavPath) + if wave == 0 { + return pb.TranscriptResult{}, fmt.Errorf("failed to read wav file %s", wavPath) + } + defer sherpaFreeWave(wave) + + stream := sherpaCreateOfflineStream(s.recognizer) + if stream == 0 { + return pb.TranscriptResult{}, fmt.Errorf("failed to create offline stream") + } + defer sherpaDestroyOfflineStream(stream) + + sr := shimWaveSampleRate(wave) + samples := shimWaveSamples(wave) + nSamples := shimWaveNumSamples(wave) + sherpaAcceptWaveformOffline(stream, sr, samples, nSamples) + sherpaDecodeOfflineStream(s.recognizer, stream) + + result := sherpaGetOfflineStreamResult(stream) + if result == 0 { + return pb.TranscriptResult{}, fmt.Errorf("failed to get recognition result") + } + defer sherpaDestroyOfflineRecognizerResult(result) + + text := strings.TrimSpace(goStringFromCPtr(shimOfflineResultText(result))) + + return pb.TranscriptResult{ + Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}}, + Text: text, + }, nil +} + +// AudioTranscriptionStream drives sherpa-onnx's online recognizer and emits +// incremental `delta` events on the response channel as new tokens are +// produced, then one `final_result`. Only implemented for online-loaded +// recognizers — offline models can't stream partial decode results. +// Closes `results` before returning so the server wrapper's reader +// goroutine can exit. +func (s *SherpaBackend) AudioTranscriptionStream( + req *pb.TranscriptRequest, + results chan *pb.TranscriptStreamResponse, +) error { + defer close(results) + if s.onlineRecognizer == 0 { + return fmt.Errorf("sherpa-onnx streaming transcription requires an online model (load with options: subtype=online)") + } + emitDelta := func(delta string) { + if delta == "" { + return + } + results <- &pb.TranscriptStreamResponse{Delta: delta} + } + result, err := s.runOnlineASR(req, emitDelta) + if err != nil { + return err + } + results <- &pb.TranscriptStreamResponse{FinalResult: &result} + return nil +} + +// runOnlineASR feeds a request's audio through sherpa-onnx's online +// recognizer in ~100ms chunks and assembles a TranscriptResult. When +// emitDelta is non-nil, it's called with the newly-appended text each +// time the decoded transcript grows — this is what drives the streaming +// `delta` events for AudioTranscriptionStream. +// +// Endpoint detection is configured when the recognizer is created, so +// multi-utterance inputs emit multiple segments. The returned result +// concatenates all segments into a single `Text` field, matching the +// offline path's contract. +func (s *SherpaBackend) runOnlineASR( + req *pb.TranscriptRequest, + emitDelta func(string), +) (pb.TranscriptResult, error) { + dir, err := os.MkdirTemp("", "sherpa-online") + if err != nil { + return pb.TranscriptResult{}, fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(dir) + + wavPath := filepath.Join(dir, "input.wav") + if err := utils.AudioToWav(req.Dst, wavPath); err != nil { + return pb.TranscriptResult{}, fmt.Errorf("failed to convert audio to wav: %w", err) + } + + wave := sherpaReadWave(wavPath) + if wave == 0 { + return pb.TranscriptResult{}, fmt.Errorf("failed to read wav file %s", wavPath) + } + defer sherpaFreeWave(wave) + + stream := sherpaCreateOnlineStream(s.onlineRecognizer) + if stream == 0 { + return pb.TranscriptResult{}, fmt.Errorf("failed to create online stream") + } + defer sherpaDestroyOnlineStream(stream) + + total := int(shimWaveNumSamples(wave)) + sr := shimWaveSampleRate(wave) + basePtr := shimWaveSamples(wave) + // Chunk size is a sample count set via online.chunk_samples at + // model-load time (default 1600 = 100 ms @ 16 kHz). + chunkSamples := s.onlineChunkSamples + + // Endpoint-aware decoding emits one segment per utterance, each starting + // fresh from currentText="". We track segments and a running total of + // all emitted delta text — the TranscriptStreamResponse contract + // requires concat(deltas) == final.Text, so we keep emitted verbatim. + var segments []*pb.TranscriptSegment + var currentText string + var emittedAll strings.Builder + + emit := func(delta string) { + if delta == "" { + return + } + emittedAll.WriteString(delta) + if emitDelta != nil { + emitDelta(delta) + } + } + + commit := func() { + if currentText != "" { + segments = append(segments, &pb.TranscriptSegment{ + Id: int32(len(segments)), + Text: currentText, + }) + currentText = "" + } + } + + advance := func() { + for sherpaIsOnlineStreamReady(s.onlineRecognizer, stream) == 1 { + sherpaDecodeOnlineStream(s.onlineRecognizer, stream) + } + res := sherpaGetOnlineStreamResult(s.onlineRecognizer, stream) + if res == 0 { + return + } + defer sherpaDestroyOnlineRecognizerResult(res) + + text := goStringFromCPtr(shimOnlineResultText(res)) + if text != currentText { + if strings.HasPrefix(text, currentText) { + emit(text[len(currentText):]) + } else { + // Recognizer backtracked or rewrote the partial — emit + // the whole new text. Rare; happens during rescoring. + emit(text) + } + currentText = text + } + + if sherpaOnlineStreamIsEndpoint(s.onlineRecognizer, stream) == 1 { + commit() + sherpaOnlineStreamReset(s.onlineRecognizer, stream) + } + } + + for off := 0; off < total; off += chunkSamples { + n := chunkSamples + if off+n > total { + n = total - off + } + if n <= 0 { + break + } + chunkPtr := unsafe.Add(basePtr, off*4) // float32 = 4 bytes + sherpaOnlineStreamAcceptWaveform(stream, sr, chunkPtr, int32(n)) + advance() + } + + sherpaOnlineStreamInputFinished(stream) + advance() + commit() + + return pb.TranscriptResult{ + Text: emittedAll.String(), + Segments: segments, + }, nil +} + +// ============================================================= +// TTS (non-streaming) +// ============================================================= + +func (s *SherpaBackend) TTS(req *pb.TTSRequest) error { + if s.tts == 0 { + return fmt.Errorf("sherpa-onnx TTS not loaded") + } + + sid := int32(0) + if req.Voice != "" { + if id, err := strconv.Atoi(req.Voice); err == nil { + sid = int32(id) + } + } + + audio := sherpaOfflineTtsGenerate(s.tts, req.Text, sid, s.ttsSpeed) + if audio == 0 { + return fmt.Errorf("failed to generate audio") + } + defer sherpaDestroyOfflineTtsGeneratedAudio(audio) + + n := shimGeneratedAudioN(audio) + if n <= 0 { + return fmt.Errorf("generated audio has no samples") + } + samples := shimGeneratedAudioSamples(audio) + sr := shimGeneratedAudioSampleRate(audio) + + if sherpaWriteWave(samples, n, sr, req.Dst) == 0 { + return fmt.Errorf("failed to write audio to %s", req.Dst) + } + return nil +} + +// ============================================================= +// TTS streaming +// ============================================================= + +// ttsStreamState wraps the destination channel for the purego-registered +// callback. ttsStates maps a uint64 user_data value back to this struct +// so the trampoline can recover it without cgo.Handle (which requires +// cgo). +type ttsStreamState struct { + output chan []byte +} + +var ( + ttsStates sync.Map // uint64 → *ttsStreamState + ttsNextID atomic.Uint64 + ttsCallbackPtr uintptr // purego.NewCallback return; registered in loadSherpaLibs +) + +// ttsStreamCallback is invoked by sherpa-onnx for each PCM chunk VITS +// produces. The callback's `samples` is a float32 pointer to [-1,1] +// values; we convert to int16 LE PCM and push on the state channel. +// Return 1 to keep generating; 0 to stop (state gone → consumer +// disconnected). +func ttsStreamCallback(samplesPtr unsafe.Pointer, n int32, userData uintptr) int32 { + v, ok := ttsStates.Load(uint64(userData)) + if !ok { + return 0 + } + state := v.(*ttsStreamState) + + nSamples := int(n) + if nSamples <= 0 { + return 1 + } + // Saturating truncation matches what SherpaOnnxWriteWave does + // internally and avoids math.Round/float64 per-sample overhead. + samples := unsafe.Slice((*float32)(samplesPtr), nSamples) + buf := make([]byte, nSamples*2) + for i, f := range samples { + v := int32(f * 32767) + if v > 32767 { + v = 32767 + } else if v < -32768 { + v = -32768 + } + binary.LittleEndian.PutUint16(buf[2*i:], uint16(int16(v))) + } + state.output <- buf + return 1 +} + +// streamingWAVHeader builds a minimal WAV header with unknown-size +// chunks (0xFFFFFFFF) so HTTP clients can start playing before the +// full PCM has arrived. +func streamingWAVHeader(sampleRate uint32) []byte { + const streamingSize = 0xFFFFFFFF + h := laudio.NewWAVHeaderWithRate(streamingSize, sampleRate) + h.ChunkSize = streamingSize + var buf bytes.Buffer + _ = h.Write(&buf) + return buf.Bytes() +} + +// TTSStream generates speech via sherpa-onnx's callback-driven TTS API +// and emits a WAV header followed by int16 LE PCM chunks on `results`. +// Closes `results` before returning (per the backend interface +// convention used by PredictStream etc) so the server wrapper's +// goroutine exits. +func (s *SherpaBackend) TTSStream(req *pb.TTSRequest, results chan []byte) error { + defer close(results) + if s.tts == 0 { + return fmt.Errorf("sherpa-onnx TTS not loaded") + } + + sid := int32(0) + if req.Voice != "" { + if id, err := strconv.Atoi(req.Voice); err == nil { + sid = int32(id) + } + } + + sampleRate := uint32(sherpaOfflineTtsSampleRate(s.tts)) + // First chunk: streaming WAV header. The TTS HTTP handler that + // owns the response writer stitches this + PCM into a valid + // on-the-fly WAV stream. + results <- streamingWAVHeader(sampleRate) + + id := ttsNextID.Add(1) + ttsStates.Store(id, &ttsStreamState{output: results}) + defer ttsStates.Delete(id) + + audio := shimTtsGenerateWithCallback(s.tts, req.Text, sid, s.ttsSpeed, ttsCallbackPtr, uintptr(id)) + if audio != 0 { + sherpaDestroyOfflineTtsGeneratedAudio(audio) + } + return nil +} diff --git a/backend/go/sherpa-onnx/backend_test.go b/backend/go/sherpa-onnx/backend_test.go new file mode 100644 index 000000000000..46ad6d3a283f --- /dev/null +++ b/backend/go/sherpa-onnx/backend_test.go @@ -0,0 +1,169 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestSherpaBackend(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Sherpa-ONNX Backend Suite") +} + +// Load libsherpa-shim + libsherpa-onnx-c-api via purego before any spec +// runs — otherwise any Load/TTS/VAD/AudioTranscription call hits a nil +// function pointer. LD_LIBRARY_PATH must contain the directory holding +// both .so files; test.sh sets this. +var _ = BeforeSuite(func() { + Expect(loadSherpaLibs()).To(Succeed()) +}) + +var _ = Describe("Sherpa-ONNX", func() { + Context("lifecycle", func() { + It("is locking (C API is not thread safe)", func() { + Expect((&SherpaBackend{}).Locking()).To(BeTrue()) + }) + + It("errors loading a non-existent model", func() { + tmpDir, err := os.MkdirTemp("", "sherpa-test-nonexistent") + Expect(err).ToNot(HaveOccurred()) + defer os.RemoveAll(tmpDir) + + err = (&SherpaBackend{}).Load(&pb.ModelOptions{ + ModelFile: filepath.Join(tmpDir, "non-existent-model.onnx"), + }) + Expect(err).To(HaveOccurred()) + }) + + It("errors loading a non-existent ASR model", func() { + tmpDir, err := os.MkdirTemp("", "sherpa-test-asr") + Expect(err).ToNot(HaveOccurred()) + defer os.RemoveAll(tmpDir) + + err = (&SherpaBackend{}).Load(&pb.ModelOptions{ + ModelFile: filepath.Join(tmpDir, "model.onnx"), + Type: "asr", + }) + Expect(err).To(HaveOccurred()) + }) + + It("dispatches Load by Type", func() { + tmpDir, err := os.MkdirTemp("", "sherpa-test-dispatch") + Expect(err).ToNot(HaveOccurred()) + defer os.RemoveAll(tmpDir) + + modelFile := filepath.Join(tmpDir, "model.onnx") + for _, typ := range []string{"", "asr", "vad"} { + err := (&SherpaBackend{}).Load(&pb.ModelOptions{ModelFile: modelFile, Type: typ}) + Expect(err).To(HaveOccurred(), "Type=%q", typ) + } + }) + }) + + Context("method errors without loaded model", func() { + It("rejects TTS", func() { + tmpDir, err := os.MkdirTemp("", "sherpa-test-tts") + Expect(err).ToNot(HaveOccurred()) + defer os.RemoveAll(tmpDir) + + err = (&SherpaBackend{}).TTS(&pb.TTSRequest{ + Text: "should fail — no model loaded", + Dst: filepath.Join(tmpDir, "output.wav"), + }) + Expect(err).To(HaveOccurred()) + }) + + It("rejects AudioTranscription", func() { + _, err := (&SherpaBackend{}).AudioTranscription(&pb.TranscriptRequest{ + Dst: "/tmp/nonexistent.wav", + }) + Expect(err).To(HaveOccurred()) + }) + + It("rejects VAD", func() { + _, err := (&SherpaBackend{}).VAD(&pb.VADRequest{ + Audio: []float32{0.1, 0.2, 0.3}, + }) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("type detection", func() { + DescribeTable("isASRType", + func(input string, want bool) { + Expect(isASRType(input)).To(Equal(want)) + }, + Entry("asr", "asr", true), + Entry("ASR", "ASR", true), + Entry("Asr", "Asr", true), + Entry("transcription", "transcription", true), + Entry("Transcription", "Transcription", true), + Entry("transcribe", "transcribe", true), + Entry("Transcribe", "Transcribe", true), + Entry("tts", "tts", false), + Entry("empty", "", false), + Entry("other", "other", false), + Entry("vad", "vad", false), + ) + + DescribeTable("isVADType", + func(input string, want bool) { + Expect(isVADType(input)).To(Equal(want)) + }, + Entry("vad", "vad", true), + Entry("VAD", "VAD", true), + Entry("Vad", "Vad", true), + Entry("asr", "asr", false), + Entry("tts", "tts", false), + Entry("empty", "", false), + Entry("other", "other", false), + ) + }) + + Context("option parsing", func() { + It("parses float options with fallback on bad input", func() { + opts := &pb.ModelOptions{Options: []string{ + "vad.threshold=0.3", + "tts.length_scale=1.25", + "bad.number=not-a-float", + }} + Expect(findOptionFloat(opts, "vad.threshold=", 0.5)).To(BeNumerically("~", 0.3, 1e-6)) + Expect(findOptionFloat(opts, "tts.length_scale=", 1.0)).To(BeNumerically("~", 1.25, 1e-6)) + Expect(findOptionFloat(opts, "missing.key=", 0.7)).To(BeNumerically("~", 0.7, 1e-6)) + Expect(findOptionFloat(opts, "bad.number=", 9.9)).To(BeNumerically("~", 9.9, 1e-6)) + }) + + It("parses int options with fallback on bad input", func() { + opts := &pb.ModelOptions{Options: []string{ + "asr.sample_rate=22050", + "online.chunk_samples=800", + "bad.int=4.2", + }} + Expect(findOptionInt(opts, "asr.sample_rate=", 16000)).To(Equal(int32(22050))) + Expect(findOptionInt(opts, "online.chunk_samples=", 1600)).To(Equal(int32(800))) + Expect(findOptionInt(opts, "missing.key=", 42)).To(Equal(int32(42))) + Expect(findOptionInt(opts, "bad.int=", 100)).To(Equal(int32(100))) + }) + + It("parses bool options (0/1, true/false, yes/no, on/off)", func() { + opts := &pb.ModelOptions{Options: []string{ + "online.enable_endpoint=0", + "asr.sense_voice.use_itn=True", + "feature.on=yes", + "feature.off=Off", + "feature.bad=maybe", + }} + Expect(findOptionBool(opts, "online.enable_endpoint=", 1)).To(Equal(int32(0))) + Expect(findOptionBool(opts, "asr.sense_voice.use_itn=", 0)).To(Equal(int32(1))) + Expect(findOptionBool(opts, "feature.on=", 0)).To(Equal(int32(1))) + Expect(findOptionBool(opts, "feature.off=", 1)).To(Equal(int32(0))) + Expect(findOptionBool(opts, "feature.bad=", 1)).To(Equal(int32(1))) + Expect(findOptionBool(opts, "missing.key=", 1)).To(Equal(int32(1))) + }) + }) +}) diff --git a/backend/go/sherpa-onnx/csrc/shim.c b/backend/go/sherpa-onnx/csrc/shim.c new file mode 100644 index 000000000000..c09a449033e5 --- /dev/null +++ b/backend/go/sherpa-onnx/csrc/shim.c @@ -0,0 +1,325 @@ +#include "shim.h" +#include "c-api.h" + +#include +#include + +// Replace the char* field pointed to by `slot` with a strdup of `s` +// (or NULL if s is NULL). Frees any prior value. Silently no-ops when +// strdup fails — the caller will see a Create* failure downstream. +static void shim_set_str(const char **slot, const char *s) { + free((char *)*slot); + *slot = s ? strdup(s) : NULL; +} + +// ================================================================== +// VAD config +// ================================================================== + +void *sherpa_shim_vad_config_new(void) { + return calloc(1, sizeof(SherpaOnnxVadModelConfig)); +} + +void sherpa_shim_vad_config_free(void *h) { + if (!h) return; + SherpaOnnxVadModelConfig *c = (SherpaOnnxVadModelConfig *)h; + free((char *)c->silero_vad.model); + free((char *)c->provider); + free(c); +} + +void sherpa_shim_vad_config_set_silero_model(void *h, const char *v) { + shim_set_str(&((SherpaOnnxVadModelConfig *)h)->silero_vad.model, v); +} +void sherpa_shim_vad_config_set_silero_threshold(void *h, float v) { + ((SherpaOnnxVadModelConfig *)h)->silero_vad.threshold = v; +} +void sherpa_shim_vad_config_set_silero_min_silence_duration(void *h, float v) { + ((SherpaOnnxVadModelConfig *)h)->silero_vad.min_silence_duration = v; +} +void sherpa_shim_vad_config_set_silero_min_speech_duration(void *h, float v) { + ((SherpaOnnxVadModelConfig *)h)->silero_vad.min_speech_duration = v; +} +void sherpa_shim_vad_config_set_silero_window_size(void *h, int32_t v) { + ((SherpaOnnxVadModelConfig *)h)->silero_vad.window_size = v; +} +void sherpa_shim_vad_config_set_silero_max_speech_duration(void *h, float v) { + ((SherpaOnnxVadModelConfig *)h)->silero_vad.max_speech_duration = v; +} +void sherpa_shim_vad_config_set_sample_rate(void *h, int32_t v) { + ((SherpaOnnxVadModelConfig *)h)->sample_rate = v; +} +void sherpa_shim_vad_config_set_num_threads(void *h, int32_t v) { + ((SherpaOnnxVadModelConfig *)h)->num_threads = v; +} +void sherpa_shim_vad_config_set_provider(void *h, const char *v) { + shim_set_str(&((SherpaOnnxVadModelConfig *)h)->provider, v); +} +void sherpa_shim_vad_config_set_debug(void *h, int32_t v) { + ((SherpaOnnxVadModelConfig *)h)->debug = v; +} + +void *sherpa_shim_create_vad(void *h, float buffer_size_seconds) { + return (void *)SherpaOnnxCreateVoiceActivityDetector( + (const SherpaOnnxVadModelConfig *)h, buffer_size_seconds); +} + +// ================================================================== +// Offline TTS config (VITS) +// ================================================================== + +void *sherpa_shim_tts_config_new(void) { + return calloc(1, sizeof(SherpaOnnxOfflineTtsConfig)); +} + +void sherpa_shim_tts_config_free(void *h) { + if (!h) return; + SherpaOnnxOfflineTtsConfig *c = (SherpaOnnxOfflineTtsConfig *)h; + free((char *)c->model.vits.model); + free((char *)c->model.vits.tokens); + free((char *)c->model.vits.lexicon); + free((char *)c->model.vits.data_dir); + free((char *)c->model.provider); + free(c); +} + +void sherpa_shim_tts_config_set_vits_model(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.vits.model, v); +} +void sherpa_shim_tts_config_set_vits_tokens(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.vits.tokens, v); +} +void sherpa_shim_tts_config_set_vits_lexicon(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.vits.lexicon, v); +} +void sherpa_shim_tts_config_set_vits_data_dir(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.vits.data_dir, v); +} +void sherpa_shim_tts_config_set_vits_noise_scale(void *h, float v) { + ((SherpaOnnxOfflineTtsConfig *)h)->model.vits.noise_scale = v; +} +void sherpa_shim_tts_config_set_vits_noise_scale_w(void *h, float v) { + ((SherpaOnnxOfflineTtsConfig *)h)->model.vits.noise_scale_w = v; +} +void sherpa_shim_tts_config_set_vits_length_scale(void *h, float v) { + ((SherpaOnnxOfflineTtsConfig *)h)->model.vits.length_scale = v; +} +void sherpa_shim_tts_config_set_num_threads(void *h, int32_t v) { + ((SherpaOnnxOfflineTtsConfig *)h)->model.num_threads = v; +} +void sherpa_shim_tts_config_set_debug(void *h, int32_t v) { + ((SherpaOnnxOfflineTtsConfig *)h)->model.debug = v; +} +void sherpa_shim_tts_config_set_provider(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineTtsConfig *)h)->model.provider, v); +} +void sherpa_shim_tts_config_set_max_num_sentences(void *h, int32_t v) { + ((SherpaOnnxOfflineTtsConfig *)h)->max_num_sentences = v; +} + +void *sherpa_shim_create_offline_tts(void *h) { + return (void *)SherpaOnnxCreateOfflineTts( + (const SherpaOnnxOfflineTtsConfig *)h); +} + +// ================================================================== +// Offline recognizer config +// ================================================================== + +void *sherpa_shim_offline_recog_config_new(void) { + return calloc(1, sizeof(SherpaOnnxOfflineRecognizerConfig)); +} + +void sherpa_shim_offline_recog_config_free(void *h) { + if (!h) return; + SherpaOnnxOfflineRecognizerConfig *c = (SherpaOnnxOfflineRecognizerConfig *)h; + free((char *)c->model_config.provider); + free((char *)c->model_config.tokens); + free((char *)c->model_config.whisper.encoder); + free((char *)c->model_config.whisper.decoder); + free((char *)c->model_config.whisper.language); + free((char *)c->model_config.whisper.task); + free((char *)c->model_config.paraformer.model); + free((char *)c->model_config.sense_voice.model); + free((char *)c->model_config.sense_voice.language); + free((char *)c->model_config.omnilingual.model); + free((char *)c->decoding_method); + free(c); +} + +void sherpa_shim_offline_recog_config_set_num_threads(void *h, int32_t v) { + ((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.num_threads = v; +} +void sherpa_shim_offline_recog_config_set_debug(void *h, int32_t v) { + ((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.debug = v; +} +void sherpa_shim_offline_recog_config_set_provider(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.provider, v); +} +void sherpa_shim_offline_recog_config_set_tokens(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.tokens, v); +} +void sherpa_shim_offline_recog_config_set_feat_sample_rate(void *h, int32_t v) { + ((SherpaOnnxOfflineRecognizerConfig *)h)->feat_config.sample_rate = v; +} +void sherpa_shim_offline_recog_config_set_feat_feature_dim(void *h, int32_t v) { + ((SherpaOnnxOfflineRecognizerConfig *)h)->feat_config.feature_dim = v; +} +void sherpa_shim_offline_recog_config_set_decoding_method(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->decoding_method, v); +} +void sherpa_shim_offline_recog_config_set_whisper_encoder(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.whisper.encoder, v); +} +void sherpa_shim_offline_recog_config_set_whisper_decoder(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.whisper.decoder, v); +} +void sherpa_shim_offline_recog_config_set_whisper_language(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.whisper.language, v); +} +void sherpa_shim_offline_recog_config_set_whisper_task(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.whisper.task, v); +} +void sherpa_shim_offline_recog_config_set_whisper_tail_paddings(void *h, int32_t v) { + ((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.whisper.tail_paddings = v; +} +void sherpa_shim_offline_recog_config_set_paraformer_model(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.paraformer.model, v); +} +void sherpa_shim_offline_recog_config_set_sense_voice_model(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.sense_voice.model, v); +} +void sherpa_shim_offline_recog_config_set_sense_voice_language(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.sense_voice.language, v); +} +void sherpa_shim_offline_recog_config_set_sense_voice_use_itn(void *h, int32_t v) { + ((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.sense_voice.use_itn = v; +} +void sherpa_shim_offline_recog_config_set_omnilingual_model(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOfflineRecognizerConfig *)h)->model_config.omnilingual.model, v); +} + +void *sherpa_shim_create_offline_recognizer(void *h) { + return (void *)SherpaOnnxCreateOfflineRecognizer( + (const SherpaOnnxOfflineRecognizerConfig *)h); +} + +// ================================================================== +// Online recognizer config +// ================================================================== + +void *sherpa_shim_online_recog_config_new(void) { + return calloc(1, sizeof(SherpaOnnxOnlineRecognizerConfig)); +} + +void sherpa_shim_online_recog_config_free(void *h) { + if (!h) return; + SherpaOnnxOnlineRecognizerConfig *c = (SherpaOnnxOnlineRecognizerConfig *)h; + free((char *)c->model_config.transducer.encoder); + free((char *)c->model_config.transducer.decoder); + free((char *)c->model_config.transducer.joiner); + free((char *)c->model_config.tokens); + free((char *)c->model_config.provider); + free((char *)c->decoding_method); + free(c); +} + +void sherpa_shim_online_recog_config_set_transducer_encoder(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.transducer.encoder, v); +} +void sherpa_shim_online_recog_config_set_transducer_decoder(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.transducer.decoder, v); +} +void sherpa_shim_online_recog_config_set_transducer_joiner(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.transducer.joiner, v); +} +void sherpa_shim_online_recog_config_set_tokens(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.tokens, v); +} +void sherpa_shim_online_recog_config_set_num_threads(void *h, int32_t v) { + ((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.num_threads = v; +} +void sherpa_shim_online_recog_config_set_debug(void *h, int32_t v) { + ((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.debug = v; +} +void sherpa_shim_online_recog_config_set_provider(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->model_config.provider, v); +} +void sherpa_shim_online_recog_config_set_feat_sample_rate(void *h, int32_t v) { + ((SherpaOnnxOnlineRecognizerConfig *)h)->feat_config.sample_rate = v; +} +void sherpa_shim_online_recog_config_set_feat_feature_dim(void *h, int32_t v) { + ((SherpaOnnxOnlineRecognizerConfig *)h)->feat_config.feature_dim = v; +} +void sherpa_shim_online_recog_config_set_decoding_method(void *h, const char *v) { + shim_set_str(&((SherpaOnnxOnlineRecognizerConfig *)h)->decoding_method, v); +} +void sherpa_shim_online_recog_config_set_enable_endpoint(void *h, int32_t v) { + ((SherpaOnnxOnlineRecognizerConfig *)h)->enable_endpoint = v; +} +void sherpa_shim_online_recog_config_set_rule1_min_trailing_silence(void *h, float v) { + ((SherpaOnnxOnlineRecognizerConfig *)h)->rule1_min_trailing_silence = v; +} +void sherpa_shim_online_recog_config_set_rule2_min_trailing_silence(void *h, float v) { + ((SherpaOnnxOnlineRecognizerConfig *)h)->rule2_min_trailing_silence = v; +} +void sherpa_shim_online_recog_config_set_rule3_min_utterance_length(void *h, float v) { + ((SherpaOnnxOnlineRecognizerConfig *)h)->rule3_min_utterance_length = v; +} + +void *sherpa_shim_create_online_recognizer(void *h) { + return (void *)SherpaOnnxCreateOnlineRecognizer( + (const SherpaOnnxOnlineRecognizerConfig *)h); +} + +// ================================================================== +// Result-struct accessors +// ================================================================== + +int32_t sherpa_shim_wave_sample_rate(const void *h) { + return ((const SherpaOnnxWave *)h)->sample_rate; +} +int32_t sherpa_shim_wave_num_samples(const void *h) { + return ((const SherpaOnnxWave *)h)->num_samples; +} +const float *sherpa_shim_wave_samples(const void *h) { + return ((const SherpaOnnxWave *)h)->samples; +} + +const char *sherpa_shim_offline_result_text(const void *h) { + return ((const SherpaOnnxOfflineRecognizerResult *)h)->text; +} +const char *sherpa_shim_online_result_text(const void *h) { + return ((const SherpaOnnxOnlineRecognizerResult *)h)->text; +} + +int32_t sherpa_shim_generated_audio_sample_rate(const void *h) { + return ((const SherpaOnnxGeneratedAudio *)h)->sample_rate; +} +int32_t sherpa_shim_generated_audio_n(const void *h) { + return ((const SherpaOnnxGeneratedAudio *)h)->n; +} +const float *sherpa_shim_generated_audio_samples(const void *h) { + return ((const SherpaOnnxGeneratedAudio *)h)->samples; +} + +int32_t sherpa_shim_speech_segment_start(const void *h) { + return ((const SherpaOnnxSpeechSegment *)h)->start; +} +int32_t sherpa_shim_speech_segment_n(const void *h) { + return ((const SherpaOnnxSpeechSegment *)h)->n; +} + +// ================================================================== +// TTS streaming callback trampoline +// ================================================================== + +void *sherpa_shim_tts_generate_with_callback( + void *tts, const char *text, int32_t sid, float speed, + uintptr_t callback_ptr, uintptr_t user_data) { + SherpaOnnxGeneratedAudioCallbackWithArg cb = + (SherpaOnnxGeneratedAudioCallbackWithArg)callback_ptr; + return (void *)SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( + (const SherpaOnnxOfflineTts *)tts, text, sid, speed, cb, + (void *)user_data); +} diff --git a/backend/go/sherpa-onnx/csrc/shim.h b/backend/go/sherpa-onnx/csrc/shim.h new file mode 100644 index 000000000000..d479a33a308b --- /dev/null +++ b/backend/go/sherpa-onnx/csrc/shim.h @@ -0,0 +1,129 @@ +#ifndef LOCALAI_SHERPA_ONNX_SHIM_H +#define LOCALAI_SHERPA_ONNX_SHIM_H + +#include + +// libsherpa-shim: purego-friendly wrapper around sherpa-onnx's C API. +// Purego can't access C struct fields and can't route C callbacks to Go +// funcs directly. Every function here is a fixed-signature trampoline +// that replaces one field read/write or callback handoff that the Go +// backend would otherwise have to do through cgo. +// +// String lifetime: setters strdup; _free walks every owned string and +// frees it. Callers may discard their input buffers the moment a setter +// returns. +// +// Opaque handles are `void *` in both directions. Nothing here holds a +// reference across calls except config handles (freed via _free) and +// sherpa-allocated results (freed via sherpa's own Destroy* entry +// points, which Go calls through purego pass-through). + +#ifdef __cplusplus +extern "C" { +#endif + +// --- VAD config ----------------------------------------------------- +void *sherpa_shim_vad_config_new(void); +void sherpa_shim_vad_config_free(void *cfg); +void sherpa_shim_vad_config_set_silero_model(void *cfg, const char *path); +void sherpa_shim_vad_config_set_silero_threshold(void *cfg, float v); +void sherpa_shim_vad_config_set_silero_min_silence_duration(void *cfg, float v); +void sherpa_shim_vad_config_set_silero_min_speech_duration(void *cfg, float v); +void sherpa_shim_vad_config_set_silero_window_size(void *cfg, int32_t v); +void sherpa_shim_vad_config_set_silero_max_speech_duration(void *cfg, float v); +void sherpa_shim_vad_config_set_sample_rate(void *cfg, int32_t v); +void sherpa_shim_vad_config_set_num_threads(void *cfg, int32_t v); +void sherpa_shim_vad_config_set_provider(void *cfg, const char *v); +void sherpa_shim_vad_config_set_debug(void *cfg, int32_t v); +void *sherpa_shim_create_vad(void *cfg, float buffer_size_seconds); + +// --- Offline TTS config (VITS path — the only TTS family the backend uses) --- +void *sherpa_shim_tts_config_new(void); +void sherpa_shim_tts_config_free(void *cfg); +void sherpa_shim_tts_config_set_vits_model(void *cfg, const char *v); +void sherpa_shim_tts_config_set_vits_tokens(void *cfg, const char *v); +void sherpa_shim_tts_config_set_vits_lexicon(void *cfg, const char *v); +void sherpa_shim_tts_config_set_vits_data_dir(void *cfg, const char *v); +void sherpa_shim_tts_config_set_vits_noise_scale(void *cfg, float v); +void sherpa_shim_tts_config_set_vits_noise_scale_w(void *cfg, float v); +void sherpa_shim_tts_config_set_vits_length_scale(void *cfg, float v); +void sherpa_shim_tts_config_set_num_threads(void *cfg, int32_t v); +void sherpa_shim_tts_config_set_debug(void *cfg, int32_t v); +void sherpa_shim_tts_config_set_provider(void *cfg, const char *v); +void sherpa_shim_tts_config_set_max_num_sentences(void *cfg, int32_t v); +void *sherpa_shim_create_offline_tts(void *cfg); + +// --- Offline recognizer config (Whisper / Paraformer / SenseVoice / Omnilingual) --- +void *sherpa_shim_offline_recog_config_new(void); +void sherpa_shim_offline_recog_config_free(void *cfg); +void sherpa_shim_offline_recog_config_set_num_threads(void *cfg, int32_t v); +void sherpa_shim_offline_recog_config_set_debug(void *cfg, int32_t v); +void sherpa_shim_offline_recog_config_set_provider(void *cfg, const char *v); +void sherpa_shim_offline_recog_config_set_tokens(void *cfg, const char *v); +void sherpa_shim_offline_recog_config_set_feat_sample_rate(void *cfg, int32_t v); +void sherpa_shim_offline_recog_config_set_feat_feature_dim(void *cfg, int32_t v); +void sherpa_shim_offline_recog_config_set_decoding_method(void *cfg, const char *v); +void sherpa_shim_offline_recog_config_set_whisper_encoder(void *cfg, const char *v); +void sherpa_shim_offline_recog_config_set_whisper_decoder(void *cfg, const char *v); +void sherpa_shim_offline_recog_config_set_whisper_language(void *cfg, const char *v); +void sherpa_shim_offline_recog_config_set_whisper_task(void *cfg, const char *v); +void sherpa_shim_offline_recog_config_set_whisper_tail_paddings(void *cfg, int32_t v); +void sherpa_shim_offline_recog_config_set_paraformer_model(void *cfg, const char *v); +void sherpa_shim_offline_recog_config_set_sense_voice_model(void *cfg, const char *v); +void sherpa_shim_offline_recog_config_set_sense_voice_language(void *cfg, const char *v); +void sherpa_shim_offline_recog_config_set_sense_voice_use_itn(void *cfg, int32_t v); +void sherpa_shim_offline_recog_config_set_omnilingual_model(void *cfg, const char *v); +void *sherpa_shim_create_offline_recognizer(void *cfg); + +// --- Online recognizer config (streaming zipformer transducer) --- +void *sherpa_shim_online_recog_config_new(void); +void sherpa_shim_online_recog_config_free(void *cfg); +void sherpa_shim_online_recog_config_set_transducer_encoder(void *cfg, const char *v); +void sherpa_shim_online_recog_config_set_transducer_decoder(void *cfg, const char *v); +void sherpa_shim_online_recog_config_set_transducer_joiner(void *cfg, const char *v); +void sherpa_shim_online_recog_config_set_tokens(void *cfg, const char *v); +void sherpa_shim_online_recog_config_set_num_threads(void *cfg, int32_t v); +void sherpa_shim_online_recog_config_set_debug(void *cfg, int32_t v); +void sherpa_shim_online_recog_config_set_provider(void *cfg, const char *v); +void sherpa_shim_online_recog_config_set_feat_sample_rate(void *cfg, int32_t v); +void sherpa_shim_online_recog_config_set_feat_feature_dim(void *cfg, int32_t v); +void sherpa_shim_online_recog_config_set_decoding_method(void *cfg, const char *v); +void sherpa_shim_online_recog_config_set_enable_endpoint(void *cfg, int32_t v); +void sherpa_shim_online_recog_config_set_rule1_min_trailing_silence(void *cfg, float v); +void sherpa_shim_online_recog_config_set_rule2_min_trailing_silence(void *cfg, float v); +void sherpa_shim_online_recog_config_set_rule3_min_utterance_length(void *cfg, float v); +void *sherpa_shim_create_online_recognizer(void *cfg); + +// --- Result accessors (sherpa-allocated; caller destroys via sherpa's own Destroy*) --- +int32_t sherpa_shim_wave_sample_rate(const void *wave); +int32_t sherpa_shim_wave_num_samples(const void *wave); +const float *sherpa_shim_wave_samples(const void *wave); + +const char *sherpa_shim_offline_result_text(const void *result); +const char *sherpa_shim_online_result_text(const void *result); + +int32_t sherpa_shim_generated_audio_sample_rate(const void *audio); +int32_t sherpa_shim_generated_audio_n(const void *audio); +const float *sherpa_shim_generated_audio_samples(const void *audio); + +int32_t sherpa_shim_speech_segment_start(const void *seg); +int32_t sherpa_shim_speech_segment_n(const void *seg); + +// --- TTS streaming callback trampoline ----------------------------- +// Replaces the //export sherpaTtsGoCallback + callbacks.c bridge pattern. +// `callback_ptr` is the C-callable function pointer returned by +// purego.NewCallback. `user_data` is an integer the Go side uses to +// look up its state (sync.Map keyed by uint64). +// +// Returns the sherpa-allocated SherpaOnnxGeneratedAudio. Destroy with +// SherpaOnnxDestroyOfflineTtsGeneratedAudio (callable directly from +// Go via purego). +void *sherpa_shim_tts_generate_with_callback( + void *tts, const char *text, int32_t sid, float speed, + uintptr_t callback_ptr, uintptr_t user_data); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/backend/go/sherpa-onnx/main.go b/backend/go/sherpa-onnx/main.go new file mode 100644 index 000000000000..ef4dbe6e306f --- /dev/null +++ b/backend/go/sherpa-onnx/main.go @@ -0,0 +1,23 @@ +package main + +import ( + "flag" + + grpc "github.com/mudler/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := loadSherpaLibs(); err != nil { + panic(err) + } + + if err := grpc.StartServer(*addr, &SherpaBackend{}); err != nil { + panic(err) + } +} diff --git a/backend/go/sherpa-onnx/package.sh b/backend/go/sherpa-onnx/package.sh new file mode 100755 index 000000000000..5a596e4906cb --- /dev/null +++ b/backend/go/sherpa-onnx/package.sh @@ -0,0 +1,51 @@ +#!/bin/bash +set -e + +CURDIR=$(dirname "$(realpath $0)") +REPO_ROOT="${CURDIR}/../../.." + +mkdir -p $CURDIR/package/lib + +cp -avf $CURDIR/sherpa-onnx $CURDIR/package/ +cp -avf $CURDIR/run.sh $CURDIR/package/ +cp -rfLv $CURDIR/backend-assets/lib/* $CURDIR/package/lib/ + +if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then + echo "Detected x86_64 architecture, copying x86_64 libraries..." + cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so + cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 + cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 + cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 + cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 + cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 + cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 + cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 + cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 +elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then + echo "Detected ARM64 architecture, copying ARM64 libraries..." + cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so + cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 + cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 + cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 + cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 + cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 + cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 + cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 + cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 +elif [ $(uname -s) = "Darwin" ]; then + echo "Detected Darwin" +else + echo "Error: Could not detect architecture" + exit 1 +fi + +GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh" +if [ -f "$GPU_LIB_SCRIPT" ]; then + echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..." + source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib" + package_gpu_libs +fi + +echo "Packaging completed successfully" +ls -liah $CURDIR/package/ +ls -liah $CURDIR/package/lib/ diff --git a/backend/go/sherpa-onnx/run.sh b/backend/go/sherpa-onnx/run.sh new file mode 100755 index 000000000000..b703e51551c2 --- /dev/null +++ b/backend/go/sherpa-onnx/run.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -ex + +CURDIR=$(dirname "$(realpath $0)") + +export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH + +if [ -f $CURDIR/lib/ld.so ]; then + echo "Using lib/ld.so" + exec $CURDIR/lib/ld.so $CURDIR/sherpa-onnx "$@" +fi + +exec $CURDIR/sherpa-onnx "$@" diff --git a/backend/go/sherpa-onnx/test.sh b/backend/go/sherpa-onnx/test.sh new file mode 100755 index 000000000000..e0a69d213074 --- /dev/null +++ b/backend/go/sherpa-onnx/test.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Unit tests for the sherpa-onnx backend. Exercises error-path and +# dispatch logic via SherpaBackend directly (no gRPC). Integration +# coverage (gRPC TTS / streaming ASR / realtime pipeline) lives in +# tests/e2e-backends and tests/e2e and runs against the Docker image. +set -e + +CURDIR=$(dirname "$(realpath $0)") +cd "$CURDIR" + +PACKAGES=$(go list ./... | grep -v /sources/) +go test -v -timeout 60s $PACKAGES diff --git a/backend/index.yaml b/backend/index.yaml index d97059769be1..63726c5b620e 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -1006,6 +1006,23 @@ nvidia: "cuda12-neutts" amd: "rocm-neutts" nvidia-cuda-12: "cuda12-neutts" +- &sherpa-onnx + name: "sherpa-onnx" + alias: "sherpa-onnx" + urls: + - https://k2-fsa.github.io/sherpa/onnx/ + description: | + Sherpa-ONNX backend for text-to-speech (VITS, Matcha, Kokoro), speech-to-text (Whisper, Paraformer, SenseVoice, Omnilingual ASR CTC), and voice activity detection via ONNX Runtime. + Supports multi-speaker voices, 1600+ language ASR, and GPU acceleration. + tags: + - text-to-speech + - TTS + - speech-to-text + - ASR + capabilities: + default: "cpu-sherpa-onnx" + nvidia: "cuda12-sherpa-onnx" + nvidia-cuda-12: "cuda12-sherpa-onnx" - !!merge <<: *neutts name: "neutts-development" capabilities: @@ -3834,3 +3851,30 @@ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-speaker-recognition" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-speaker-recognition +## sherpa-onnx +- !!merge <<: *sherpa-onnx + name: "sherpa-onnx-development" + capabilities: + default: "cpu-sherpa-onnx-development" + nvidia: "cuda12-sherpa-onnx-development" + nvidia-cuda-12: "cuda12-sherpa-onnx-development" +- !!merge <<: *sherpa-onnx + name: "cpu-sherpa-onnx" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-sherpa-onnx" + mirrors: + - localai/localai-backends:latest-cpu-sherpa-onnx +- !!merge <<: *sherpa-onnx + name: "cpu-sherpa-onnx-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cpu-sherpa-onnx" + mirrors: + - localai/localai-backends:master-cpu-sherpa-onnx +- !!merge <<: *sherpa-onnx + name: "cuda12-sherpa-onnx" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-sherpa-onnx" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-12-sherpa-onnx +- !!merge <<: *sherpa-onnx + name: "cuda12-sherpa-onnx-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-sherpa-onnx" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-12-sherpa-onnx diff --git a/core/config/model_config.go b/core/config/model_config.go index b839ae491d25..1184d8452a71 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -767,7 +767,7 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool { } if (u & FLAG_VAD) == FLAG_VAD { - if c.Backend != "silero-vad" && !(c.Backend == "whisper" && slices.Contains(c.Options, "vad_only")) { + if c.Backend != "silero-vad" && c.Backend != "sherpa-onnx" && !(c.Backend == "whisper" && slices.Contains(c.Options, "vad_only")) { return false } } diff --git a/gallery/index.yaml b/gallery/index.yaml index 227e0e082285..82114bf471a9 100644 --- a/gallery/index.yaml +++ b/gallery/index.yaml @@ -1178,6 +1178,134 @@ - transcript parameters: model: tiny +- name: omnilingual-0.3b-ctc-q8-sherpa + license: apache-2.0 + url: "github:mudler/LocalAI/gallery/sherpa-onnx-asr.yaml@master" + description: | + Omnilingual ASR CTC 300M (int8) is a multilingual automatic speech recognition model supporting 1,600+ languages. Based on Meta's omniASR_CTC_300M architecture (Wav2Vec2 with CTC head), quantized to int8 for efficient inference. Uses the sherpa-onnx backend with ONNX Runtime. + urls: + - https://huggingface.co/csukuangfj/sherpa-onnx-omnilingual-asr-1600-languages-300M-ctc-int8-2025-11-12 + - https://k2-fsa.github.io/sherpa/onnx/omnilingual-asr/models.html + icon: https://avatars.githubusercontent.com/u/75781706 + tags: + - stt + - speech-to-text + - asr + - audio-transcription + - multilingual + - omnilingual + - sherpa-onnx + - cpu + - gpu + overrides: + known_usecases: + - transcript + parameters: + model: omnilingual-asr/model.int8.onnx + files: + - filename: omnilingual-asr/model.int8.onnx + sha256: e7c4e54ee4c4c47829cc6667d5d00ed8ea7bef1dcfeef0fce766f77752a2726c + uri: https://huggingface.co/csukuangfj/sherpa-onnx-omnilingual-asr-1600-languages-300M-ctc-int8-2025-11-12/resolve/main/model.int8.onnx + - filename: omnilingual-asr/tokens.txt + sha256: a7a044c52cb29cbe8b0dc1953e92cefd4ca16b0ed968177b6beab21f9a7d0b31 + uri: https://huggingface.co/csukuangfj/sherpa-onnx-omnilingual-asr-1600-languages-300M-ctc-int8-2025-11-12/resolve/main/tokens.txt +- name: streaming-zipformer-en-sherpa + license: apache-2.0 + url: "github:mudler/LocalAI/gallery/sherpa-onnx-asr.yaml@master" + description: | + Streaming English ASR: sherpa-onnx zipformer transducer (int8, chunk-16 left-128). Low-latency real-time transcription with endpoint detection via sherpa-onnx's online recognizer. English-only; for multilingual offline ASR see omnilingual-0.3b-ctc-q8-sherpa. + urls: + - https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 + - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html + icon: https://avatars.githubusercontent.com/u/75781706 + tags: + - stt + - speech-to-text + - asr + - audio-transcription + - streaming + - real-time + - english + - zipformer + - sherpa-onnx + - cpu + - gpu + overrides: + known_usecases: + - transcript + parameters: + model: streaming-zipformer-en/encoder.int8.onnx + options: + - subtype=online + files: + - filename: streaming-zipformer-en/encoder.int8.onnx + sha256: 563fde436d16cf7607cf408cd6b30909819d03162652ef389c2450ced3f45ac1 + uri: https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/encoder-epoch-99-avg-1-chunk-16-left-128.int8.onnx + - filename: streaming-zipformer-en/decoder.int8.onnx + sha256: 98da299f471e38bb4e1a8df579b8cc9122d6039576a77e357b3c60f17dd83b02 + uri: https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/decoder-epoch-99-avg-1-chunk-16-left-128.int8.onnx + - filename: streaming-zipformer-en/joiner.int8.onnx + sha256: d944208d660d67c8d72cd2acaeac971fa5ceb8c80e76c1968148846fedd6e297 + uri: https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/joiner-epoch-99-avg-1-chunk-16-left-128.int8.onnx + - filename: streaming-zipformer-en/tokens.txt + sha256: 49e3c2646595fd907228b3c6787069658f67b17377c60aeb8619c4551b2316fb + uri: https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/tokens.txt +- name: silero-vad-sherpa + license: mit + url: "github:mudler/LocalAI/gallery/sherpa-onnx-vad.yaml@master" + description: | + Silero VAD served through the sherpa-onnx backend. Uses the same ONNX weights as the dedicated silero-vad backend, loaded through sherpa-onnx's C VAD API. Pairs with the sherpa-onnx ASR entries for round-trip audio pipelines. + urls: + - https://github.com/snakers4/silero-vad + - https://huggingface.co/onnx-community/silero-vad + icon: https://github.com/snakers4/silero-models/raw/master/files/silero_logo.jpg + tags: + - vad + - voice-activity-detection + - sherpa-onnx + - cpu + - gpu + overrides: + known_usecases: + - vad + parameters: + model: silero-vad/silero-vad.onnx + files: + - filename: silero-vad/silero-vad.onnx + sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808 + uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx +- name: vits-ljs-sherpa + license: mit + url: "github:mudler/LocalAI/gallery/sherpa-onnx-tts.yaml@master" + description: | + VITS-LJS English single-speaker TTS served through the sherpa-onnx backend. Trained on the LJSpeech corpus at 22.05 kHz. Pairs with the sherpa-onnx ASR entries for round-trip audio pipelines. + urls: + - https://github.com/k2-fsa/sherpa-onnx + - https://huggingface.co/csukuangfj/vits-ljs + icon: https://avatars.githubusercontent.com/u/75781706 + tags: + - tts + - text-to-speech + - english + - vits + - sherpa-onnx + - cpu + - gpu + overrides: + known_usecases: + - tts + parameters: + model: vits-ljs/vits-ljs.onnx + files: + - filename: vits-ljs/vits-ljs.onnx + sha256: 5bbd273797a9ecf8d94bd6ec02ad16cb41cbb85f055ad98d528ced3e44c9b31a + uri: https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx + - filename: vits-ljs/tokens.txt + sha256: 5fee2c6b238d712287f2ecb08f34a8a8b413bcb7390862ef6fb6fd6f0f8d3a17 + uri: https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt + - filename: vits-ljs/lexicon.txt + sha256: bdccfc6da71c45c48e2e0056fcf0aab760577c5f959f6c1b5eb3e3e916fd5a0e + uri: https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt - name: voxcpm-1.5 license: apache-2.0 url: "github:mudler/LocalAI/gallery/virtual.yaml@master" diff --git a/gallery/sherpa-onnx-asr.yaml b/gallery/sherpa-onnx-asr.yaml new file mode 100644 index 000000000000..f0bdb219f17b --- /dev/null +++ b/gallery/sherpa-onnx-asr.yaml @@ -0,0 +1,27 @@ +--- +name: "sherpa-onnx-asr" + +config_file: | + backend: sherpa-onnx + type: asr + options: + # Feature extraction. Most shipped sherpa-onnx ASR models expect + # 16 kHz / 80-dim log-mel; derivatives trained at other rates + # should override these. + - asr.sample_rate=16000 + - asr.feature_dim=80 + - asr.decoding_method=greedy_search + # Whisper-family defaults (ignored by non-whisper models). + - asr.whisper.task=transcribe + - asr.whisper.tail_paddings=-1 + # SenseVoice-family: inverse text normalization is off in upstream + # sherpa but on here — we want formatted transcription output + # ("100" not "one hundred"). Set to 0 for raw tokens. + - asr.sense_voice.use_itn=1 + # Online (streaming zipformer) ASR. Endpoint detection is upstream- + # off but on here — streaming consumers need segment boundaries. + - online.enable_endpoint=1 + - online.rule1_min_trailing_silence=2.4 + - online.rule2_min_trailing_silence=1.2 + - online.rule3_min_utterance_length=20.0 + - online.chunk_samples=1600 diff --git a/gallery/sherpa-onnx-tts.yaml b/gallery/sherpa-onnx-tts.yaml new file mode 100644 index 000000000000..e6bdb1f4b4f9 --- /dev/null +++ b/gallery/sherpa-onnx-tts.yaml @@ -0,0 +1,14 @@ +--- +name: "sherpa-onnx-tts" + +config_file: | + backend: sherpa-onnx + options: + # VITS inference knobs. Matches upstream sherpa-onnx defaults. + - tts.noise_scale=0.667 + - tts.noise_scale_w=0.8 + - tts.length_scale=1.0 + - tts.max_num_sentences=1 + # Speech rate multiplier. Applied at every TTS / TTSStream call + # since the TTSRequest proto has no speed field. + - tts.speed=1.0 diff --git a/gallery/sherpa-onnx-vad.yaml b/gallery/sherpa-onnx-vad.yaml new file mode 100644 index 000000000000..72c226e02e00 --- /dev/null +++ b/gallery/sherpa-onnx-vad.yaml @@ -0,0 +1,17 @@ +--- +name: "sherpa-onnx-vad" + +config_file: | + backend: sherpa-onnx + type: vad + options: + # Silero VAD. Defaults mirror upstream sherpa-onnx. Override for + # faster turn-taking (lower min_silence) or different sample rate + # derivatives (8 kHz Silero variants). + - vad.threshold=0.5 + - vad.min_silence=0.5 + - vad.min_speech=0.25 + - vad.window_size=512 + - vad.max_speech=20.0 + - vad.sample_rate=16000 + - vad.buffer_size=60.0 diff --git a/pkg/utils/ffmpeg.go b/pkg/utils/ffmpeg.go index c2783dbceadd..1ebcc11a8ba5 100644 --- a/pkg/utils/ffmpeg.go +++ b/pkg/utils/ffmpeg.go @@ -2,10 +2,13 @@ package utils import ( "fmt" + "io" "os" "os/exec" "strings" + laudio "github.com/mudler/LocalAI/pkg/audio" + "github.com/go-audio/wav" ) @@ -16,24 +19,61 @@ func ffmpegCommand(args []string) (string, error) { return string(out), err } -// AudioToWav converts audio to wav for transcribe. -// TODO: use https://github.com/mccoyst/ogg? +// AudioToWav converts audio to wav for transcribe (16 kHz mono s16le). +// WAV files already in the target format are passed through directly; +// everything else is converted via ffmpeg. +// +// The pass-through uses a hardlink (with a Copy fallback for cross-fs +// src/dst) rather than Rename — callers may invoke this twice against +// the same fixture (e.g. once for AudioTranscription and once for +// AudioTranscriptionStream) and expect the original file to remain. func AudioToWav(src, dst string) error { - if strings.HasSuffix(src, ".wav") { - f, err := os.Open(src) - if err != nil { - return fmt.Errorf("open: %w", err) - } + if strings.HasSuffix(src, ".wav") && isTargetWav(src) { + return passthroughWAV(src, dst) + } + return convertWithFFmpeg(src, dst) +} - dec := wav.NewDecoder(f) - dec.ReadInfo() - f.Close() +func passthroughWAV(src, dst string) error { + if err := os.Link(src, dst); err == nil { + return nil + } + // Fallback: copy. Hardlink fails across filesystems (e.g. src on a + // read-only mount, dst in /tmp) or when the destination already + // exists — both are fine; just copy bytes. + in, err := os.Open(src) + if err != nil { + return fmt.Errorf("open src: %w", err) + } + defer in.Close() + out, err := os.Create(dst) + if err != nil { + return fmt.Errorf("create dst: %w", err) + } + defer out.Close() + if _, err := io.Copy(out, in); err != nil { + return fmt.Errorf("copy: %w", err) + } + return nil +} - if dec.BitDepth == 16 && dec.NumChans == 1 && dec.SampleRate == 16000 { - os.Rename(src, dst) - return nil - } +// isTargetWav returns true when src is a valid WAV already in the +// target format (16 kHz, mono, 16-bit PCM). +func isTargetWav(src string) bool { + f, err := os.Open(src) + if err != nil { + return false + } + defer f.Close() + + dec := wav.NewDecoder(f) + if !dec.IsValidFile() { + return false } + return dec.BitDepth == 16 && dec.NumChans == 1 && dec.SampleRate == 16000 +} + +func convertWithFFmpeg(src, dst string) error { commandArgs := []string{"-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst} out, err := ffmpegCommand(commandArgs) if err != nil { @@ -85,3 +125,18 @@ func AudioConvert(src string, format string) (string, error) { } return dst, nil } + +// WriteWav16kFromReader reads all PCM data from r and writes a 16 kHz mono +// 16-bit WAV to w. Useful when the PCM length is not known in advance. +func WriteWav16kFromReader(w io.Writer, r io.Reader) error { + pcm, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("read pcm: %w", err) + } + hdr := laudio.NewWAVHeader(uint32(len(pcm))) + if err := hdr.Write(w); err != nil { + return fmt.Errorf("write wav header: %w", err) + } + _, err = w.Write(pcm) + return err +} diff --git a/pkg/utils/ffmpeg_test.go b/pkg/utils/ffmpeg_test.go new file mode 100644 index 000000000000..f11ef7fd4d0f --- /dev/null +++ b/pkg/utils/ffmpeg_test.go @@ -0,0 +1,150 @@ +package utils + +import ( + "encoding/binary" + "os" + "path/filepath" + "testing" + + laudio "github.com/mudler/LocalAI/pkg/audio" +) + +// generateTestWav creates a WAV file with a sine-ish tone at the given sample rate, +// channels, and bit depth (only 16-bit supported for simplicity). +func generateTestWav(t *testing.T, path string, sampleRate uint32, numChannels uint16, numSamples int) { + t.Helper() + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + bitsPerSample := uint16(16) + blockAlign := numChannels * (bitsPerSample / 8) + byteRate := sampleRate * uint32(blockAlign) + totalSamples := numSamples * int(numChannels) + dataSize := uint32(totalSamples) * uint32(bitsPerSample/8) + + hdr := laudio.WAVHeader{ + ChunkID: [4]byte{'R', 'I', 'F', 'F'}, + ChunkSize: 36 + dataSize, + Format: [4]byte{'W', 'A', 'V', 'E'}, + Subchunk1ID: [4]byte{'f', 'm', 't', ' '}, + Subchunk1Size: 16, + AudioFormat: 1, + NumChannels: numChannels, + SampleRate: sampleRate, + ByteRate: byteRate, + BlockAlign: blockAlign, + BitsPerSample: bitsPerSample, + Subchunk2ID: [4]byte{'d', 'a', 't', 'a'}, + Subchunk2Size: dataSize, + } + if err := binary.Write(f, binary.LittleEndian, &hdr); err != nil { + t.Fatal(err) + } + + for i := 0; i < totalSamples; i++ { + sample := int16(1000 * (i % 100)) + if err := binary.Write(f, binary.LittleEndian, sample); err != nil { + t.Fatal(err) + } + } +} + +func TestAudioToWav_AlreadyCorrectFormat(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "input.wav") + dst := filepath.Join(dir, "output.wav") + + generateTestWav(t, src, 16000, 1, 1600) + + if err := AudioToWav(src, dst); err != nil { + t.Fatalf("AudioToWav failed: %v", err) + } + + info, err := os.Stat(dst) + if err != nil { + t.Fatalf("output not found: %v", err) + } + if info.Size() == 0 { + t.Fatal("output file is empty") + } +} + +func TestAudioToWav_ResampleFrom22050(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "input.wav") + dst := filepath.Join(dir, "output.wav") + + generateTestWav(t, src, 22050, 1, 22050) + + if err := AudioToWav(src, dst); err != nil { + t.Fatalf("AudioToWav failed: %v", err) + } + + info, err := os.Stat(dst) + if err != nil { + t.Fatalf("output not found: %v", err) + } + if info.Size() == 0 { + t.Fatal("output file is empty") + } + + verifyWavFormat(t, dst, 16000, 1) +} + +func TestAudioToWav_StereoDownmix(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "input.wav") + dst := filepath.Join(dir, "output.wav") + + generateTestWav(t, src, 16000, 2, 1600) + + if err := AudioToWav(src, dst); err != nil { + t.Fatalf("AudioToWav failed: %v", err) + } + + verifyWavFormat(t, dst, 16000, 1) +} + +func TestAudioToWav_StereoAndResample(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "input.wav") + dst := filepath.Join(dir, "output.wav") + + generateTestWav(t, src, 44100, 2, 44100) + + if err := AudioToWav(src, dst); err != nil { + t.Fatalf("AudioToWav failed: %v", err) + } + + verifyWavFormat(t, dst, 16000, 1) +} + +func verifyWavFormat(t *testing.T, path string, expectedRate uint32, expectedChannels uint16) { + t.Helper() + f, err := os.Open(path) + if err != nil { + t.Fatalf("open: %v", err) + } + defer f.Close() + + var hdr laudio.WAVHeader + if err := binary.Read(f, binary.LittleEndian, &hdr); err != nil { + t.Fatalf("read header: %v", err) + } + + if hdr.SampleRate != expectedRate { + t.Errorf("sample rate = %d, want %d", hdr.SampleRate, expectedRate) + } + if hdr.NumChannels != expectedChannels { + t.Errorf("channels = %d, want %d", hdr.NumChannels, expectedChannels) + } + if hdr.BitsPerSample != 16 { + t.Errorf("bit depth = %d, want 16", hdr.BitsPerSample) + } + if hdr.AudioFormat != 1 { + t.Errorf("audio format = %d, want 1 (PCM)", hdr.AudioFormat) + } +} diff --git a/tests/e2e-backends/backend_test.go b/tests/e2e-backends/backend_test.go index 29af3fc31e7e..9ce7576502b0 100644 --- a/tests/e2e-backends/backend_test.go +++ b/tests/e2e-backends/backend_test.go @@ -40,6 +40,12 @@ import ( // to download alongside the main model — required for // multimodal models like Qwen3-ASR-0.6B-GGUF. // BACKEND_TEST_MMPROJ_FILE Path to an already-available mmproj file. +// BACKEND_TEST_EXTRA_FILES Pipe-separated list of companion files to download +// next to the main model. Each entry is "" or +// "#" (the optional suffix renames +// the file on disk — useful for sherpa-onnx models +// whose loader expects specific names like +// encoder.int8.onnx). // BACKEND_TEST_AUDIO_URL HTTP(S) URL of a sample audio file used by the // transcription specs. // BACKEND_TEST_AUDIO_FILE Path to an already-available sample audio file. @@ -71,6 +77,9 @@ import ( // (default: "What's the weather like in Paris, France?"). // BACKEND_TEST_TOOL_NAME Override the function name expected in the tool call // (default: "get_weather"). +// BACKEND_TEST_TTS_TEXT Override the text synthesized by the tts/ttsstream +// specs (default: "The quick brown fox jumps over the +// lazy dog."). // // The suite is intentionally model-format-agnostic: it only ever passes the // file path to LoadModel, so GGUF, ONNX, safetensors, .bin etc. all work so @@ -83,6 +92,7 @@ const ( capEmbeddings = "embeddings" capTools = "tools" capTranscription = "transcription" + capTTS = "tts" capImage = "image" capFaceDetect = "face_detect" capFaceEmbed = "face_embed" @@ -99,6 +109,7 @@ const ( defaultImagePrompt = "a photograph of an astronaut riding a horse" defaultImageSteps = 4 defaultVerifyDistanceCeil = float32(0.6) // upper bound for same-person; SFace runs closer to 0.5 ArcFace to 0.35. + defaultTTSText = "The quick brown fox jumps over the lazy dog." ) func defaultCaps() map[string]bool { @@ -110,6 +121,17 @@ func defaultCaps() map[string]bool { } } +// splitURLAndName parses a "#" entry. The #name suffix is +// optional — if absent, defaultName is returned. Used by the main-model +// and extras download paths so a test can rename downloaded files to the +// shape the backend's loader expects. +func splitURLAndName(entry, defaultName string) (url, name string) { + if hash := strings.Index(entry, "#"); hash >= 0 { + return entry[:hash], entry[hash+1:] + } + return entry, defaultName +} + // parseCaps reads BACKEND_TEST_CAPS and returns the enabled capability set. // An empty/unset value falls back to defaultCaps(). func parseCaps() map[string]bool { @@ -199,9 +221,33 @@ var _ = Describe("Backend container", Ordered, func() { Expect(filepath.Join(binaryDir, "run.sh")).To(BeAnExistingFile()) // Download the model once if not provided and no HF name given. + // BACKEND_TEST_MODEL_URL accepts an optional "#" suffix + // for cases where the backend expects the model file to have a + // specific name (e.g. sherpa-onnx's online recognizer finds + // encoder/decoder/joiner by filename substring). if modelFile == "" && modelName == "" { - modelFile = filepath.Join(workDir, "model.bin") - downloadFile(modelURL, modelFile) + url, name := splitURLAndName(modelURL, "model.bin") + modelFile = filepath.Join(workDir, name) + downloadFile(url, modelFile) + } + + // Multi-file models (sherpa-onnx streaming zipformer, sherpa-onnx + // Omnilingual, any split encoder/decoder/joiner bundle) need + // companion files next to the main model. BACKEND_TEST_EXTRA_FILES + // is a pipe-separated list of "[#]" entries; each + // is downloaded into the same directory as modelFile. The optional + // renames the saved file (useful when upstream URLs + // have stamp/version suffixes the loader doesn't recognise). + if extraSpec := strings.TrimSpace(os.Getenv("BACKEND_TEST_EXTRA_FILES")); extraSpec != "" && modelFile != "" { + modelDir := filepath.Dir(modelFile) + for _, entry := range strings.Split(extraSpec, "|") { + entry = strings.TrimSpace(entry) + if entry == "" { + continue + } + url, name := splitURLAndName(entry, filepath.Base(entry)) + downloadFile(url, filepath.Join(modelDir, name)) + } } // Multimodal projector (mmproj): required by audio/vision-capable @@ -787,6 +833,62 @@ var _ = Describe("Backend container", Ordered, func() { } GinkgoWriter.Printf("voice_analyze: %d segments\n", len(res.GetSegments())) }) + + It("synthesizes speech via TTS", func() { + if !caps[capTTS] { + Skip("tts capability not enabled") + } + text := os.Getenv("BACKEND_TEST_TTS_TEXT") + if text == "" { + text = defaultTTSText + } + dst := filepath.Join(workDir, "tts.wav") + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + _, err := client.TTS(ctx, &pb.TTSRequest{Text: text, Dst: dst}) + Expect(err).NotTo(HaveOccurred()) + + info, err := os.Stat(dst) + Expect(err).NotTo(HaveOccurred(), "TTS did not write a file at %s", dst) + Expect(info.Size()).To(BeNumerically(">", int64(1024)), + "TTS output too small: %d bytes", info.Size()) + GinkgoWriter.Printf("TTS: wrote %s (%d bytes)\n", dst, info.Size()) + }) + + It("streams PCM via TTSStream", func() { + if !caps[capTTS] { + Skip("tts capability not enabled") + } + text := os.Getenv("BACKEND_TEST_TTS_TEXT") + if text == "" { + text = defaultTTSText + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + stream, err := client.TTSStream(ctx, &pb.TTSRequest{Text: text}) + Expect(err).NotTo(HaveOccurred()) + + var chunks, totalBytes int + for { + reply, err := stream.Recv() + if err == io.EOF { + break + } + Expect(err).NotTo(HaveOccurred()) + if audio := reply.GetAudio(); len(audio) > 0 { + chunks++ + totalBytes += len(audio) + } + } + // Header + at least one PCM chunk proves real streaming (not emit-once). + Expect(chunks).To(BeNumerically(">=", 2), + "expected >=2 chunks (header + PCM), got %d (bytes=%d)", chunks, totalBytes) + Expect(totalBytes).To(BeNumerically(">", 1024), + "streamed audio too short: %d bytes", totalBytes) + GinkgoWriter.Printf("TTSStream: %d chunks, %d bytes\n", chunks, totalBytes) + }) }) // extractImage runs `docker create` + `docker export` to materialise the image @@ -819,9 +921,17 @@ func extractImage(image, dest string) { // downloadFile fetches url into dest using curl -L. Used for CI convenience; // local runs can use BACKEND_TEST_MODEL_FILE to skip downloading. +// Retry flags guard against transient CI network hiccups (github.com in +// particular has been flaky from GHA runners, timing out TCP connects). func downloadFile(url, dest string) { GinkgoHelper() - cmd := exec.Command("curl", "-sSfL", "-o", dest, url) + cmd := exec.Command("curl", "-sSfL", + "--connect-timeout", "30", + "--max-time", "600", + "--retry", "5", + "--retry-delay", "5", + "--retry-all-errors", + "-o", dest, url) cmd.Stdout = GinkgoWriter cmd.Stderr = GinkgoWriter Expect(cmd.Run()).To(Succeed(), "failed to download %s", url) diff --git a/tests/e2e/e2e_suite_test.go b/tests/e2e/e2e_suite_test.go index 65af629e0477..f6cef3fdfb36 100644 --- a/tests/e2e/e2e_suite_test.go +++ b/tests/e2e/e2e_suite_test.go @@ -212,6 +212,9 @@ var _ = BeforeSuite(func() { // Import model configs from an external directory (e.g. real model YAMLs // and weights mounted into a container). Symlinks avoid copying large files. + // Both files and directories are symlinked — multi-file backends like + // sherpa-onnx TTS expect their tokens.txt / lexicon.txt sidecars in the + // same directory as the .onnx, so we need whole-directory imports. if rtModels := os.Getenv("REALTIME_MODELS_PATH"); rtModels != "" { entries, err := os.ReadDir(rtModels) Expect(err).ToNot(HaveOccurred()) @@ -221,9 +224,6 @@ var _ = BeforeSuite(func() { if _, err := os.Stat(dst); err == nil { continue // don't overwrite mock configs } - if entry.IsDir() { - continue - } Expect(os.Symlink(src, dst)).To(Succeed()) } } diff --git a/tests/e2e/realtime_ws_test.go b/tests/e2e/realtime_ws_test.go index c69186f3345e..6aeb3fcd4ab6 100644 --- a/tests/e2e/realtime_ws_test.go +++ b/tests/e2e/realtime_ws_test.go @@ -1,15 +1,21 @@ package e2e_test import ( + "bytes" "encoding/base64" "encoding/json" "fmt" + "io" "math" + "net/http" "net/url" "os" + "strings" "time" "github.com/gorilla/websocket" + laudio "github.com/mudler/LocalAI/pkg/audio" + "github.com/mudler/LocalAI/pkg/sound" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -72,6 +78,66 @@ func generatePCMBase64(freq float64, sampleRate, durationMs int) string { return base64.StdEncoding.EncodeToString(pcm) } +// padPCMBase64 prepends and appends the given milliseconds of silence to a +// base64-encoded 16-bit LE PCM buffer. Used to give VAD a clear lead-in / +// lead-out so Silero can reliably detect utterance boundaries. +func padPCMBase64(pcmB64 string, sampleRate, leadingMs, trailingMs int) string { + raw, err := base64.StdEncoding.DecodeString(pcmB64) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + lead := make([]byte, sampleRate*leadingMs/1000*2) + trail := make([]byte, sampleRate*trailingMs/1000*2) + padded := make([]byte, 0, len(lead)+len(raw)+len(trail)) + padded = append(padded, lead...) + padded = append(padded, raw...) + padded = append(padded, trail...) + return base64.StdEncoding.EncodeToString(padded) +} + +// ttsPCMBase64 drives the /v1/audio/speech endpoint to render `text` through +// the given TTS model, strips the returned WAV header, resamples to the +// requested sample rate if needed, and returns base64-encoded 16-bit LE PCM. +// Fails the test on any transport / format error — there's no useful fallback. +func ttsPCMBase64(model, text string, targetSampleRate int) string { + body, err := json.Marshal(map[string]any{ + "model": model, + "input": text, + "format": "wav", + }) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + + req, err := http.NewRequest(http.MethodPost, apiURL+"/audio/speech", bytes.NewReader(body)) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + defer resp.Body.Close() + + wav, err := io.ReadAll(resp.Body) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, resp.StatusCode).To(Equal(http.StatusOK), + "tts returned %d: %s", resp.StatusCode, string(wav)) + + pcm, srcRate := laudio.ParseWAV(wav) + ExpectWithOffset(1, srcRate).To(BeNumerically(">", 0), + "tts response is not a valid WAV (body=%d bytes)", len(wav)) + + if srcRate != targetSampleRate { + samples := sound.BytesToInt16sLE(pcm) + pcm = sound.Int16toBytesLE(sound.ResampleInt16(samples, srcRate, targetSampleRate)) + } + return base64.StdEncoding.EncodeToString(pcm) +} + +// isRealTTS returns true when REALTIME_TTS names a real backend-backed model, +// as opposed to the default mock-tts. Used to gate test behavior that only +// makes sense with a real TTS — e.g. driving the session with a real +// utterance and asserting the transcription contains recognisable words. +func isRealTTS() bool { + m := os.Getenv("REALTIME_TTS") + return m != "" && m != "mock-tts" +} + // pipelineModel returns the model name to use for realtime tests. func pipelineModel() string { if m := os.Getenv("REALTIME_TEST_MODEL"); m != "" { @@ -139,8 +205,19 @@ var _ = Describe("Realtime WebSocket API", Label("Realtime"), func() { sendClientEvent(conn, disableVADEvent()) drainUntil(conn, "session.updated", 10*time.Second) - // Append 1 second of 440Hz sine wave at 24kHz (the default remote sample rate) - audio := generatePCMBase64(440, 24000, 1000) + // Real TTS: synthesise an utterance the ASR should be able to + // recognise, and pad with silence so Silero-VAD has a clear + // lead-in/out. Fallback: 1s of 440Hz sine wave — the mock + // transcriber returns a static string anyway, so this only + // needs to exercise the pipeline plumbing. + const inputText = "The quick brown fox jumps over the lazy dog." + var audio string + if isRealTTS() { + audio = ttsPCMBase64(os.Getenv("REALTIME_TTS"), inputText, 24000) + audio = padPCMBase64(audio, 24000, 500, 500) + } else { + audio = generatePCMBase64(440, 24000, 1000) + } sendClientEvent(conn, map[string]any{ "type": "input_audio_buffer.append", "audio": audio, @@ -161,9 +238,30 @@ var _ = Describe("Realtime WebSocket API", Label("Realtime"), func() { committed := drainUntil(conn, "input_audio_buffer.committed", 30*time.Second) Expect(committed).ToNot(BeNil()) - // Wait for the full response cycle to complete - done := drainUntil(conn, "response.done", 60*time.Second) - Expect(done).ToNot(BeNil()) + // Drain the response cycle, capturing the input transcription + // event as it arrives so we can sanity-check it alongside the + // real-TTS path. + var transcript string + deadline := time.Now().Add(90 * time.Second) + for time.Now().Before(deadline) { + evt := readServerEvent(conn, time.Until(deadline)) + if evt["type"] == "conversation.item.input_audio_transcription.completed" { + if t, ok := evt["transcript"].(string); ok { + transcript = t + } + } + if evt["type"] == "response.done" { + Expect(evt).ToNot(BeNil()) + break + } + } + + if isRealTTS() { + lower := strings.ToLower(transcript) + matched := strings.Contains(lower, "fox") || strings.Contains(lower, "dog") + Expect(matched).To(BeTrue(), + "expected real-TTS transcript to contain 'fox' or 'dog' (got %q)", transcript) + } }) }) diff --git a/tests/e2e/run-realtime-sherpa.sh b/tests/e2e/run-realtime-sherpa.sh new file mode 100755 index 000000000000..1c1649c193a9 --- /dev/null +++ b/tests/e2e/run-realtime-sherpa.sh @@ -0,0 +1,136 @@ +#!/bin/bash +# Drives tests/e2e/realtime_ws_test.go against a realtime pipeline where +# VAD, STT and TTS are served by the sherpa-onnx Docker backend, and the +# LLM slot stays mocked by the in-repo mock-backend. Pre-requisites: +# - `make build-mock-backend` has produced tests/e2e/mock-backend/mock-backend +# - `make docker-build-sherpa-onnx` has produced local-ai-backend:sherpa-onnx +# - `make protogen-go` is up-to-date +# Environment overrides: +# WORK_DIR Where to stage the extracted backend + model files. +# Defaults to a mktemp'd directory (cleaned on exit). +# KEEP_WORK Non-empty to preserve WORK_DIR after the test exits (useful for +# debugging repeated runs — skips redownloads if files already present). +set -euo pipefail + +ROOT=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." && pwd) +IMAGE=${BACKEND_IMAGE:-local-ai-backend:sherpa-onnx} +MODEL=${REALTIME_STT_MODEL:-omnilingual-0.3b-ctc-q8-sherpa} +VAD_MODEL=${REALTIME_VAD_MODEL:-silero-vad-sherpa} +TTS_MODEL=${REALTIME_TTS_MODEL:-vits-ljs-sherpa} + +WORK_DIR=${WORK_DIR:-$(mktemp -d -t localai-sherpa-realtime.XXXXXX)} +if [[ -z "${KEEP_WORK:-}" ]]; then + trap 'rm -rf "$WORK_DIR"' EXIT +fi +echo "WORK_DIR=$WORK_DIR" + +BACKENDS_DIR="$WORK_DIR/backends" +MODELS_DIR="$WORK_DIR/models" +mkdir -p "$BACKENDS_DIR/sherpa-onnx" "$MODELS_DIR" + +# 1. Extract the sherpa-onnx backend image rootfs. Mirrors the pattern in +# tests/e2e-backends/backend_test.go:extractImage — docker create + export +# avoids having to pull and parse layer tarballs. +if [[ ! -x "$BACKENDS_DIR/sherpa-onnx/run.sh" ]]; then + echo "Extracting $IMAGE rootfs into $BACKENDS_DIR/sherpa-onnx ..." + CID=$(docker create --entrypoint=/run.sh "$IMAGE") + trap 'docker rm -f "$CID" >/dev/null 2>&1 || true; [[ -z "${KEEP_WORK:-}" ]] && rm -rf "$WORK_DIR"' EXIT + docker export "$CID" | tar -xC "$BACKENDS_DIR/sherpa-onnx" \ + --exclude='dev/*' --exclude='proc/*' --exclude='sys/*' + docker rm -f "$CID" >/dev/null +fi + +# Make sure run.sh is executable (tar usually preserves this, but belt + braces). +chmod +x "$BACKENDS_DIR/sherpa-onnx/run.sh" \ + "$BACKENDS_DIR/sherpa-onnx/sherpa-onnx" 2>/dev/null || true + +# 2. Download model files. URLs + sha256s match gallery/index.yaml entries. +download() { + local dst="$1" url="$2" sha="$3" + if [[ -f "$dst" ]]; then + actual=$(sha256sum "$dst" | awk '{print $1}') + if [[ "$actual" == "$sha" ]]; then + echo "cached: $dst" + return + fi + fi + mkdir -p "$(dirname "$dst")" + echo "downloading: $url -> $dst" + curl -sSfL "$url" -o "$dst" + actual=$(sha256sum "$dst" | awk '{print $1}') + if [[ "$actual" != "$sha" ]]; then + echo "sha256 mismatch for $dst: got $actual, expected $sha" >&2 + exit 1 + fi +} + +# Silero VAD (single file) +download "$MODELS_DIR/silero-vad/silero-vad.onnx" \ + "https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx" \ + "a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808" + +# Omnilingual ASR (model + tokens) +download "$MODELS_DIR/omnilingual-asr/model.int8.onnx" \ + "https://huggingface.co/csukuangfj/sherpa-onnx-omnilingual-asr-1600-languages-300M-ctc-int8-2025-11-12/resolve/main/model.int8.onnx" \ + "e7c4e54ee4c4c47829cc6667d5d00ed8ea7bef1dcfeef0fce766f77752a2726c" +download "$MODELS_DIR/omnilingual-asr/tokens.txt" \ + "https://huggingface.co/csukuangfj/sherpa-onnx-omnilingual-asr-1600-languages-300M-ctc-int8-2025-11-12/resolve/main/tokens.txt" \ + "a7a044c52cb29cbe8b0dc1953e92cefd4ca16b0ed968177b6beab21f9a7d0b31" + +# VITS-LJS TTS (model + tokens + lexicon) +download "$MODELS_DIR/vits-ljs/vits-ljs.onnx" \ + "https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx" \ + "5bbd273797a9ecf8d94bd6ec02ad16cb41cbb85f055ad98d528ced3e44c9b31a" +download "$MODELS_DIR/vits-ljs/tokens.txt" \ + "https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt" \ + "5fee2c6b238d712287f2ecb08f34a8a8b413bcb7390862ef6fb6fd6f0f8d3a17" +download "$MODELS_DIR/vits-ljs/lexicon.txt" \ + "https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt" \ + "bdccfc6da71c45c48e2e0056fcf0aab760577c5f959f6c1b5eb3e3e916fd5a0e" + +# 3. Write model config YAMLs matching the gallery entries' shape. These are +# what the realtime pipeline resolves via LoadModelConfigFileByName. +cat > "$MODELS_DIR/$VAD_MODEL.yaml" < "$MODELS_DIR/$MODEL.yaml" < "$MODELS_DIR/$TTS_MODEL.yaml" <