diff --git a/include/mgclient.h b/include/mgclient.h index 9815631..b3c40ee 100644 --- a/include/mgclient.h +++ b/include/mgclient.h @@ -1255,6 +1255,8 @@ MGCLIENT_EXPORT void mg_session_params_set_host(mg_session_params *, const char *host); MGCLIENT_EXPORT void mg_session_params_set_port(mg_session_params *, uint16_t port); +MGCLIENT_EXPORT void mg_session_params_set_scheme(mg_session_params *, + const char *scheme); MGCLIENT_EXPORT void mg_session_params_set_username(mg_session_params *, const char *username); MGCLIENT_EXPORT void mg_session_params_set_password(mg_session_params *, diff --git a/mgclient_cpp/include/mgclient.hpp b/mgclient_cpp/include/mgclient.hpp index dd9d6fe..0724d9e 100644 --- a/mgclient_cpp/include/mgclient.hpp +++ b/mgclient_cpp/include/mgclient.hpp @@ -57,6 +57,7 @@ class Client { struct Params { std::string host = "127.0.0.1"; uint16_t port = 7687; + std::string scheme = "none"; std::string username = ""; std::string password = ""; bool use_ssl = false; @@ -148,13 +149,24 @@ inline std::unique_ptr Client::Connect(const Client::Params ¶ms) { if (!mg_params) { return nullptr; } - mg_session_params_set_host(mg_params, params.host.c_str()); - mg_session_params_set_port(mg_params, params.port); + if (!params.host.empty()) { + mg_session_params_set_host(mg_params, params.host.c_str()); + } + if (params.port != 0) { + mg_session_params_set_port(mg_params, params.port); + } + if (!params.scheme.empty()) { + mg_session_params_set_scheme(mg_params, params.scheme.c_str()); + } if (!params.username.empty()) { mg_session_params_set_username(mg_params, params.username.c_str()); + } + if (!params.password.empty()) { mg_session_params_set_password(mg_params, params.password.c_str()); } - mg_session_params_set_user_agent(mg_params, params.user_agent.c_str()); + if (!params.user_agent.empty()) { + mg_session_params_set_user_agent(mg_params, params.user_agent.c_str()); + } mg_session_params_set_sslmode( mg_params, params.use_ssl ? MG_SSLMODE_REQUIRE : MG_SSLMODE_DISABLE); diff --git a/src/mgclient.c b/src/mgclient.c index 58d8284..f655d20 100644 --- a/src/mgclient.c +++ b/src/mgclient.c @@ -67,6 +67,7 @@ typedef struct mg_session_params { const char *address; const char *host; uint16_t port; + const char *scheme; const char *username; const char *password; const char *user_agent; @@ -87,6 +88,7 @@ mg_session_params *mg_session_params_make(void) { params->address = NULL; params->host = NULL; params->port = 0; + params->scheme = NULL; params->username = NULL; params->password = NULL; params->user_agent = MG_USER_AGENT; @@ -118,6 +120,11 @@ void mg_session_params_set_port(mg_session_params *params, uint16_t port) { params->port = port; } +void mg_session_params_set_scheme(mg_session_params *params, + const char *scheme) { + params->scheme = scheme; +} + void mg_session_params_set_username(mg_session_params *params, const char *username) { params->username = username; @@ -364,8 +371,8 @@ int mg_bolt_init_v1(mg_session *session, const mg_session_params *params) { return status; } -static mg_map *build_hello_extra(const char *user_agent, const char *username, - const char *password) { +static mg_map *build_hello_extra(const char *user_agent, const char *scheme, + const char *username, const char *password) { mg_map *extra = mg_map_make_empty(4); if (!extra) { return NULL; @@ -379,40 +386,53 @@ static mg_map *build_hello_extra(const char *user_agent, const char *username, } } - assert((username && password) || (!username && !password)); - if (username) { - mg_value *scheme = mg_value_make_string("basic"); - if (!scheme || mg_map_insert_unsafe(extra, "scheme", scheme) != 0) { + // The "basic" scheme requires a username and a password/credential within the + // HELLO message. Other schemes (save for "kerberos", which is not supported + // by Memgraph) do not have such requirements: + // https://neo4j.com/docs/bolt/current/bolt/message/#messages-hello + // https://neo4j.com/docs/bolt/current/bolt/message/#messages-logon + // NOTE: HELLO message does NOT contain schema after Bolt 5.0. + if (scheme && strcmp(scheme, "basic") == 0) { + assert(username && password); + } + + if (!username && !password) { + mg_value *scheme_ = mg_value_make_string("none"); + if (!scheme_ || mg_map_insert_unsafe(extra, "scheme", scheme_) != 0) { goto cleanup; } + return extra; + } + + mg_value *scheme_ = mg_value_make_string(scheme ? scheme : "none"); // NOTE: Makes none default. + if (!scheme_ || mg_map_insert_unsafe(extra, "scheme", scheme_) != 0) { + goto cleanup; + } + if (username) { mg_value *principal = mg_value_make_string(username); if (!principal || mg_map_insert_unsafe(extra, "principal", principal)) { goto cleanup; } + } + if (password) { mg_value *credentials = mg_value_make_string(password); if (!credentials || mg_map_insert_unsafe(extra, "credentials", credentials)) { goto cleanup; } - } else { - mg_value *scheme = mg_value_make_string("none"); - if (!scheme || mg_map_insert_unsafe(extra, "scheme", scheme) != 0) { - goto cleanup; - } } return extra; - cleanup: mg_map_destroy(extra); return NULL; } int mg_bolt_init_v4(mg_session *session, const mg_session_params *params) { - mg_map *extra = - build_hello_extra(params->user_agent, params->username, params->password); + mg_map *extra = build_hello_extra(params->user_agent, params->scheme, + params->username, params->password); if (!extra) { return MG_ERROR_OOM; } diff --git a/tests/client.cpp b/tests/client.cpp index 59c3839..dad92e9 100644 --- a/tests/client.cpp +++ b/tests/client.cpp @@ -19,7 +19,6 @@ #include #include "mgclient.h" -#include "mgcommon.h" #include "mgsession.h" #include "mgsocket.h" @@ -508,9 +507,80 @@ TEST_F(ConnectTest, Success) { ASSERT_MEMORY_OK(); } -TEST_F(ConnectTest, Success_v4) { - RunServer([](int sockfd) { - // Perform handshake. +// Bolt v1 server that completes the handshake by selecting version 1, +// reads the client's INIT message verifying basic-auth credentials, and +// replies with SUCCESS. +auto make_v1_server_success_basic = [](int sockfd) { + { + char handshake[20]; + ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); + ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); + ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); + ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); + ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); + ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); + + uint32_t version = htobe32(1); + ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0); + } + + mg_session *session = mg_session_init(&mg_system_allocator); + ASSERT_TRUE(session); + session->version = 1; + mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport, + &mg_system_allocator); + + { + mg_message *message; + ASSERT_EQ(mg_session_receive_message(session), 0); + ASSERT_EQ(mg_session_read_bolt_message(session, &message), 0); + ASSERT_EQ(message->type, MG_MESSAGE_TYPE_INIT); + + mg_message_init *msg_init = message->init_v; + EXPECT_EQ( + std::string(msg_init->client_name->data, msg_init->client_name->size), + MG_USER_AGENT); + ASSERT_EQ(mg_map_size(msg_init->auth_token), 3u); + + const mg_value *scheme_val = mg_map_at(msg_init->auth_token, "scheme"); + ASSERT_TRUE(scheme_val); + ASSERT_EQ(mg_value_get_type(scheme_val), MG_VALUE_TYPE_STRING); + const mg_string *scheme = mg_value_string(scheme_val); + ASSERT_EQ(std::string(scheme->data, scheme->size), "basic"); + + const mg_value *principal_val = + mg_map_at(msg_init->auth_token, "principal"); + ASSERT_TRUE(principal_val); + ASSERT_EQ(mg_value_get_type(principal_val), MG_VALUE_TYPE_STRING); + const mg_string *principal = mg_value_string(principal_val); + ASSERT_EQ(std::string(principal->data, principal->size), "user"); + + const mg_value *credentials_val = + mg_map_at(msg_init->auth_token, "credentials"); + ASSERT_TRUE(credentials_val); + ASSERT_EQ(mg_value_get_type(credentials_val), MG_VALUE_TYPE_STRING); + const mg_string *credentials = mg_value_string(credentials_val); + ASSERT_EQ(std::string(credentials->data, credentials->size), "pass"); + + mg_message_destroy_ca(message, session->decoder_allocator); + } + + ASSERT_EQ(mg_session_send_success_message(session, &mg_empty_map), 0); + + mg_session_destroy(session); +}; + +// Bolt v4 server that completes the handshake by selecting version 4.1, +// reads the client's HELLO message, verifies the auth fields against the +// supplied expectations, and replies with SUCCESS. When expected_principal / +// expected_credentials are empty, the server expects an extra map of just +// {user_agent, scheme}; otherwise it expects all four fields. +auto make_v4_server_success(std::string expected_scheme, + std::string expected_principal = "", + std::string expected_credentials = "") { + return [expected_scheme = std::move(expected_scheme), + expected_principal = std::move(expected_principal), + expected_credentials = std::move(expected_credentials)](int sockfd) { { char handshake[20]; ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); @@ -530,7 +600,6 @@ TEST_F(ConnectTest, Success_v4) { mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport, &mg_system_allocator); - // Read HELLO message. { mg_message *message; ASSERT_EQ(mg_session_receive_message(session), 0); @@ -538,49 +607,55 @@ TEST_F(ConnectTest, Success_v4) { ASSERT_EQ(message->type, MG_MESSAGE_TYPE_HELLO); mg_message_hello *msg_hello = message->hello_v; - { - ASSERT_EQ(mg_map_size(msg_hello->extra), 4u); - - const mg_value *user_agent_val = - mg_map_at(msg_hello->extra, "user_agent"); - ASSERT_TRUE(user_agent_val); - ASSERT_EQ(mg_value_get_type(user_agent_val), MG_VALUE_TYPE_STRING); - const mg_string *user_agent = mg_value_string(user_agent_val); - ASSERT_EQ(std::string(user_agent->data, user_agent->size), - MG_USER_AGENT); - - const mg_value *scheme_val = mg_map_at(msg_hello->extra, "scheme"); - ASSERT_TRUE(scheme_val); - ASSERT_EQ(mg_value_get_type(scheme_val), MG_VALUE_TYPE_STRING); - const mg_string *scheme = mg_value_string(scheme_val); - ASSERT_EQ(std::string(scheme->data, scheme->size), "basic"); - + const bool with_creds = !expected_principal.empty(); + ASSERT_EQ(mg_map_size(msg_hello->extra), with_creds ? 4u : 2u); + + const mg_value *user_agent_val = + mg_map_at(msg_hello->extra, "user_agent"); + ASSERT_TRUE(user_agent_val); + ASSERT_EQ(mg_value_get_type(user_agent_val), MG_VALUE_TYPE_STRING); + const mg_string *user_agent = mg_value_string(user_agent_val); + ASSERT_EQ(std::string(user_agent->data, user_agent->size), MG_USER_AGENT); + + const mg_value *scheme_val = mg_map_at(msg_hello->extra, "scheme"); + ASSERT_TRUE(scheme_val); + ASSERT_EQ(mg_value_get_type(scheme_val), MG_VALUE_TYPE_STRING); + const mg_string *scheme = mg_value_string(scheme_val); + ASSERT_EQ(std::string(scheme->data, scheme->size), expected_scheme); + + if (with_creds) { const mg_value *principal_val = mg_map_at(msg_hello->extra, "principal"); ASSERT_TRUE(principal_val); ASSERT_EQ(mg_value_get_type(principal_val), MG_VALUE_TYPE_STRING); const mg_string *principal = mg_value_string(principal_val); - ASSERT_EQ(std::string(principal->data, principal->size), "user"); + ASSERT_EQ(std::string(principal->data, principal->size), + expected_principal); const mg_value *credentials_val = mg_map_at(msg_hello->extra, "credentials"); ASSERT_TRUE(credentials_val); ASSERT_EQ(mg_value_get_type(credentials_val), MG_VALUE_TYPE_STRING); const mg_string *credentials = mg_value_string(credentials_val); - ASSERT_EQ(std::string(credentials->data, credentials->size), "pass"); + ASSERT_EQ(std::string(credentials->data, credentials->size), + expected_credentials); } mg_message_destroy_ca(message, session->decoder_allocator); } - // Send SUCCESS message. ASSERT_EQ(mg_session_send_success_message(session, &mg_empty_map), 0); mg_session_destroy(session); - }); + }; +} + +TEST_F(ConnectTest, Success_v4) { + RunServer(make_v4_server_success("basic", "user", "pass")); mg_session_params *params = mg_session_params_make(); mg_session_params_set_host(params, "127.0.0.1"); mg_session_params_set_port(params, port); + mg_session_params_set_scheme(params, "basic"); mg_session_params_set_username(params, "user"); mg_session_params_set_password(params, "pass"); mg_session *session; @@ -592,70 +667,34 @@ TEST_F(ConnectTest, Success_v4) { } TEST_F(ConnectTest, SuccessWithSSL) { - RunServer([](int sockfd) { - // Perform handshake. - { - char handshake[20]; - ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); - ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); - ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); - ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); - ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); - ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); + RunServer(make_v1_server_success_basic); - uint32_t version = htobe32(1); - ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0); - } - - mg_session *session = mg_session_init(&mg_system_allocator); - ASSERT_TRUE(session); - session->version = 1; - mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport, - &mg_system_allocator); - - // Read INIT message. - { - mg_message *message; - ASSERT_EQ(mg_session_receive_message(session), 0); - ASSERT_EQ(mg_session_read_bolt_message(session, &message), 0); - ASSERT_EQ(message->type, MG_MESSAGE_TYPE_INIT); - - mg_message_init *msg_init = message->init_v; - EXPECT_EQ( - std::string(msg_init->client_name->data, msg_init->client_name->size), - MG_USER_AGENT); - { - ASSERT_EQ(mg_map_size(msg_init->auth_token), 3u); - - const mg_value *scheme_val = mg_map_at(msg_init->auth_token, "scheme"); - ASSERT_TRUE(scheme_val); - ASSERT_EQ(mg_value_get_type(scheme_val), MG_VALUE_TYPE_STRING); - const mg_string *scheme = mg_value_string(scheme_val); - ASSERT_EQ(std::string(scheme->data, scheme->size), "basic"); - - const mg_value *principal_val = - mg_map_at(msg_init->auth_token, "principal"); - ASSERT_TRUE(principal_val); - ASSERT_EQ(mg_value_get_type(principal_val), MG_VALUE_TYPE_STRING); - const mg_string *principal = mg_value_string(principal_val); - ASSERT_EQ(std::string(principal->data, principal->size), "user"); - - const mg_value *credentials_val = - mg_map_at(msg_init->auth_token, "credentials"); - ASSERT_TRUE(credentials_val); - ASSERT_EQ(mg_value_get_type(credentials_val), MG_VALUE_TYPE_STRING); - const mg_string *credentials = mg_value_string(credentials_val); - ASSERT_EQ(std::string(credentials->data, credentials->size), "pass"); - } - - mg_message_destroy_ca(message, session->decoder_allocator); - } + mg_secure_transport_init_called = 0; + trust_callback_ok = 0; - // Send SUCCESS message. - ASSERT_EQ(mg_session_send_success_message(session, &mg_empty_map), 0); + mg_session_params *params = mg_session_params_make(); + mg_session_params_set_host(params, "localhost"); + mg_session_params_set_port(params, port); + mg_session_params_set_username(params, "user"); + mg_session_params_set_password(params, "pass"); + mg_session_params_set_sslmode(params, MG_SSLMODE_REQUIRE); + mg_session_params_set_sslcert(params, "/path/to/cert"); + mg_session_params_set_sslkey(params, "/path/to/key"); + mg_session_params_set_trust_callback(params, trust_callback); + int trust_data = 42; + mg_session_params_set_trust_data(params, (void *)&trust_data); + mg_session *session; + ASSERT_EQ(mg_connect_ca(params, &session, (mg_allocator *)&allocator), 0); + ASSERT_EQ(mg_secure_transport_init_called, 1); + ASSERT_EQ(trust_callback_ok, 1); + EXPECT_EQ(mg_session_status(session), MG_SESSION_READY); + mg_session_params_destroy(params); + mg_session_destroy(session); + ASSERT_MEMORY_OK(); +} - mg_session_destroy(session); - }); +TEST_F(ConnectTest, SuccessWithSSL_v4) { + RunServer(make_v4_server_success("basic", "user", "pass")); mg_secure_transport_init_called = 0; trust_callback_ok = 0; @@ -663,6 +702,7 @@ TEST_F(ConnectTest, SuccessWithSSL) { mg_session_params *params = mg_session_params_make(); mg_session_params_set_host(params, "localhost"); mg_session_params_set_port(params, port); + mg_session_params_set_scheme(params, "basic"); mg_session_params_set_username(params, "user"); mg_session_params_set_password(params, "pass"); mg_session_params_set_sslmode(params, MG_SSLMODE_REQUIRE); @@ -681,6 +721,35 @@ TEST_F(ConnectTest, SuccessWithSSL) { ASSERT_MEMORY_OK(); } +TEST_F(ConnectTest, CustomScheme) { + RunServer(make_v4_server_success("custom_scheme", "user", "pass")); + mg_session_params *params = mg_session_params_make(); + mg_session_params_set_host(params, "127.0.0.1"); + mg_session_params_set_port(params, port); + mg_session_params_set_scheme(params, "custom_scheme"); + mg_session_params_set_username(params, "user"); + mg_session_params_set_password(params, "pass"); + mg_session *session; + ASSERT_EQ(mg_connect_ca(params, &session, (mg_allocator *)&allocator), 0); + EXPECT_EQ(mg_session_status(session), MG_SESSION_READY); + mg_session_params_destroy(params); + mg_session_destroy(session); + ASSERT_MEMORY_OK(); +} + +TEST_F(ConnectTest, SuccessNoAuth_v4) { + RunServer(make_v4_server_success("none")); + mg_session_params *params = mg_session_params_make(); + mg_session_params_set_host(params, "127.0.0.1"); + mg_session_params_set_port(params, port); + mg_session *session; + ASSERT_EQ(mg_connect_ca(params, &session, (mg_allocator *)&allocator), 0); + EXPECT_EQ(mg_session_status(session), MG_SESSION_READY); + mg_session_params_destroy(params); + mg_session_destroy(session); + ASSERT_MEMORY_OK(); +} + class RunTest : public ::testing::Test { protected: virtual void SetUp() override { diff --git a/tests/integration/basic_cpp.cpp b/tests/integration/basic_cpp.cpp index b5ff5a8..adc9264 100644 --- a/tests/integration/basic_cpp.cpp +++ b/tests/integration/basic_cpp.cpp @@ -35,7 +35,7 @@ class MemgraphConnection : public ::testing::Test { client = mg::Client::Connect( {GetEnvOrDefault("MEMGRAPH_HOST", "127.0.0.1"), - GetEnvOrDefault("MEMGRAPH_PORT", 7687), "", "", + GetEnvOrDefault("MEMGRAPH_PORT", 7687), "basic", "", "", GetEnvOrDefault("MEMGRAPH_SSLMODE", false), ""}); ASSERT_TRUE(client);