Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion api/extensions/ext_socketio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,27 @@

from configs import dify_config

SOCKETIO_COLLABORATION_CHANNEL = "socketio:collaboration"


def create_client_manager():
if not dify_config.ENABLE_COLLABORATION_MODE:
return None

return socketio.RedisManager(
dify_config.normalized_pubsub_redis_url,
channel=SOCKETIO_COLLABORATION_CHANNEL,
)


def create_socketio_server():
return socketio.Server(
async_mode="gevent",
client_manager=create_client_manager(),
cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS,
)


# TODO: FIXME(chariri) - Casting to any because app_factory attaches the
# current app as the `app` attribute on this - Bad.
sio = cast(Any, socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS))
sio = cast(Any, create_socketio_server())
58 changes: 58 additions & 0 deletions api/tests/unit_tests/extensions/test_socketio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from types import SimpleNamespace
from unittest.mock import Mock

import pytest

from extensions import ext_socketio


def test_create_client_manager_uses_in_memory_transport_when_collaboration_disabled(
monkeypatch: pytest.MonkeyPatch,
):
config = SimpleNamespace(ENABLE_COLLABORATION_MODE=False)
redis_manager = Mock()
monkeypatch.setattr(ext_socketio, "dify_config", config)
monkeypatch.setattr(ext_socketio.socketio, "RedisManager", redis_manager)

manager = ext_socketio.create_client_manager()

assert manager is None
redis_manager.assert_not_called()


def test_create_client_manager_uses_redis_when_collaboration_enabled(
monkeypatch: pytest.MonkeyPatch,
):
config = SimpleNamespace(
ENABLE_COLLABORATION_MODE=True,
normalized_pubsub_redis_url="redis://redis:6379/1",
)
redis_manager = Mock(return_value=object())
monkeypatch.setattr(ext_socketio, "dify_config", config)
monkeypatch.setattr(ext_socketio.socketio, "RedisManager", redis_manager)

manager = ext_socketio.create_client_manager()

assert manager is redis_manager.return_value
redis_manager.assert_called_once_with(
"redis://redis:6379/1",
channel=ext_socketio.SOCKETIO_COLLABORATION_CHANNEL,
)


def test_create_socketio_server_passes_client_manager(monkeypatch: pytest.MonkeyPatch):
config = SimpleNamespace(CONSOLE_CORS_ALLOW_ORIGINS=["https://example.com"])
client_manager = object()
server = Mock(return_value=object())
monkeypatch.setattr(ext_socketio, "dify_config", config)
monkeypatch.setattr(ext_socketio, "create_client_manager", Mock(return_value=client_manager))
monkeypatch.setattr(ext_socketio.socketio, "Server", server)

result = ext_socketio.create_socketio_server()

assert result is server.return_value
server.assert_called_once_with(
async_mode="gevent",
client_manager=client_manager,
cors_allowed_origins=["https://example.com"],
)