Commit 9c52365fc3a72b4e0add781c388032869c081473

Stefan Sperling 2020-03-21T10:37:19

properly terminate the ssh process after fetching via SSH

diff --git a/got/got.c b/got/got.c
index a1aea8a..983a4a7 100644
--- a/got/got.c
+++ b/got/got.c
@@ -940,7 +940,8 @@ cmd_clone(int argc, char *argv[])
 	struct got_pathlist_head refs, symrefs, wanted_branches;
 	struct got_pathlist_entry *pe;
 	struct got_object_id *pack_hash = NULL;
-	int ch, fetchfd = -1;
+	int ch, fetchfd = -1, fetchstatus;
+	pid_t fetchpid = -1;
 	struct got_fetch_progress_arg fpa;
 	char *git_url = NULL;
 	char *gitconfig_path = NULL;
@@ -1077,8 +1078,8 @@ cmd_clone(int argc, char *argv[])
 	if (error)
 		goto done;
 
-	error = got_fetch_connect(&fetchfd, proto, host, port, server_path,
-	    verbosity);
+	error = got_fetch_connect(&fetchpid, &fetchfd, proto, host, port,
+	    server_path, verbosity);
 	if (error)
 		goto done;
 
@@ -1266,6 +1267,12 @@ cmd_clone(int argc, char *argv[])
 		printf("Created %s repository '%s'\n",
 		    mirror_references ? "mirrored" : "cloned", repo_path);
 done:
+	if (fetchpid > 0) {
+		if (kill(fetchpid, SIGTERM) == -1)
+			error = got_error_from_errno("kill");
+		if (waitpid(fetchpid, &fetchstatus, 0) == -1 && error == NULL)
+			error = got_error_from_errno("waitpid");
+	}
 	if (fetchfd != -1 && close(fetchfd) == -1 && error == NULL)
 		error = got_error_from_errno("close");
 	if (gitconfig_file && fclose(gitconfig_file) == EOF && error == NULL)
@@ -1387,7 +1394,8 @@ cmd_fetch(int argc, char *argv[])
 	struct got_pathlist_head refs, symrefs, wanted_branches;
 	struct got_pathlist_entry *pe;
 	struct got_object_id *pack_hash = NULL;
-	int i, ch, fetchfd = -1;
+	int i, ch, fetchfd = -1, fetchstatus;
+	pid_t fetchpid = -1;
 	struct got_fetch_progress_arg fpa;
 	int verbosity = 0, fetch_all_branches = 0, list_refs_only = 0;
 
@@ -1532,8 +1540,8 @@ cmd_fetch(int argc, char *argv[])
 	if (error)
 		goto done;
 
-	error = got_fetch_connect(&fetchfd, proto, host, port, server_path,
-	    verbosity);
+	error = got_fetch_connect(&fetchpid, &fetchfd, proto, host, port,
+	    server_path, verbosity);
 	if (error)
 		goto done;
 
@@ -1634,6 +1642,12 @@ cmd_fetch(int argc, char *argv[])
 		id_str = NULL;
 	}
 done:
+	if (fetchpid > 0) {
+		if (kill(fetchpid, SIGTERM) == -1)
+			error = got_error_from_errno("kill");
+		if (waitpid(fetchpid, &fetchstatus, 0) == -1 && error == NULL)
+			error = got_error_from_errno("waitpid");
+	}
 	if (fetchfd != -1 && close(fetchfd) == -1 && error == NULL)
 		error = got_error_from_errno("close");
 	if (repo)
diff --git a/include/got_fetch.h b/include/got_fetch.h
index 9e26b62..a44351f 100644
--- a/include/got_fetch.h
+++ b/include/got_fetch.h
@@ -43,11 +43,16 @@ const struct got_error *got_fetch_parse_uri(char **, char **, char **,
  * A verbosity level can be specified; it currently controls the amount
  * of -v options passed to ssh(1). If the level is -1 ssh(1) will be run
  * with the -q option.
+ *
  * If successful return an open file descriptor for the connection which can
  * be passed to other functions below, and must be disposed of with close(2).
+ *
+ * If an ssh(1) process was started return its PID as well, in which case
+ * the caller should eventually send SIGTERM to the procress and wait for
+ * the process to exit with waitpid(2). Otherwise, return PID -1.
  */
-const struct got_error *got_fetch_connect(int *, const char *, const char *,
-    const char *, const char *, int);
+const struct got_error *got_fetch_connect(pid_t *, int *, const char *,
+    const char *, const char *, const char *, int);
 
 /* A callback function which gets invoked with progress information to print. */
 typedef const struct got_error *(*got_fetch_progress_cb)(void *,
diff --git a/lib/fetch.c b/lib/fetch.c
index fcd9bb0..1cdd3ca 100644
--- a/lib/fetch.c
+++ b/lib/fetch.c
@@ -83,8 +83,8 @@ hassuffix(char *base, char *suf)
 }
 
 static const struct got_error *
-dial_ssh(int *fetchfd, const char *host, const char *port, const char *path,
-    const char *direction, int verbosity)
+dial_ssh(pid_t *fetchpid, int *fetchfd, const char *host, const char *port,
+    const char *path, const char *direction, int verbosity)
 {
 	const struct got_error *error = NULL;
 	int pid, pfd[2];
@@ -92,6 +92,7 @@ dial_ssh(int *fetchfd, const char *host, const char *port, const char *path,
 	char *argv[11];
 	int i = 0;
 
+	*fetchpid = -1;
 	*fetchfd = -1;
 
 	if (port == NULL)
@@ -135,6 +136,7 @@ dial_ssh(int *fetchfd, const char *host, const char *port, const char *path,
 		abort(); /* not reached */
 	} else {
 		close(pfd[0]);
+		*fetchpid = pid;
 		*fetchfd = pfd[1];
 		return NULL;
 	}
@@ -212,16 +214,17 @@ done:
 }
 
 const struct got_error *
-got_fetch_connect(int *fetchfd, const char *proto, const char *host,
-    const char *port, const char *server_path, int verbosity)
+got_fetch_connect(pid_t *fetchpid, int *fetchfd, const char *proto,
+    const char *host, const char *port, const char *server_path, int verbosity)
 {
 	const struct got_error *err = NULL;
 
+	*fetchpid = -1;
 	*fetchfd = -1;
 
 	if (strcmp(proto, "ssh") == 0 || strcmp(proto, "git+ssh") == 0)
-		err = dial_ssh(fetchfd, host, port, server_path, "upload",
-		    verbosity);
+		err = dial_ssh(fetchpid, fetchfd, host, port, server_path,
+		    "upload", verbosity);
 	else if (strcmp(proto, "git") == 0)
 		err = dial_git(fetchfd, host, port, server_path, "upload");
 	else if (strcmp(proto, "http") == 0 || strcmp(proto, "git+http") == 0)