Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/rime-language-param.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@livekit/agents-plugin-rime': patch
---

Send Rime TTS language selection with the documented `language` API parameter while preserving `lang` as a legacy option.
152 changes: 150 additions & 2 deletions plugins/rime/src/tts.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,62 @@
import { initializeLogger } from '@livekit/agents';
import { STT } from '@livekit/agents-plugin-openai';
import { tts } from '@livekit/agents-plugins-test';
import { afterEach, describe, expect, it, vi } from 'vitest';
import { TTS } from './tts.js';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { TTS, type TTSOptions } from './tts.js';

const { MockWebSocket } = vi.hoisted(() => {
class MockWebSocket {
static OPEN = 1;
static CLOSED = 3;
static instances: MockWebSocket[] = [];

readyState = 0;
readonly sent: unknown[] = [];
readonly #listeners = new Map<string, Set<(...args: unknown[]) => void>>();

constructor(
readonly url: string,
readonly options: unknown,
) {
MockWebSocket.instances.push(this);
}

on(event: string, listener: (...args: unknown[]) => void) {
const listeners = this.#listeners.get(event) ?? new Set<(...args: unknown[]) => void>();
listeners.add(listener);
this.#listeners.set(event, listeners);
return this;
}

off(event: string, listener: (...args: unknown[]) => void) {
this.#listeners.get(event)?.delete(listener);
return this;
}

send(data: unknown): void {
this.sent.push(data);
}

close(): void {
this.readyState = MockWebSocket.CLOSED;
this.emit('close', 1000, Buffer.from(''));
}

terminate(): void {
this.readyState = MockWebSocket.CLOSED;
}

emit(event: string, ...args: unknown[]): void {
for (const listener of this.#listeners.get(event) ?? []) {
listener(...args);
}
}
}

return { MockWebSocket };
});

vi.mock('ws', () => ({ default: MockWebSocket, WebSocket: MockWebSocket }));

initializeLogger({ pretty: false, level: 'silent' });

Expand All @@ -30,6 +84,53 @@ async function withTimeout<T>(promise: Promise<T>, ms: number): Promise<T | 'tim
});
}

const sleep = (ms: number) => new Promise<void>((resolve) => setTimeout(resolve, ms));

async function waitFor(predicate: () => boolean, timeoutMs = 1000): Promise<void> {
const start = performance.now();
while (!predicate()) {
if (performance.now() - start > timeoutMs) {
throw new Error('condition not met within timeout');
}
await sleep(5);
}
}

async function captureFetchPayload(opts: Partial<TTSOptions>): Promise<Record<string, unknown>> {
let payload: Record<string, unknown> | undefined;
vi.spyOn(globalThis, 'fetch').mockImplementation(async (_url, init) => {
payload = JSON.parse(String(init?.body));
return new Response(
new ReadableStream<Uint8Array>({
start(controller) {
controller.close();
},
}),
{
status: 200,
headers: { 'Content-Type': 'audio/pcm' },
},
);
});

const rimeTTS = new TTS({
apiKey: 'test-rime-key',
baseURL: 'https://rime.test/v1/rime-tts',
modelId: 'arcana',
speaker: 'luna',
...opts,
});

const result = await withTimeout(rimeTTS.synthesize('Hello from Rime.').next(), 1000);
expect(result).not.toBe('timeout');
expect(payload).toBeDefined();
return payload!;
}

beforeEach(() => {
MockWebSocket.instances.length = 0;
});

describe('Rime TTS streaming', () => {
afterEach(() => {
vi.restoreAllMocks();
Expand Down Expand Up @@ -83,6 +184,53 @@ describe('Rime TTS streaming', () => {
});
});

describe('Rime TTS language options', () => {
afterEach(() => {
vi.restoreAllMocks();
});

it('sends language instead of lang in one-shot payloads', async () => {
const payload = await captureFetchPayload({ language: 'eng', lang: 'spa' });

expect(payload.language).toBe('eng');
expect(payload).not.toHaveProperty('lang');
});

it('maps legacy lang to the Rime language API parameter', async () => {
const payload = await captureFetchPayload({ lang: 'spa' });

expect(payload.language).toBe('spa');
expect(payload).not.toHaveProperty('lang');
});

it('sends language instead of lang in WebSocket query parameters', async () => {
const rimeTTS = new TTS({
apiKey: 'test-rime-key',
baseURL: 'wss://rime.test',
useWebsocket: true,
modelId: 'arcana',
speaker: 'luna',
language: 'eng',
lang: 'spa',
});
const stream = rimeTTS.stream({
connOptions: { maxRetry: 0, retryIntervalMs: 0, timeoutMs: 1000 },
});

try {
await waitFor(() => MockWebSocket.instances.length > 0);
const socket = MockWebSocket.instances[0]!;
const url = new URL(socket.url);

expect(url.searchParams.get('language')).toBe('eng');
expect(url.searchParams.has('lang')).toBe(false);
} finally {
stream.close();
await withTimeout(stream.next(), 1000);
}
});
});

if (hasRimeConfig) {
describe('Rime TTS', async () => {
await tts(new TTS(), new STT(), { streaming: false });
Expand Down
7 changes: 6 additions & 1 deletion plugins/rime/src/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ export interface TTSOptions {
useWebsocket?: boolean;
segment?: string;
tokenizer?: tokenize.SentenceTokenizer;
/** Rime TTS language code. Prefer this over the deprecated `lang` option. */
language?: DefaultLanguages | string;
/** @deprecated Use `language` instead. */
lang?: DefaultLanguages | string;
repetition_penalty?: number;
temperature?: number;
Expand Down Expand Up @@ -88,7 +91,8 @@ const defaultTTSOptions: TTSOptions = {

function modelParams(opts: TTSOptions): Record<string, string | number | boolean> {
const params: Record<string, string | number | boolean> = {};
if (opts.lang !== undefined) params.lang = opts.lang;
const language = opts.language ?? opts.lang;
if (language !== undefined) params.language = language;

if (opts.modelId === 'arcana') {
if (opts.repetition_penalty !== undefined) params.repetition_penalty = opts.repetition_penalty;
Expand Down Expand Up @@ -139,6 +143,7 @@ function fetchPayload(opts: TTSOptions, text: string): Record<string, unknown> {
'tokenizer',
'speaker',
'modelId',
'language',
'lang',
'repetition_penalty',
'temperature',
Expand Down
Loading