Commit bb64b79860632effc06224a66e840f56e41f6ac6

Stefan Sperling 2020-03-18T16:11:26

have got_fetch() require an open got_repository to be passed in

diff --git a/got/got.c b/got/got.c
index 2aa285a..b5bed4a 100644
--- a/got/got.c
+++ b/got/got.c
@@ -974,6 +974,9 @@ cmd_clone(int argc, char *argv[])
 	const struct got_error *err = NULL;
 	const char *uri, *branch_filter, *dirname;
 	char *proto, *host, *port, *repo_name, *server_path;
+	char *default_destdir = NULL;
+	const char *repo_path;
+	struct got_repository *repo = NULL;
 	int ch;
 
 	while ((ch = getopt(argc, argv, "b:")) != -1) {
@@ -1001,14 +1004,38 @@ cmd_clone(int argc, char *argv[])
 	if (err)
 		goto done;
 
+	if (dirname == NULL) {
+		if (asprintf(&default_destdir, "%s.git", repo_name) == -1) {
+			err = got_error_from_errno("asprintf");
+			goto done;
+		}
+		repo_path = default_destdir;
+	} else
+		repo_path = dirname;
+
+	err = got_path_mkdir(repo_path);
+	if (err)
+		goto done;
+
+	err = got_repo_init(repo_path);
+	if (err != NULL)
+		goto done;
+
+	err = got_repo_open(&repo, repo_path, NULL);
+	if (err)
+		goto done;
+
 	err = got_fetch(proto, host, port, server_path, repo_name,
-	    branch_filter, dirname);
+	    branch_filter, repo);
 done:
+	if (repo)
+		got_repo_close(repo);
 	free(proto);
 	free(host);
 	free(port);
 	free(server_path);
 	free(repo_name);
+	free(default_destdir);
 	return err;
 }
 
diff --git a/include/got_fetch.h b/include/got_fetch.h
index 3af0abb..76bf04d 100644
--- a/include/got_fetch.h
+++ b/include/got_fetch.h
@@ -22,4 +22,4 @@ const struct got_error *got_fetch_parse_uri(char **, char **, char **,
     char **, char **, const char *);
 const struct got_error *got_fetch(const char *, const char *,
     const char *, const char *, const char *, const char *,
-    const char *);
+    struct got_repository *);
diff --git a/lib/fetch.c b/lib/fetch.c
index ff27135..4b57d05 100644
--- a/lib/fetch.c
+++ b/lib/fetch.c
@@ -281,7 +281,7 @@ done:
 const struct got_error*
 got_fetch(const char *proto, const char *host, const char *port,
     const char *server_path, const char *repo_name,
-    const char *branch_filter, const char *destdir)
+    const char *branch_filter, struct got_repository *repo)
 {
 	int imsg_fetchfds[2], imsg_idxfds[2], fetchfd = -1;
 	int packfd = -1, npackfd = -1, idxfd = -1, nidxfd = -1;
@@ -290,26 +290,17 @@ got_fetch(const char *proto, const char *host, const char *port,
 	const struct got_error *err;
 	struct imsgbuf ibuf;
 	pid_t pid;
-	char *tmppackpath = NULL, *tmpidxpath = NULL, *default_destdir = NULL;
+	char *tmppackpath = NULL, *tmpidxpath = NULL;
 	char *packpath = NULL, *idxpath = NULL, *id_str = NULL;
-	const char *repo_path;
+	const char *repo_path = got_repo_get_path(repo);
 	struct got_pathlist_head symrefs;
 	struct got_pathlist_entry *pe;
-	struct got_repository *repo = NULL;
 	char *path;
 
 	TAILQ_INIT(&symrefs);
 
 	fetchfd = -1;
-	if (destdir == NULL) {
-		if (asprintf(&default_destdir, "%s.git", repo_name) == -1)
-			return got_error_from_errno("asprintf");
-		repo_path = default_destdir;
-	} else
-		repo_path = destdir;
-	err = got_repo_init(repo_path);
-	if (err != NULL)
-		goto done;
+
 	if (asprintf(&path, "%s/objects/path", repo_path) == -1) {
 		err = got_error_from_errno("asprintf");
 		goto done;
@@ -388,10 +379,6 @@ got_fetch(const char *proto, const char *host, const char *port,
 		goto done;
 	}
 
-	err = got_repo_open(&repo, repo_path, NULL);
-	if (err)
-		goto done;
-
 	while (!done) {
 		struct got_object_id *id;
 		char *refname;
@@ -536,13 +523,10 @@ done:
 		err = got_error_from_errno("close");
 	if (idxfd != -1 && close(idxfd) == -1 && err == NULL)
 		err = got_error_from_errno("close");
-	if (repo)
-		got_repo_close(repo);
 	free(tmppackpath);
 	free(tmpidxpath);
 	free(idxpath);
 	free(packpath);
-	free(default_destdir);
 	free(packhash);
 	TAILQ_FOREACH(pe, &symrefs, entry) {
 		free((void *)pe->path);