Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ EVA_MODEL__TTS_PARAMS='{"api_key": "your_cartesia_api_key", "model": "sonic"}'
# --- Framework (S2S / AudioLLM) ---
#i Base framework for S2S or AudioLLM pipelines.
#d enum
#e pipecat,openai_realtime,gemini_live,elevenlabs
#e pipecat,openai_realtime,gemini_live,elevenlabs,grok_voice
#v EVA_FRAMEWORK=openai_realtime

# ==============================================
Expand Down
82 changes: 82 additions & 0 deletions src/eva/assistant/grok_voice_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Grok Voice realtime assistant server (xAI).

xAI's voice realtime API is event-compatible with OpenAI's Realtime API
(per https://docs.x.ai/developers/model-capabilities/audio/voice-agent#openai-realtime-api-compatibility),
so this server subclasses `OpenAIRealtimeAssistantServer` and overrides
only the hooks that differ:

* `_create_client` — point AsyncOpenAI at api.x.ai/v1
* `_default_voice` — xAI's built-in voices are `eve`/`ara`/`rex`/`sal`/`leo`
* `_build_session_config` — xAI doesn't accept the `transcription.model` selector
* `_on_transcription_completed` — xAI sends incremental completed events; only flush the final one

The shared audio bridge, event loop, tool round-trip, audit logging,
and latency metrics in `OpenAIRealtimeAssistantServer` are reused as-is.
"""

from typing import Any

from openai import AsyncOpenAI

from eva.assistant.openai_realtime_server import OpenAIRealtimeAssistantServer
from eva.utils.logging import get_logger

logger = get_logger(__name__)

XAI_REALTIME_BASE_URL = "https://api.x.ai/v1"


class GrokVoiceAssistantServer(OpenAIRealtimeAssistantServer):
"""Assistant server backed by xAI's Grok voice realtime API."""

_service_name: str = "Grok Voice"
_metrics_processor_name: str = "grok_voice"

def _create_client(self) -> AsyncOpenAI:
api_key = self.pipeline_config.s2s_params.get("api_key")
if not api_key:
raise ValueError(f"API key required for {self._service_name}")
return AsyncOpenAI(
api_key=api_key, base_url=self.pipeline_config.s2s_params.get("base_url", XAI_REALTIME_BASE_URL)
)

def _default_voice(self) -> str:
return "eve"

# ── Deferred transcription (xAI sends incremental completed events) ──

def _flush_pending_user_transcript(self) -> None:
"""Write the buffered user transcript to the audit log if pending."""
if self._user_turn and self._user_turn.transcript and not self._user_turn.flushed:
timestamp_ms = self._user_turn.speech_started_wall_ms or None
self.audit_log.append_user_input(self._user_turn.transcript, timestamp_ms=timestamp_ms)
self._user_turn.flushed = True
logger.debug(f"Flushed deferred user transcript: {self._user_turn.transcript[:60]}...")

async def _on_transcription_completed(self, event: Any) -> None:
"""Buffer transcription instead of writing immediately.

xAI fires ``conversation.item.input_audio_transcription.completed``
multiple times per turn with progressively longer text. We store
each update but defer the audit-log write until the turn is done
(see ``_on_speech_started`` / ``_on_response_done``).
"""
transcript = getattr(event, "transcript", "") or ""
transcript = transcript.strip()
if not transcript:
return

if self._user_turn:
self._user_turn.transcript = transcript
# Do NOT set flushed or write to audit_log yet
logger.debug(f"Buffered user transcription: {transcript[:60]}...")

async def _on_speech_started(self, event: Any) -> None:
"""Flush any pending transcript before starting a new turn."""
self._flush_pending_user_transcript()
await super()._on_speech_started(event)

async def _on_response_done(self, event: Any) -> None:
"""Flush any pending transcript before recording assistant output."""
self._flush_pending_user_transcript()
await super()._on_response_done(event)
127 changes: 75 additions & 52 deletions src/eva/assistant/openai_realtime_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class OpenAIRealtimeAssistantServer(AbstractAssistantServer):
(24 kHz PCM16 base64).
"""

_service_name: str = "OpenAI Realtime"
_metrics_processor_name: str = "openai_realtime"

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

Expand Down Expand Up @@ -142,7 +145,7 @@ async def websocket_root(websocket: WebSocket):
while not self._server.started:
await asyncio.sleep(0.01)

logger.info(f"OpenAI Realtime server started on ws://localhost:{self.port}")
logger.info(f"{self._service_name} server started on ws://localhost:{self.port}")

async def _shutdown(self) -> None:
"""Stop the OpenAI Realtime server."""
Expand All @@ -167,7 +170,66 @@ async def _shutdown(self) -> None:
self._server = None
self._server_task = None

logger.info(f"OpenAI Realtime server stopped on port {self.port}")
logger.info(f"{self._service_name} server stopped on port {self.port}")

def _default_voice(self) -> str:
"""Default voice ID when `s2s_params.voice` is not set.

Subclasses override when the underlying service uses a different
voice catalogue.
"""
return "marin"

def _build_session_config(self) -> dict[str, Any]:
"""Construct the `session.update` payload for the realtime connection.

Subclasses override to adjust service-specific fields (e.g. drop the
`transcription.model` selector for xAI, which doesn't expose one).
"""
s2s = self.pipeline_config.s2s_params or {}
vad = s2s.get("vad_settings", {}) or {}

session_config: dict[str, Any] = {
"type": "realtime",
"output_modalities": ["audio"],
"instructions": self._system_prompt,
"audio": {
"output": {
"voice": s2s.get("voice", self._default_voice()),
"format": {"type": "audio/pcm", "rate": 24000},
},
"input": {
"format": {"type": "audio/pcm", "rate": 24000},
"turn_detection": {
"type": vad.get("type", "server_vad"),
"threshold": vad.get("threshold", 0.5),
"prefix_padding_ms": vad.get("prefix_padding_ms", 300),
"silence_duration_ms": vad.get("silence_duration_ms", 200),
},
"transcription": {
"model": s2s.get("transcription_model", "whisper-1"),
},
},
},
"tools": self._realtime_tools,
}

reasoning_effort = s2s.get("reasoning_effort")
if reasoning_effort:
session_config["reasoning"] = {"effort": reasoning_effort}

return session_config

def _create_client(self) -> AsyncOpenAI:
"""Construct the AsyncOpenAI client used for the realtime connection.

Subclasses override to point at a different base_url (e.g. xAI's
realtime endpoint, which is OpenAI-Realtime-API-compatible).
"""
api_key = self.pipeline_config.s2s_params.get("api_key")
if not api_key:
raise ValueError(f"API key required for {self._service_name}")
return AsyncOpenAI(api_key=api_key)

async def _handle_session(self, websocket: WebSocket) -> None:
"""Handle a single WebSocket session.
Expand All @@ -181,7 +243,7 @@ async def _handle_session(self, websocket: WebSocket) -> None:
5. On tool call: execute via self.tool_handler, send result back
6. On audio: decode base64 PCM16 -> record -> encode mulaw -> send to Twilio WS
"""
logger.info("Client connected to OpenAI Realtime server")
logger.info(f"Client connected to {self._service_name} server")

# Reset per-session state
self._user_turn = None
Expand All @@ -190,52 +252,13 @@ async def _handle_session(self, websocket: WebSocket) -> None:
self._user_speaking = False
self._bot_speaking = False

api_key = self.pipeline_config.s2s_params.get("api_key")
if not api_key:
raise ValueError("API key required for openai realtime")
client = AsyncOpenAI(api_key=api_key)
client = self._create_client()

try:
logger.info(f"Starting OpenAI Realtime session (model={self._model})")
logger.info(f"Starting {self._service_name} session (model={self._model})")
async with client.realtime.connect(model=self._model) as conn:
# Configure the session
session_config: dict[str, Any] = {
"type": "realtime",
"output_modalities": ["audio"],
"instructions": self._system_prompt,
"audio": {
"output": {
"voice": self.pipeline_config.s2s_params.get("voice", "marin"),
"format": {"type": "audio/pcm", "rate": 24000},
},
"input": {
"format": {"type": "audio/pcm", "rate": 24000},
"turn_detection": {
"type": self.pipeline_config.s2s_params.get("vad_settings", {}).get(
"type", "server_vad"
),
"threshold": self.pipeline_config.s2s_params.get("vad_settings", {}).get(
"threshold", 0.5
),
"prefix_padding_ms": self.pipeline_config.s2s_params.get("vad_settings", {}).get(
"prefix_padding_ms", 300
),
"silence_duration_ms": self.pipeline_config.s2s_params.get("vad_settings", {}).get(
"silence_duration_ms", 200
),
},
"transcription": {
"model": self.pipeline_config.s2s_params.get("transcription_model", "whisper-1")
},
},
},
"tools": self._realtime_tools,
}

reasoning_effort = self.pipeline_config.s2s_params.get("reasoning_effort")
if reasoning_effort:
session_config["reasoning"] = {"effort": reasoning_effort}

session_config = self._build_session_config()
await conn.session.update(session=session_config)

# Trigger the initial greeting
Expand Down Expand Up @@ -276,9 +299,9 @@ async def _handle_session(self, websocket: WebSocket) -> None:
logger.error(f"Session task failed: {task.exception()}")

except Exception as e:
logger.error(f"OpenAI Realtime session error: {e}", exc_info=True)
logger.error(f"{self._service_name} session error: {e}", exc_info=True)
finally:
logger.info("Client disconnected from OpenAI Realtime server")
logger.info(f"Client disconnected from {self._service_name} server")

# ── Audio output pacer (OpenAI -> Twilio WS at real-time rate) ───

Expand Down Expand Up @@ -404,10 +427,10 @@ async def _handle_openai_event(

match event_type:
case "session.created":
logger.info("OpenAI Realtime session created")
logger.info(f"{self._service_name} session created")

case "session.updated":
logger.debug("OpenAI Realtime session updated")
logger.debug(f"{self._service_name} session updated")

case "input_audio_buffer.speech_started":
await self._on_speech_started(event)
Expand Down Expand Up @@ -452,10 +475,10 @@ async def _handle_openai_event(

case "error":
error_data = getattr(event, "error", None)
logger.error(f"OpenAI Realtime error: {error_data}")
logger.error(f"{self._service_name} error: {error_data}")

case _:
logger.debug(f"Unhandled OpenAI event: {event_type}")
logger.debug(f"Unhandled {self._service_name} event: {event_type}")

# ── Event handlers ────────────────────────────────────────────────

Expand Down Expand Up @@ -660,7 +683,7 @@ async def _on_response_done(self, event: Any) -> None:
input_tokens = getattr(usage, "input_tokens", 0) or 0
output_tokens = getattr(usage, "output_tokens", 0) or 0
self._metrics_log.write_token_usage(
processor="openai_realtime",
processor=self._metrics_processor_name,
model=self._model,
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
Expand Down
3 changes: 2 additions & 1 deletion src/eva/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,15 @@ class ModelDeployment(DeploymentTypedDict):
)

# Framework selection
framework: Literal["pipecat", "openai_realtime", "gemini_live", "elevenlabs"] = Field(
framework: Literal["pipecat", "openai_realtime", "gemini_live", "elevenlabs", "grok_voice"] = Field(
"pipecat",
description=(
"Agent framework to use for the assistant server."
"'pipecat' (default): Pipecat pipeline."
"'openai_realtime': OpenAI Realtime API directly."
"'gemini_live': Gemini Live API via google-genai."
"'elevenlabs': ElevenLabs Conversational AI API."
"'grok_voice': xAI Grok voice realtime API."
),
)

Expand Down
7 changes: 6 additions & 1 deletion src/eva/orchestrator/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,14 @@ def _get_server_class(framework: str) -> type[AbstractAssistantServer]:
from eva.assistant.elevenlabs_server import ElevenLabsAssistantServer

return ElevenLabsAssistantServer
elif framework == "grok_voice":
from eva.assistant.grok_voice_server import GrokVoiceAssistantServer

return GrokVoiceAssistantServer
else:
raise ValueError(
f"Unknown framework: {framework!r}. Supported: pipecat, openai_realtime, gemini_live, elevenlabs"
f"Unknown framework: {framework!r}. "
"Supported: pipecat, openai_realtime, gemini_live, elevenlabs, grok_voice"
)


Expand Down
71 changes: 71 additions & 0 deletions tests/unit/assistant/test_grok_voice_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Tests for GrokVoiceAssistantServer hook overrides."""

from unittest.mock import MagicMock

from openai import AsyncOpenAI

from eva.assistant.grok_voice_server import GrokVoiceAssistantServer


def _bare_server() -> GrokVoiceAssistantServer:
srv = object.__new__(GrokVoiceAssistantServer)
srv.pipeline_config = MagicMock()
srv.pipeline_config.s2s_params = {
"api_key": "xai-test-key",
"model": "grok-voice-latest",
}
srv._model = "grok-voice-latest"
srv._system_prompt = "you are a helpful assistant"
srv._realtime_tools = []
return srv


class TestCreateClient:
def test_uses_xai_base_url(self):
srv = _bare_server()
client = srv._create_client()
assert isinstance(client, AsyncOpenAI)
assert client.api_key == "xai-test-key"
assert "api.x.ai" in str(client.base_url)

def test_raises_when_api_key_missing(self):
srv = _bare_server()
srv.pipeline_config.s2s_params = {}
try:
srv._create_client()
except ValueError as e:
assert "API key required" in str(e)
assert "Grok Voice" in str(e)
else:
raise AssertionError("expected ValueError")


class TestDefaultVoice:
def test_default_voice_is_eve(self):
srv = _bare_server()
assert srv._default_voice() == "eve"


class TestBuildSessionConfig:
def test_voice_defaults_to_eve(self):
srv = _bare_server()
cfg = srv._build_session_config()
assert cfg["audio"]["output"]["voice"] == "eve"

def test_explicit_voice_passes_through(self):
srv = _bare_server()
srv.pipeline_config.s2s_params = {
"api_key": "xai-test-key",
"model": "grok-voice-latest",
"voice": "rex",
}
cfg = srv._build_session_config()
assert cfg["audio"]["output"]["voice"] == "rex"


class TestServiceLabels:
def test_service_name(self):
assert GrokVoiceAssistantServer._service_name == "Grok Voice"

def test_metrics_processor_name(self):
assert GrokVoiceAssistantServer._metrics_processor_name == "grok_voice"
Loading
Loading