Commit bfce7f836563143b27bfa24b645f65198c0298f8

Stefan Sperling 2019-07-27T21:26:27

improve histedit script error checking and fix leaks

diff --git a/got/got.c b/got/got.c
index 45c7c68..1f7e06b 100644
--- a/got/got.c
+++ b/got/got.c
@@ -4407,8 +4407,46 @@ histedit_parse_list(struct got_histedit_list *histedit_cmds,
 }
 
 static const struct got_error *
+histedit_check_script(struct got_histedit_list *histedit_cmds,
+    struct got_object_id_queue *commits, struct got_repository *repo)
+{
+	const struct got_error *err = NULL;
+	struct got_object_qid *qid;
+	struct got_histedit_list_entry *hle;
+	static char msg[80];
+	char *id_str;
+
+	if (TAILQ_EMPTY(histedit_cmds))
+		return got_error_msg(GOT_ERR_EMPTY_HISTEDIT,
+		    "histedit script contains no commands");
+
+	SIMPLEQ_FOREACH(qid, commits, entry) {
+		TAILQ_FOREACH(hle, histedit_cmds, entry) {
+			if (got_object_id_cmp(qid->id, hle->commit_id) == 0)
+				break;
+		}
+		if (hle == NULL) {
+			err = got_object_id_str(&id_str, qid->id);
+			if (err)
+				return err;
+			snprintf(msg, sizeof(msg),
+			    "commit %s missing from histedit script", id_str);
+			free(id_str);
+			return got_error_msg(GOT_ERR_HISTEDIT_CMD, msg);
+		}
+	}
+
+	if (hle->cmd->code == GOT_HISTEDIT_FOLD)
+		return got_error_msg(GOT_ERR_HISTEDIT_CMD,
+		    "last commit in histedit script cannot be folded");
+
+	return NULL;
+}
+
+static const struct got_error *
 histedit_run_editor(struct got_histedit_list *histedit_cmds,
-    const char *path, struct got_repository *repo)
+    const char *path, struct got_object_id_queue *commits,
+    struct got_repository *repo)
 {
 	const struct got_error *err = NULL;
 	char *editor;
@@ -4429,6 +4467,10 @@ histedit_run_editor(struct got_histedit_list *histedit_cmds,
 		goto done;
 	}
 	err = histedit_parse_list(histedit_cmds, f, repo);
+	if (err)
+		goto done;
+
+	err = histedit_check_script(histedit_cmds, commits, repo);
 done:
 	if (f && fclose(f) != 0 && err == NULL)
 		err = got_error_from_errno("fclose");
@@ -4437,7 +4479,7 @@ done:
 }
 
 static const struct got_error *
-histedit_edit_list_retry(struct got_histedit_list *, const char *,
+histedit_edit_list_retry(struct got_histedit_list *, const struct got_error *,
     struct got_object_id_queue *, const char *, struct got_repository *);
 
 static const struct got_error *
@@ -4466,12 +4508,12 @@ histedit_edit_script(struct got_histedit_list *histedit_cmds,
 	}
 	f = NULL;
 
-	err = histedit_run_editor(histedit_cmds, path, repo);
+	err = histedit_run_editor(histedit_cmds, path, commits, repo);
 	if (err) {
-		const char *errmsg = err->msg;
-		if (err->code != GOT_ERR_HISTEDIT_SYNTAX)
+		if (err->code != GOT_ERR_HISTEDIT_SYNTAX &&
+		    err->code != GOT_ERR_HISTEDIT_CMD)
 			goto done;
-		err = histedit_edit_list_retry(histedit_cmds, errmsg,
+		err = histedit_edit_list_retry(histedit_cmds, err,
 		    commits, path, repo);
 	}
 done:
@@ -4526,6 +4568,17 @@ done:
 	return err;
 }
 
+void
+histedit_free_list(struct got_histedit_list *histedit_cmds)
+{
+	struct got_histedit_list_entry *hle;
+
+	while ((hle = TAILQ_FIRST(histedit_cmds))) {
+		TAILQ_REMOVE(histedit_cmds, hle, entry);
+		free(hle);
+	}
+}
+
 static const struct got_error *
 histedit_load_list(struct got_histedit_list *histedit_cmds,
     const char *path, struct got_repository *repo)
@@ -4548,33 +4601,42 @@ done:
 
 static const struct got_error *
 histedit_edit_list_retry(struct got_histedit_list *histedit_cmds,
-    const char *errmsg, struct got_object_id_queue *commits,
+    const struct got_error *edit_err, struct got_object_id_queue *commits,
     const char *path, struct got_repository *repo)
 {
-	const struct got_error *err = NULL;
+	const struct got_error *err = NULL, *prev_err = edit_err;
 	int resp = ' ';
 
 	while (resp != 'c' && resp != 'r' && resp != 'a') {
 		printf("%s: %s\n(c)ontinue editing, (r)estart editing, "
-		    "or (a)bort: ", getprogname(), errmsg);
+		    "or (a)bort: ", getprogname(), prev_err->msg);
 		resp = getchar();
 		if (resp == 'c') {
-			err = histedit_run_editor(histedit_cmds, path, repo);
+			histedit_free_list(histedit_cmds);
+			err = histedit_run_editor(histedit_cmds, path, commits,
+			    repo);
 			if (err) {
-				if (err->code != GOT_ERR_HISTEDIT_SYNTAX)
+				if (err->code != GOT_ERR_HISTEDIT_SYNTAX &&
+				    err->code != GOT_ERR_HISTEDIT_CMD)
 					break;
+				prev_err = err;
 				resp = ' ';
 				continue;
 			}
+			break;
 		} else if (resp == 'r') {
+			histedit_free_list(histedit_cmds);
 			err = histedit_edit_script(histedit_cmds,
 			    commits, repo);
 			if (err) {
-				if (err->code != GOT_ERR_HISTEDIT_SYNTAX)
+				if (err->code != GOT_ERR_HISTEDIT_SYNTAX &&
+				    err->code != GOT_ERR_HISTEDIT_CMD)
 					break;
+				prev_err = err;
 				resp = ' ';
 				continue;
 			}
+			break;
 		} else if (resp == 'a') {
 			err = got_error(GOT_ERR_HISTEDIT_CANCEL);
 			break;
@@ -4710,43 +4772,6 @@ histedit_skip_commit(struct got_histedit_list_entry *hle,
 }
 
 static const struct got_error *
-histedit_check_script(struct got_histedit_list *histedit_cmds,
-    struct got_object_id_queue *commits, struct got_repository *repo)
-{
-	const struct got_error *err = NULL;
-	struct got_object_qid *qid;
-	struct got_histedit_list_entry *hle;
-	static char msg[80];
-	char *id_str;
-
-	if (TAILQ_EMPTY(histedit_cmds))
-		return got_error_msg(GOT_ERR_EMPTY_HISTEDIT,
-		    "histedit script contains no commands");
-
-	SIMPLEQ_FOREACH(qid, commits, entry) {
-		TAILQ_FOREACH(hle, histedit_cmds, entry) {
-			if (got_object_id_cmp(qid->id, hle->commit_id) == 0)
-				break;
-		}
-		if (hle == NULL) {
-			err = got_object_id_str(&id_str, qid->id);
-			if (err)
-				return err;
-			snprintf(msg, sizeof(msg),
-			    "commit %s missing from histedit script", id_str);
-			free(id_str);
-			return got_error_msg(GOT_ERR_HISTEDIT_CMD, msg);
-		}
-	}
-
-	if (hle->cmd->code == GOT_HISTEDIT_FOLD)
-		return got_error_msg(GOT_ERR_HISTEDIT_CMD,
-		    "last commit in histedit script cannot be folded");
-
-	return NULL;
-}
-
-static const struct got_error *
 cmd_histedit(int argc, char *argv[])
 {
 	const struct got_error *error = NULL;
@@ -4916,6 +4941,8 @@ cmd_histedit(int argc, char *argv[])
 			goto done;
 
 		error = got_ref_resolve(&head_commit_id, repo, branch);
+		got_ref_close(branch);
+		branch = NULL;
 		if (error)
 			goto done;
 
@@ -4938,6 +4965,11 @@ cmd_histedit(int argc, char *argv[])
 		if (error)
 			goto done;
 
+		error = got_worktree_histedit_prepare(&tmp_branch, &branch,
+		    &base_commit_id, &fileindex, worktree, repo);
+		if (error)
+			goto done;
+
 		if (edit_script_path) {
 			error = histedit_load_list(&histedit_cmds,
 			    edit_script_path, repo);
@@ -4956,11 +4988,6 @@ cmd_histedit(int argc, char *argv[])
 		if (error)
 			goto done;
 
-		error = got_worktree_histedit_prepare(&tmp_branch, &branch,
-		    &base_commit_id, &fileindex, worktree, repo);
-		if (error)
-			goto done;
-
 	}
 
 	 error = histedit_check_script(&histedit_cmds, &commits, repo);
@@ -5053,6 +5080,7 @@ cmd_histedit(int argc, char *argv[])
 		    branch, repo);
 done:
 	got_object_id_queue_free(&commits);
+	histedit_free_list(&histedit_cmds);
 	free(head_commit_id);
 	free(base_commit_id);
 	free(resume_commit_id);