Commit 20eb36d0c1c6656066cb9febb9e6988c1eac5c2a

Stefan Sperling 2020-03-18T16:11:27

attempt to connect to a server before creating a local repo

diff --git a/got/got.c b/got/got.c
index c3c32e3..14911d3 100644
--- a/got/got.c
+++ b/got/got.c
@@ -980,7 +980,7 @@ cmd_clone(int argc, char *argv[])
 	struct got_pathlist_head refs, symrefs;
 	struct got_pathlist_entry *pe;
 	struct got_object_id *pack_hash = NULL;
-	int ch;
+	int ch, fetchfd = -1;
 
 	TAILQ_INIT(&refs);
 	TAILQ_INIT(&symrefs);
@@ -1010,6 +1010,10 @@ cmd_clone(int argc, char *argv[])
 	if (err)
 		goto done;
 
+	err = got_fetch_connect(&fetchfd, proto, host, port, server_path);
+	if (err)
+		goto done;
+
 	if (dirname == NULL) {
 		if (asprintf(&default_destdir, "%s.git", repo_name) == -1) {
 			err = got_error_from_errno("asprintf");
@@ -1031,7 +1035,7 @@ cmd_clone(int argc, char *argv[])
 	if (err)
 		goto done;
 
-	err = got_fetch(&pack_hash, &refs, &symrefs,
+	err = got_fetch(&pack_hash, &refs, &symrefs, fetchfd,
 	    proto, host, port, server_path, repo_name, branch_filter, repo);
 	if (err)
 		goto done;
@@ -1096,6 +1100,8 @@ cmd_clone(int argc, char *argv[])
 	}
 
 done:
+	if (fetchfd != -1 && close(fetchfd) == -1 && err == NULL)
+		err = got_error_from_errno("close");
 	if (repo)
 		got_repo_close(repo);
 	TAILQ_FOREACH(pe, &refs, entry) {
diff --git a/include/got_fetch.h b/include/got_fetch.h
index 7b36650..600d838 100644
--- a/include/got_fetch.h
+++ b/include/got_fetch.h
@@ -21,7 +21,10 @@
 const struct got_error *got_fetch_parse_uri(char **, char **, char **,
     char **, char **, const char *);
 
+const struct got_error *got_fetch_connect(int *, const char *, const char *,
+    const char *, const char *);
+
 const struct got_error *got_fetch(struct got_object_id **,
-	struct got_pathlist_head *, struct got_pathlist_head *,
+	struct got_pathlist_head *, struct got_pathlist_head *, int,
 	const char *, const char *, const char *, const char *,
 	const char *, const char *, struct got_repository *);
diff --git a/lib/fetch.c b/lib/fetch.c
index 93751d1..617c092 100644
--- a/lib/fetch.c
+++ b/lib/fetch.c
@@ -183,6 +183,25 @@ done:
 }
 
 const struct got_error *
+got_fetch_connect(int *fetchfd, const char *proto, const char *host,
+    const char *port, const char *server_path)
+{
+	const struct got_error *err = NULL;
+
+	*fetchfd = -1;
+
+	if (strcmp(proto, "ssh") == 0 || strcmp(proto, "git+ssh") == 0)
+		err = dial_ssh(fetchfd, host, port, server_path, "upload");
+	else if (strcmp(proto, "git") == 0)
+		err = dial_git(fetchfd, host, port, server_path, "upload");
+	else if (strcmp(proto, "http") == 0 || strcmp(proto, "git+http") == 0)
+		err = got_error_path(proto, GOT_ERR_NOT_IMPL);
+	else
+		err = got_error_path(proto, GOT_ERR_BAD_PROTO);
+	return err;
+}
+
+const struct got_error *
 got_fetch_parse_uri(char **proto, char **host, char **port,
     char **server_path, char **repo_name, const char *uri)
 {
@@ -280,12 +299,12 @@ done:
 
 const struct got_error*
 got_fetch(struct got_object_id **pack_hash, struct got_pathlist_head *refs,
-    struct got_pathlist_head *symrefs, const char *proto, const char *host,
-    const char *port, const char *server_path, const char *repo_name,
-    const char *branch_filter, struct got_repository *repo)
+    struct got_pathlist_head *symrefs, int fetchfd, const char *proto,
+    const char *host, const char *port, const char *server_path,
+    const char *repo_name, const char *branch_filter, struct got_repository *repo)
 {
-	int imsg_fetchfds[2], imsg_idxfds[2], fetchfd = -1;
-	int packfd = -1, npackfd = -1, idxfd = -1, nidxfd = -1;
+	int imsg_fetchfds[2], imsg_idxfds[2];
+	int packfd = -1, npackfd = -1, idxfd = -1, nidxfd = -1, nfetchfd = -1;
 	int status, done = 0;
 	const struct got_error *err;
 	struct imsgbuf ibuf;
@@ -298,8 +317,6 @@ got_fetch(struct got_object_id **pack_hash, struct got_pathlist_head *refs,
 
 	*pack_hash = NULL;
 
-	fetchfd = -1;
-
 	if (asprintf(&path, "%s/%s/fetching.pack",
 	    repo_path, GOT_OBJECTS_PACK_DIR) == -1) {
 		err = got_error_from_errno("asprintf");
@@ -329,17 +346,6 @@ got_fetch(struct got_object_id **pack_hash, struct got_pathlist_head *refs,
 		goto done;
 	}
 
-	if (strcmp(proto, "ssh") == 0 || strcmp(proto, "git+ssh") == 0)
-		err = dial_ssh(&fetchfd, host, port, server_path, "upload");
-	else if (strcmp(proto, "git") == 0)
-		err = dial_git(&fetchfd, host, port, server_path, "upload");
-	else if (strcmp(proto, "http") == 0 || strcmp(proto, "git+http") == 0)
-		err = got_error(GOT_ERR_BAD_PROTO);
-	else
-		err = got_error(GOT_ERR_BAD_PROTO);
-	if (err)
-		goto done;
-
 	if (socketpair(AF_UNIX, SOCK_STREAM, PF_UNSPEC, imsg_fetchfds) == -1) {
 		err = got_error_from_errno("socketpair");
 		goto done;
@@ -359,10 +365,15 @@ got_fetch(struct got_object_id **pack_hash, struct got_pathlist_head *refs,
 		goto done;
 	}
 	imsg_init(&ibuf, imsg_fetchfds[0]);
-	err = got_privsep_send_fetch_req(&ibuf, fetchfd);
+	nfetchfd = dup(fetchfd);
+	if (nfetchfd == -1) {
+		err = got_error_from_errno("dup");
+		goto done;
+	}
+	err = got_privsep_send_fetch_req(&ibuf, nfetchfd);
 	if (err != NULL)
 		goto done;
-	fetchfd = -1;
+	nfetchfd = -1;
 	err = got_privsep_send_tmpfd(&ibuf, npackfd);
 	if (err != NULL)
 		goto done;
@@ -457,7 +468,7 @@ got_fetch(struct got_object_id **pack_hash, struct got_pathlist_head *refs,
 	}
 
 done:
-	if (fetchfd != -1 && close(fetchfd) == -1 && err == NULL)
+	if (nfetchfd != -1 && close(nfetchfd) == -1 && err == NULL)
 		err = got_error_from_errno("close");
 	if (npackfd != -1 && close(npackfd) == -1 && err == NULL)
 		err = got_error_from_errno("close");