diff --git a/src/main-proc.c b/src/main-proc.c index e6f7cf10..f64cb28b 100644 --- a/src/main-proc.c +++ b/src/main-proc.c @@ -94,6 +94,8 @@ struct proc_st *ctmp; */ void remove_proc(main_server_st * s, struct proc_st *proc, unsigned flags) { + pid_t pid; + ev_io_stop(EV_A_ &proc->io); ev_child_stop(EV_A_ &proc->ev_child); @@ -114,9 +116,22 @@ void remove_proc(main_server_st * s, struct proc_st *proc, unsigned flags) mslog(s, proc, LOG_INFO, "user disconnected (reason: %s, rx: %"PRIu64", tx: %"PRIu64")", discon_reason_to_str(proc->discon_reason), proc->bytes_in, proc->bytes_out); - remove_from_script_list(s, proc); - if (proc->status == PS_AUTH_COMPLETED) { - user_disconnected(s, proc); + pid = remove_from_script_list(s, proc); + if (proc->status == PS_AUTH_COMPLETED || pid > 0) { + if (pid > 0) { + int wstatus; + /* we were called during the connect script being run. + * wait for it to finish and if it returns zero run the + * disconnect script */ + if (waitpid(pid, &wstatus, 0) > 0) { + if (WEXITSTATUS(wstatus) == 0) + user_disconnected(s, proc); + } + } else { /* pid > 0 or status == PS_AUTH_COMPLETED are mutually exclusive + * since PS_AUTH_COMPLETED is set only after a successful script run. + */ + user_disconnected(s, proc); + } } /* close the intercomm fd */ diff --git a/src/script-list.h b/src/script-list.h index f573fc9f..2dab7e05 100644 --- a/src/script-list.h +++ b/src/script-list.h @@ -22,6 +22,8 @@ # define SCRIPT_LIST_H #include +#include +#include #include void script_child_watcher_cb(struct ev_loop *loop, ev_child *w, int revents); @@ -44,17 +46,27 @@ struct script_wait_st *stmp; list_add(&s->script_list.head, &(stmp->list)); } -inline static void remove_from_script_list(main_server_st* s, struct proc_st* proc) +/* Removes the tracked connect script, and kills it. It returns the pid + * of the removed script or -1. + */ +inline static pid_t remove_from_script_list(main_server_st* s, struct proc_st* proc) { struct script_wait_st *stmp = NULL, *spos; + pid_t ret = -1; list_for_each_safe(&s->script_list.head, stmp, spos, list) { if (stmp->proc == proc) { list_del(&stmp->list); + if (stmp->pid > 0) { + kill(stmp->pid, SIGTERM); + ret = stmp->pid; + } talloc_free(stmp); break; } } + + return ret; } #endif