diff --git a/api/extensions/ext_socketio.py b/api/extensions/ext_socketio.py index 2fe2369e9f86f6..62e9eb7b5e328f 100644 --- a/api/extensions/ext_socketio.py +++ b/api/extensions/ext_socketio.py @@ -4,6 +4,27 @@ from configs import dify_config +SOCKETIO_COLLABORATION_CHANNEL = "socketio:collaboration" + + +def create_client_manager(): + if not dify_config.ENABLE_COLLABORATION_MODE: + return None + + return socketio.RedisManager( + dify_config.normalized_pubsub_redis_url, + channel=SOCKETIO_COLLABORATION_CHANNEL, + ) + + +def create_socketio_server(): + return socketio.Server( + async_mode="gevent", + client_manager=create_client_manager(), + cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS, + ) + + # TODO: FIXME(chariri) - Casting to any because app_factory attaches the # current app as the `app` attribute on this - Bad. -sio = cast(Any, socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)) +sio = cast(Any, create_socketio_server()) diff --git a/api/tests/unit_tests/extensions/test_socketio.py b/api/tests/unit_tests/extensions/test_socketio.py new file mode 100644 index 00000000000000..85710fd7ad974a --- /dev/null +++ b/api/tests/unit_tests/extensions/test_socketio.py @@ -0,0 +1,58 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from extensions import ext_socketio + + +def test_create_client_manager_uses_in_memory_transport_when_collaboration_disabled( + monkeypatch: pytest.MonkeyPatch, +): + config = SimpleNamespace(ENABLE_COLLABORATION_MODE=False) + redis_manager = Mock() + monkeypatch.setattr(ext_socketio, "dify_config", config) + monkeypatch.setattr(ext_socketio.socketio, "RedisManager", redis_manager) + + manager = ext_socketio.create_client_manager() + + assert manager is None + redis_manager.assert_not_called() + + +def test_create_client_manager_uses_redis_when_collaboration_enabled( + monkeypatch: pytest.MonkeyPatch, +): + config = SimpleNamespace( + ENABLE_COLLABORATION_MODE=True, + normalized_pubsub_redis_url="redis://redis:6379/1", + ) + redis_manager = Mock(return_value=object()) + monkeypatch.setattr(ext_socketio, "dify_config", config) + monkeypatch.setattr(ext_socketio.socketio, "RedisManager", redis_manager) + + manager = ext_socketio.create_client_manager() + + assert manager is redis_manager.return_value + redis_manager.assert_called_once_with( + "redis://redis:6379/1", + channel=ext_socketio.SOCKETIO_COLLABORATION_CHANNEL, + ) + + +def test_create_socketio_server_passes_client_manager(monkeypatch: pytest.MonkeyPatch): + config = SimpleNamespace(CONSOLE_CORS_ALLOW_ORIGINS=["https://example.com"]) + client_manager = object() + server = Mock(return_value=object()) + monkeypatch.setattr(ext_socketio, "dify_config", config) + monkeypatch.setattr(ext_socketio, "create_client_manager", Mock(return_value=client_manager)) + monkeypatch.setattr(ext_socketio.socketio, "Server", server) + + result = ext_socketio.create_socketio_server() + + assert result is server.return_value + server.assert_called_once_with( + async_mode="gevent", + client_manager=client_manager, + cors_allowed_origins=["https://example.com"], + )