From 33bcfb178e15b261408ef4ece76ebbace99e7888 Mon Sep 17 00:00:00 2001 From: Nikos Mavrogiannopoulos Date: Wed, 13 May 2015 14:04:20 +0200 Subject: [PATCH] main: use two sockets to communicate with sec-mod That allows to have a reliable synchronous socket, and a socket where messages are sent and received asynchronously. --- src/main-misc.c | 19 ++++++++++++++++--- src/main-sec-mod-cmd.c | 8 ++++---- src/main.c | 2 +- src/main.h | 5 +++-- src/sec-mod-auth.c | 36 ++++++++++++++++++------------------ src/sec-mod.c | 30 ++++++++++++++++++++---------- src/sec-mod.h | 5 +++-- 7 files changed, 65 insertions(+), 40 deletions(-) diff --git a/src/main-misc.c b/src/main-misc.c index 5c960459..ae94d025 100644 --- a/src/main-misc.c +++ b/src/main-misc.c @@ -595,11 +595,14 @@ int handle_commands(main_server_st * s, struct proc_st *proc) return ret; } -/* Returns a file descriptor to be used for communication with sec-mod +/* Returns two file descriptors to be used for communication with sec-mod. + * The sync_fd is used by main to send synchronous commands- commands which + * expect a reply immediately. */ -int run_sec_mod(main_server_st * s) +int run_sec_mod(main_server_st * s, int *sync_fd) { int e, fd[2], ret; + int sfd[2]; pid_t pid; const char *p; @@ -621,6 +624,12 @@ int run_sec_mod(main_server_st * s) exit(1); } + ret = socketpair(AF_UNIX, SOCK_STREAM, 0, sfd); + if (ret < 0) { + mslog(s, NULL, LOG_ERR, "error creating sec-mod sync command socket"); + exit(1); + } + pid = fork(); if (pid == 0) { /* child */ clear_lists(s); @@ -633,13 +642,17 @@ int run_sec_mod(main_server_st * s) #endif setproctitle(PACKAGE_NAME "-secmod"); close(fd[1]); + close(sfd[1]); set_cloexec_flag (fd[0], 1); - sec_mod_server(s->main_pool, s->perm_config, p, s->cookie_key, fd[0]); + set_cloexec_flag (sfd[0], 1); + sec_mod_server(s->main_pool, s->perm_config, p, s->cookie_key, fd[0], sfd[0]); exit(0); } else if (pid > 0) { /* parent */ close(fd[0]); s->sec_mod_pid = pid; set_cloexec_flag (fd[1], 1); + set_cloexec_flag (sfd[1], 1); + *sync_fd = sfd[1]; return fd[1]; } else { e = errno; diff --git a/src/main-sec-mod-cmd.c b/src/main-sec-mod-cmd.c index d4a21e08..3c1930b4 100644 --- a/src/main-sec-mod-cmd.c +++ b/src/main-sec-mod-cmd.c @@ -209,7 +209,7 @@ int session_open(main_server_st * s, struct proc_st *proc, const uint8_t *cookie mslog(s, proc, LOG_DEBUG, "sending msg %s to sec-mod", cmd_request_to_str(SM_CMD_AUTH_SESSION_OPEN)); - ret = send_msg(proc, s->sec_mod_fd, SM_CMD_AUTH_SESSION_OPEN, + ret = send_msg(proc, s->sec_mod_fd_sync, SM_CMD_AUTH_SESSION_OPEN, &ireq, (pack_size_func)sec_auth_session_msg__get_packed_size, (pack_func)sec_auth_session_msg__pack); if (ret < 0) { @@ -218,7 +218,7 @@ int session_open(main_server_st * s, struct proc_st *proc, const uint8_t *cookie return -1; } - ret = recv_msg(proc, s->sec_mod_fd, SM_CMD_AUTH_SESSION_REPLY, + ret = recv_msg(proc, s->sec_mod_fd_sync, SM_CMD_AUTH_SESSION_REPLY, (void *)&msg, (unpack_func) sec_auth_session_reply_msg__unpack, MAIN_SEC_MOD_TIMEOUT); if (ret < 0) { e = errno; @@ -340,7 +340,7 @@ int session_close(main_server_st * s, struct proc_st *proc) mslog(s, proc, LOG_DEBUG, "sending msg %s to sec-mod", cmd_request_to_str(SM_CMD_AUTH_SESSION_CLOSE)); - ret = send_msg(proc, s->sec_mod_fd, SM_CMD_AUTH_SESSION_CLOSE, + ret = send_msg(proc, s->sec_mod_fd_sync, SM_CMD_AUTH_SESSION_CLOSE, &ireq, (pack_size_func)sec_auth_session_msg__get_packed_size, (pack_func)sec_auth_session_msg__pack); if (ret < 0) { @@ -349,7 +349,7 @@ int session_close(main_server_st * s, struct proc_st *proc) return -1; } - ret = recv_msg(proc, s->sec_mod_fd, SM_CMD_AUTH_CLI_STATS, + ret = recv_msg(proc, s->sec_mod_fd_sync, SM_CMD_AUTH_CLI_STATS, (void *)&msg, (unpack_func) cli_stats_msg__unpack, MAIN_SEC_MOD_TIMEOUT); if (ret < 0) { e = errno; diff --git a/src/main.c b/src/main.c index 7d08a709..10a29069 100644 --- a/src/main.c +++ b/src/main.c @@ -1030,7 +1030,7 @@ int main(int argc, char** argv) write_pid_file(); - s->sec_mod_fd = run_sec_mod(s); + s->sec_mod_fd = run_sec_mod(s, &s->sec_mod_fd_sync); ret = ctl_handler_init(s); if (ret < 0) { diff --git a/src/main.h b/src/main.h index 65051c28..84bbcd21 100644 --- a/src/main.h +++ b/src/main.h @@ -204,7 +204,8 @@ typedef struct main_server_st { #else int ctl_fd; #endif - int sec_mod_fd; + int sec_mod_fd; /* messages are sent and received async */ + int sec_mod_fd_sync; /* messages are received in a sync order (ping-pong) */ void *main_pool; /* talloc main pool */ } main_server_st; @@ -263,7 +264,7 @@ int handle_auth_cookie_req(main_server_st* s, struct proc_st* proc, int check_multiple_users(main_server_st *s, struct proc_st* proc); int handle_script_exit(main_server_st *s, struct proc_st* proc, int code); -int run_sec_mod(main_server_st * s); +int run_sec_mod(main_server_st * s, int *sync_fd); struct proc_st *new_proc(main_server_st * s, pid_t pid, int cmd_fd, struct sockaddr_storage *remote_addr, socklen_t remote_addr_len, diff --git a/src/sec-mod-auth.c b/src/sec-mod-auth.c index 668d4416..a7f2fdb5 100644 --- a/src/sec-mod-auth.c +++ b/src/sec-mod-auth.c @@ -367,7 +367,7 @@ static void stats_add_to(stats_st *dst, stats_st *src1, stats_st *src2) } static -int send_failed_session_open_reply(sec_mod_st *sec) +int send_failed_session_open_reply(sec_mod_st *sec, int fd) { SecAuthSessionReplyMsg rep = SEC_AUTH_SESSION_REPLY_MSG__INIT; void *lpool; @@ -380,7 +380,7 @@ int send_failed_session_open_reply(sec_mod_st *sec) return ERR_BAD_COMMAND; } - ret = send_msg(lpool, sec->cmd_fd, SM_CMD_AUTH_SESSION_REPLY, &rep, + ret = send_msg(lpool, fd, SM_CMD_AUTH_SESSION_REPLY, &rep, (pack_size_func) sec_auth_session_reply_msg__get_packed_size, (pack_func) sec_auth_session_reply_msg__pack); if (ret < 0) { @@ -393,7 +393,7 @@ int send_failed_session_open_reply(sec_mod_st *sec) } static -int handle_sec_auth_session_open(sec_mod_st *sec, const SecAuthSessionMsg *req) +int handle_sec_auth_session_open(sec_mod_st *sec, int fd, const SecAuthSessionMsg *req) { client_entry_st *e; void *lpool; @@ -403,7 +403,7 @@ int handle_sec_auth_session_open(sec_mod_st *sec, const SecAuthSessionMsg *req) if (req->sid.len != SID_SIZE) { seclog(sec, LOG_ERR, "auth session open but with illegal sid size (%d)!", (int)req->sid.len); - return send_failed_session_open_reply(sec); + return send_failed_session_open_reply(sec, fd); } e = find_client_entry(sec, req->sid.data); @@ -411,25 +411,25 @@ int handle_sec_auth_session_open(sec_mod_st *sec, const SecAuthSessionMsg *req) char tmp[BASE64_LENGTH(SID_SIZE) + 1]; base64_encode((char *)req->sid.data, req->sid.len, (char *)tmp, sizeof(tmp)); seclog(sec, LOG_INFO, "session open but with non-existing SID: %s!", tmp); - return send_failed_session_open_reply(sec); + return send_failed_session_open_reply(sec, fd); } if (e->status != PS_AUTH_COMPLETED) { seclog(sec, LOG_ERR, "session open received in unauthenticated client %s "SESSION_STR"!", e->auth_info.username, e->auth_info.psid); - return send_failed_session_open_reply(sec); + return send_failed_session_open_reply(sec, fd); } if (e->time != -1 && time(0) > e->time + sec->config->cookie_timeout) { seclog(sec, LOG_ERR, "session expired; denied session for user '%s' "SESSION_STR, e->auth_info.username, e->auth_info.psid); e->status = PS_AUTH_FAILED; - return send_failed_session_open_reply(sec); + return send_failed_session_open_reply(sec, fd); } if (req->has_cookie == 0 || (req->cookie.len != e->cookie_size) || memcmp(req->cookie.data, e->cookie, e->cookie_size) != 0) { seclog(sec, LOG_ERR, "cookie error; denied session for user '%s' "SESSION_STR, e->auth_info.username, e->auth_info.psid); e->status = PS_AUTH_FAILED; - return send_failed_session_open_reply(sec); + return send_failed_session_open_reply(sec, fd); } if (req->ipv4) @@ -442,7 +442,7 @@ int handle_sec_auth_session_open(sec_mod_st *sec, const SecAuthSessionMsg *req) if (ret < 0) { e->status = PS_AUTH_FAILED; seclog(sec, LOG_INFO, "denied session for user '%s' "SESSION_STR, e->auth_info.username, e->auth_info.psid); - return send_failed_session_open_reply(sec); + return send_failed_session_open_reply(sec, fd); } else { e->session_is_open = 1; } @@ -460,7 +460,7 @@ int handle_sec_auth_session_open(sec_mod_st *sec, const SecAuthSessionMsg *req) if (ret < 0) { seclog(sec, LOG_ERR, "error reading additional configuration for '%s' "SESSION_STR, e->auth_info.username, e->auth_info.psid); talloc_free(lpool); - return send_failed_session_open_reply(sec); + return send_failed_session_open_reply(sec, fd); } } @@ -470,7 +470,7 @@ int handle_sec_auth_session_open(sec_mod_st *sec, const SecAuthSessionMsg *req) rep.has_interim_update_secs = 1; } - ret = send_msg(lpool, sec->cmd_fd, SM_CMD_AUTH_SESSION_REPLY, &rep, + ret = send_msg(lpool, fd, SM_CMD_AUTH_SESSION_REPLY, &rep, (pack_size_func) sec_auth_session_reply_msg__get_packed_size, (pack_func) sec_auth_session_reply_msg__pack); if (ret < 0) { @@ -487,7 +487,7 @@ int handle_sec_auth_session_open(sec_mod_st *sec, const SecAuthSessionMsg *req) } static -int handle_sec_auth_session_close(sec_mod_st *sec, const SecAuthSessionMsg *req) +int handle_sec_auth_session_close(sec_mod_st *sec, int fd, const SecAuthSessionMsg *req) { client_entry_st *e; int ret; @@ -504,14 +504,14 @@ int handle_sec_auth_session_close(sec_mod_st *sec, const SecAuthSessionMsg *req) char tmp[BASE64_LENGTH(SID_SIZE) + 1]; base64_encode((char *)req->sid.data, req->sid.len, (char *)tmp, sizeof(tmp)); seclog(sec, LOG_INFO, "session close but with non-existing SID: %s", tmp); - return send_msg(e, sec->cmd_fd, SM_CMD_AUTH_CLI_STATS, &rep, + return send_msg(e, fd, SM_CMD_AUTH_CLI_STATS, &rep, (pack_size_func) cli_stats_msg__get_packed_size, (pack_func) cli_stats_msg__pack); } if (e->status < PS_AUTH_COMPLETED) { seclog(sec, LOG_DEBUG, "session close received in unauthenticated client %s "SESSION_STR"!", e->auth_info.username, e->auth_info.psid); - return send_msg(e, sec->cmd_fd, SM_CMD_AUTH_CLI_STATS, &rep, + return send_msg(e, fd, SM_CMD_AUTH_CLI_STATS, &rep, (pack_size_func) cli_stats_msg__get_packed_size, (pack_func) cli_stats_msg__pack); } @@ -533,7 +533,7 @@ int handle_sec_auth_session_close(sec_mod_st *sec, const SecAuthSessionMsg *req) rep.has_secmod_client_entries = 1; rep.secmod_client_entries = sec_mod_client_db_elems(sec); - ret = send_msg(e, sec->cmd_fd, SM_CMD_AUTH_CLI_STATS, &rep, + ret = send_msg(e, fd, SM_CMD_AUTH_CLI_STATS, &rep, (pack_size_func) cli_stats_msg__get_packed_size, (pack_func) cli_stats_msg__pack); if (ret < 0) { @@ -550,13 +550,13 @@ int handle_sec_auth_session_close(sec_mod_st *sec, const SecAuthSessionMsg *req) } -int handle_sec_auth_session_cmd(sec_mod_st *sec, const SecAuthSessionMsg *req, +int handle_sec_auth_session_cmd(sec_mod_st *sec, int fd, const SecAuthSessionMsg *req, unsigned cmd) { if (cmd == SM_CMD_AUTH_SESSION_OPEN) - return handle_sec_auth_session_open(sec, req); + return handle_sec_auth_session_open(sec, fd, req); else - return handle_sec_auth_session_close(sec, req); + return handle_sec_auth_session_close(sec, fd, req); } void handle_sec_auth_ban_ip_reply(sec_mod_st *sec, const BanIpReplyMsg *msg) diff --git a/src/sec-mod.c b/src/sec-mod.c index 80269ed4..df800cb3 100644 --- a/src/sec-mod.c +++ b/src/sec-mod.c @@ -304,7 +304,7 @@ int process_packet(void *pool, int cfd, pid_t pid, sec_mod_st * sec, cmd_request } static -int process_packet_from_main(void *pool, sec_mod_st * sec, cmd_request_t cmd, +int process_packet_from_main(void *pool, int fd, sec_mod_st * sec, cmd_request_t cmd, uint8_t * buffer, size_t buffer_size) { gnutls_datum_t data; @@ -345,7 +345,7 @@ int process_packet_from_main(void *pool, sec_mod_st * sec, cmd_request_t cmd, return ERR_BAD_COMMAND; } - ret = handle_sec_auth_session_cmd(sec, msg, cmd); + ret = handle_sec_auth_session_cmd(sec, fd, msg, cmd); sec_auth_session_msg__free_unpacked(msg, &pa); return ret; @@ -404,7 +404,7 @@ static void check_other_work(sec_mod_st *sec) } static -int serve_request_main(sec_mod_st *sec, uint8_t *buffer, unsigned buffer_size) +int serve_request_main(sec_mod_st *sec, int fd, uint8_t *buffer, unsigned buffer_size) { int ret, e; unsigned cmd, length; @@ -412,7 +412,7 @@ int serve_request_main(sec_mod_st *sec, uint8_t *buffer, unsigned buffer_size) void *pool = buffer; /* read request */ - ret = force_read_timeout(sec->cmd_fd, buffer, 3, MAIN_SEC_MOD_TIMEOUT); + ret = force_read_timeout(fd, buffer, 3, MAIN_SEC_MOD_TIMEOUT); if (ret == 0) goto leave; else if (ret < 3) { @@ -441,7 +441,7 @@ int serve_request_main(sec_mod_st *sec, uint8_t *buffer, unsigned buffer_size) } /* read the body */ - ret = force_read_timeout(sec->cmd_fd, buffer, length, MAIN_SEC_MOD_TIMEOUT); + ret = force_read_timeout(fd, buffer, length, MAIN_SEC_MOD_TIMEOUT); if (ret < 0) { e = errno; seclog(sec, LOG_ERR, "error receiving msg body of cmd %u with length %u: %s", @@ -450,7 +450,7 @@ int serve_request_main(sec_mod_st *sec, uint8_t *buffer, unsigned buffer_size) goto leave; } - ret = process_packet_from_main(pool, sec, cmd, buffer, ret); + ret = process_packet_from_main(pool, fd, sec, cmd, buffer, ret); if (ret < 0) { seclog(sec, LOG_ERR, "error processing data for '%s' command (%d)", cmd_request_to_str(cmd), ret); } @@ -512,6 +512,7 @@ int serve_request(sec_mod_st *sec, int cfd, pid_t pid, uint8_t *buffer, unsigned * @config: server configuration * @socket_file: the name of the socket * @cmd_fd: socket to exchange commands with main + * @cmd_fd_sync: socket to received sync commands from main * * This is the main part of the security module. * It creates the unix domain socket identified by @socket_file @@ -537,11 +538,11 @@ int serve_request(sec_mod_st *sec, int cfd, pid_t pid, uint8_t *buffer, unsigned * key operations. */ void sec_mod_server(void *main_pool, struct perm_cfg_st *perm_config, const char *socket_file, - uint8_t cookie_key[COOKIE_KEY_SIZE], int cmd_fd) + uint8_t cookie_key[COOKIE_KEY_SIZE], int cmd_fd, int cmd_fd_sync) { struct sockaddr_un sa; socklen_t sa_len; - int cfd, ret, e, n; + int cfd, ret, e, n, tfd; unsigned i, buffer_size; uid_t uid; uint8_t *buffer; @@ -600,6 +601,7 @@ void sec_mod_server(void *main_pool, struct perm_cfg_st *perm_config, const char sec_auth_init(sec, perm_config); sec->cmd_fd = cmd_fd; + sec->cmd_fd_sync = cmd_fd_sync; #ifdef HAVE_PKCS11 ret = gnutls_pkcs11_reinit(); @@ -706,6 +708,9 @@ void sec_mod_server(void *main_pool, struct perm_cfg_st *perm_config, const char FD_SET(cmd_fd, &rd_set); n = MAX(n, cmd_fd); + FD_SET(cmd_fd_sync, &rd_set); + n = MAX(n, cmd_fd_sync); + FD_SET(sd, &rd_set); n = MAX(n, sd); @@ -726,13 +731,18 @@ void sec_mod_server(void *main_pool, struct perm_cfg_st *perm_config, const char exit(1); } - if (FD_ISSET(cmd_fd, &rd_set)) { + if (FD_ISSET(cmd_fd, &rd_set) || FD_ISSET(cmd_fd_sync, &rd_set)) { + if (FD_ISSET(cmd_fd_sync, &rd_set)) + tfd = cmd_fd_sync; + else + tfd = cmd_fd; + buffer_size = MAX_MSG_SIZE; buffer = talloc_size(sec, buffer_size); if (buffer == NULL) { seclog(sec, LOG_ERR, "error in memory allocation"); } else { - ret = serve_request_main(sec, buffer, buffer_size); + ret = serve_request_main(sec, tfd, buffer, buffer_size); if (ret < 0 && ret == ERR_BAD_COMMAND) { seclog(sec, LOG_ERR, "error processing command from main"); exit(1); diff --git a/src/sec-mod.h b/src/sec-mod.h index 9f71a0f9..28e0febf 100644 --- a/src/sec-mod.h +++ b/src/sec-mod.h @@ -38,6 +38,7 @@ typedef struct sec_mod_st { unsigned key_size; struct htable *client_db; int cmd_fd; + int cmd_fd_sync; struct config_mod_st *config_module; } sec_mod_st; @@ -127,11 +128,11 @@ void sec_auth_init(sec_mod_st *sec, struct perm_cfg_st *config); void handle_sec_auth_ban_ip_reply(sec_mod_st *sec, const BanIpReplyMsg *msg); int handle_sec_auth_init(int cfd, sec_mod_st *sec, const SecAuthInitMsg * req, pid_t pid); int handle_sec_auth_cont(int cfd, sec_mod_st *sec, const SecAuthContMsg * req); -int handle_sec_auth_session_cmd(sec_mod_st *sec, const SecAuthSessionMsg *req, unsigned cmd); +int handle_sec_auth_session_cmd(sec_mod_st *sec, int fd, const SecAuthSessionMsg *req, unsigned cmd); int handle_sec_auth_stats_cmd(sec_mod_st * sec, const CliStatsMsg * req); void sec_auth_user_deinit(sec_mod_st * sec, client_entry_st * e); void sec_mod_server(void *main_pool, struct perm_cfg_st *config, const char *socket_file, - uint8_t cookie_key[COOKIE_KEY_SIZE], int cmd_fd); + uint8_t cookie_key[COOKIE_KEY_SIZE], int cmd_fd, int cmd_fd_sync); #endif