Commit 02bb39f448b9ed151a638d22fdcbccc895f4d3cf

Edward Thomson 2018-11-22T08:49:09

stream registration: take an enum type Accept an enum (`git_stream_t`) during custom stream registration that indicates whether the registration structure should be used for standard (non-TLS) streams or TLS streams.

diff --git a/include/git2/sys/stream.h b/include/git2/sys/stream.h
index edbe66f..9387931 100644
--- a/include/git2/sys/stream.h
+++ b/include/git2/sys/stream.h
@@ -72,19 +72,31 @@ typedef struct {
 } git_stream_registration;
 
 /**
+ * The type of stream to register.
+ */
+typedef enum {
+	/** A standard (non-TLS) socket. */
+	GIT_STREAM_STANDARD = 1,
+
+	/** A TLS-encrypted socket. */
+	GIT_STREAM_TLS = 2,
+} git_stream_t;
+
+/**
  * Register stream constructors for the library to use
  *
  * If a registration structure is already set, it will be overwritten.
  * Pass `NULL` in order to deregister the current constructor and return
  * to the system defaults.
  *
+ * The type parameter may be a bitwise AND of types.
+ *
+ * @param type the type or types of stream to register
  * @param registration the registration data
- * @param tls 1 if the registration is for TLS streams, 0 for regular
- *        (insecure) sockets
  * @return 0 or an error code
  */
 GIT_EXTERN(int) git_stream_register(
-	int tls, git_stream_registration *registration);
+	git_stream_t type, git_stream_registration *registration);
 
 /** @name Deprecated TLS Stream Registration Functions
  *
diff --git a/src/streams/registry.c b/src/streams/registry.c
index 210e2d0..02fd349 100644
--- a/src/streams/registry.c
+++ b/src/streams/registry.c
@@ -36,22 +36,42 @@ int git_stream_registry_global_init(void)
 	return 0;
 }
 
-int git_stream_registry_lookup(git_stream_registration *out, int tls)
+GIT_INLINE(void) stream_registration_cpy(
+	git_stream_registration *target,
+	git_stream_registration *src)
 {
-	git_stream_registration *target = tls ?
-		&stream_registry.callbacks :
-		&stream_registry.tls_callbacks;
+	if (src)
+		memcpy(target, src, sizeof(git_stream_registration));
+	else
+		memset(target, 0, sizeof(git_stream_registration));
+}
+
+int git_stream_registry_lookup(git_stream_registration *out, git_stream_t type)
+{
+	git_stream_registration *target;
 	int error = GIT_ENOTFOUND;
 
 	assert(out);
 
+	switch(type) {
+	case GIT_STREAM_STANDARD:
+		target = &stream_registry.callbacks;
+		break;
+	case GIT_STREAM_TLS:
+		target = &stream_registry.tls_callbacks;
+		break;
+	default:
+		assert(0);
+		return -1;
+	}
+
 	if (git_rwlock_rdlock(&stream_registry.lock) < 0) {
 		giterr_set(GITERR_OS, "failed to lock stream registry");
 		return -1;
 	}
 
 	if (target->init) {
-		memcpy(out, target, sizeof(git_stream_registration));
+		stream_registration_cpy(out, target);
 		error = 0;
 	}
 
@@ -59,12 +79,8 @@ int git_stream_registry_lookup(git_stream_registration *out, int tls)
 	return error;
 }
 
-int git_stream_register(int tls, git_stream_registration *registration)
+int git_stream_register(git_stream_t type, git_stream_registration *registration)
 {
-	git_stream_registration *target = tls ?
-		&stream_registry.callbacks :
-		&stream_registry.tls_callbacks;
-
 	assert(!registration || registration->init);
 
 	GITERR_CHECK_VERSION(registration, GIT_STREAM_VERSION, "stream_registration");
@@ -74,10 +90,11 @@ int git_stream_register(int tls, git_stream_registration *registration)
 		return -1;
 	}
 
-	if (registration)
-		memcpy(target, registration, sizeof(git_stream_registration));
-	else
-		memset(target, 0, sizeof(git_stream_registration));
+	if ((type & GIT_STREAM_STANDARD) == GIT_STREAM_STANDARD)
+		stream_registration_cpy(&stream_registry.callbacks, registration);
+
+	if ((type & GIT_STREAM_TLS) == GIT_STREAM_TLS)
+		stream_registration_cpy(&stream_registry.tls_callbacks, registration);
 
 	git_rwlock_wrunlock(&stream_registry.lock);
 	return 0;
@@ -92,8 +109,8 @@ int git_stream_register_tls(git_stream_cb ctor)
 		registration.init = ctor;
 		registration.wrap = NULL;
 
-		return git_stream_register(1, &registration);
+		return git_stream_register(GIT_STREAM_TLS, &registration);
 	} else {
-		return git_stream_register(1, NULL);
+		return git_stream_register(GIT_STREAM_TLS, NULL);
 	}
 }
diff --git a/src/streams/registry.h b/src/streams/registry.h
index 92f87a7..adc2b8b 100644
--- a/src/streams/registry.h
+++ b/src/streams/registry.h
@@ -14,6 +14,6 @@
 int git_stream_registry_global_init(void);
 
 /** Lookup a stream registration. */
-extern int git_stream_registry_lookup(git_stream_registration *out, int tls);
+extern int git_stream_registry_lookup(git_stream_registration *out, git_stream_t type);
 
 #endif
diff --git a/src/streams/socket.c b/src/streams/socket.c
index 21f7fea..732b459 100644
--- a/src/streams/socket.c
+++ b/src/streams/socket.c
@@ -224,7 +224,7 @@ int git_socket_stream_new(
 
 	assert(out && host && port);
 
-	if ((error = git_stream_registry_lookup(&custom, 0)) == 0)
+	if ((error = git_stream_registry_lookup(&custom, GIT_STREAM_STANDARD)) == 0)
 		init = custom.init;
 	else if (error == GIT_ENOTFOUND)
 		init = default_socket_stream_new;
diff --git a/src/streams/tls.c b/src/streams/tls.c
index 0e10697..0ea47fb 100644
--- a/src/streams/tls.c
+++ b/src/streams/tls.c
@@ -23,7 +23,7 @@ int git_tls_stream_new(git_stream **out, const char *host, const char *port)
 
 	assert(out && host && port);
 
-	if ((error = git_stream_registry_lookup(&custom, 1)) == 0) {
+	if ((error = git_stream_registry_lookup(&custom, GIT_STREAM_TLS)) == 0) {
 		init = custom.init;
 	} else if (error == GIT_ENOTFOUND) {
 #ifdef GIT_SECURE_TRANSPORT
@@ -52,7 +52,7 @@ int git_tls_stream_wrap(git_stream **out, git_stream *in, const char *host)
 
 	assert(out && in);
 
-	if (git_stream_registry_lookup(&custom, 1) == 0) {
+	if (git_stream_registry_lookup(&custom, GIT_STREAM_TLS) == 0) {
 		wrap = custom.wrap;
 	} else {
 #ifdef GIT_SECURE_TRANSPORT
diff --git a/tests/core/stream.c b/tests/core/stream.c
index a76169d..f15dce3 100644
--- a/tests/core/stream.c
+++ b/tests/core/stream.c
@@ -7,6 +7,11 @@
 static git_stream test_stream;
 static int ctor_called;
 
+void test_core_stream__cleanup(void)
+{
+	cl_git_pass(git_stream_register(GIT_STREAM_TLS | GIT_STREAM_STANDARD, NULL));
+}
+
 static int test_stream_init(git_stream **out, const char *host, const char *port)
 {
 	GIT_UNUSED(host);
@@ -39,14 +44,14 @@ void test_core_stream__register_insecure(void)
 	registration.wrap = test_stream_wrap;
 
 	ctor_called = 0;
-	cl_git_pass(git_stream_register(0, &registration));
+	cl_git_pass(git_stream_register(GIT_STREAM_STANDARD, &registration));
 	cl_git_pass(git_socket_stream_new(&stream, "localhost", "80"));
 	cl_assert_equal_i(1, ctor_called);
 	cl_assert_equal_p(&test_stream, stream);
 
 	ctor_called = 0;
 	stream = NULL;
-	cl_git_pass(git_stream_register(0, NULL));
+	cl_git_pass(git_stream_register(GIT_STREAM_STANDARD, NULL));
 	cl_git_pass(git_socket_stream_new(&stream, "localhost", "80"));
 
 	cl_assert_equal_i(0, ctor_called);
@@ -66,14 +71,14 @@ void test_core_stream__register_tls(void)
 	registration.wrap = test_stream_wrap;
 
 	ctor_called = 0;
-	cl_git_pass(git_stream_register(1, &registration));
+	cl_git_pass(git_stream_register(GIT_STREAM_TLS, &registration));
 	cl_git_pass(git_tls_stream_new(&stream, "localhost", "443"));
 	cl_assert_equal_i(1, ctor_called);
 	cl_assert_equal_p(&test_stream, stream);
 
 	ctor_called = 0;
 	stream = NULL;
-	cl_git_pass(git_stream_register(1, NULL));
+	cl_git_pass(git_stream_register(GIT_STREAM_TLS, NULL));
 	error = git_tls_stream_new(&stream, "localhost", "443");
 
 	/* We don't have TLS support enabled, or we're on Windows,
@@ -91,6 +96,28 @@ void test_core_stream__register_tls(void)
 	git_stream_free(stream);
 }
 
+void test_core_stream__register_both(void)
+{
+	git_stream *stream;
+	git_stream_registration registration = {0};
+
+	registration.version = 1;
+	registration.init = test_stream_init;
+	registration.wrap = test_stream_wrap;
+
+	cl_git_pass(git_stream_register(GIT_STREAM_STANDARD | GIT_STREAM_TLS, &registration));
+
+	ctor_called = 0;
+	cl_git_pass(git_tls_stream_new(&stream, "localhost", "443"));
+	cl_assert_equal_i(1, ctor_called);
+	cl_assert_equal_p(&test_stream, stream);
+
+	ctor_called = 0;
+	cl_git_pass(git_socket_stream_new(&stream, "localhost", "80"));
+	cl_assert_equal_i(1, ctor_called);
+	cl_assert_equal_p(&test_stream, stream);
+}
+
 void test_core_stream__register_tls_deprecated(void)
 {
 	git_stream *stream;