Commit cbd01874759dad39718e1496d241c4ae5eceaff1

Sam Lantinga 2022-06-27T16:59:50

Removed the limit on the size of the SDL error message Also added SDL_GetOriginalMemoryFunctions() Fixes https://github.com/libsdl-org/SDL/issues/5795

diff --git a/WhatsNew.txt b/WhatsNew.txt
index 8afb78f..517f0a7 100644
--- a/WhatsNew.txt
+++ b/WhatsNew.txt
@@ -30,6 +30,7 @@ General:
 * Added joystick event SDL_JOYBATTERYUPDATED for when battery status changes.
 * Added SDL_GUIDToString() and SDL_GUIDFromString() to convert between SDL GUID and string
 * Added SDL_HasLSX() and SDL_HasLASX() to detect LoongArch SIMD support
+* Added SDL_GetOriginalMemoryFunctions()
 
 Windows:
 * Added a D3D12 renderer implementation and SDL_RenderGetD3D12Device() to retrieve the D3D12 device associated with it
diff --git a/include/SDL_stdinc.h b/include/SDL_stdinc.h
index 0da2256..06f6fee 100644
--- a/include/SDL_stdinc.h
+++ b/include/SDL_stdinc.h
@@ -439,6 +439,16 @@ typedef void *(SDLCALL *SDL_realloc_func)(void *mem, size_t size);
 typedef void (SDLCALL *SDL_free_func)(void *mem);
 
 /**
+ * Get the original set of SDL memory functions
+ *
+ * \since This function is available since SDL 2.24.0.
+ */
+extern DECLSPEC void SDLCALL SDL_GetOriginalMemoryFunctions(SDL_malloc_func *malloc_func,
+                                                            SDL_calloc_func *calloc_func,
+                                                            SDL_realloc_func *realloc_func,
+                                                            SDL_free_func *free_func);
+
+/**
  * Get the current set of SDL memory functions
  *
  * \since This function is available since SDL 2.0.7.
diff --git a/src/SDL_error.c b/src/SDL_error.c
index 68b518d..e3501e6 100644
--- a/src/SDL_error.c
+++ b/src/SDL_error.c
@@ -31,12 +31,22 @@ SDL_SetError(SDL_PRINTF_FORMAT_STRING const char *fmt, ...)
     /* Ignore call if invalid format pointer was passed */
     if (fmt != NULL) {
         va_list ap;
+        int result;
         SDL_error *error = SDL_GetErrBuf();
 
         error->error = 1;  /* mark error as valid */
 
         va_start(ap, fmt);
-        SDL_vsnprintf(error->str, ERR_MAX_STRLEN, fmt, ap);
+        result = SDL_vsnprintf(error->str, error->len, fmt, ap);
+        if (result >= 0 && (size_t)result >= error->len && error->realloc_func) {
+            size_t len = (size_t)result + 1;
+            char *str = (char *)error->realloc_func(error->str, len);
+            if (str) {
+                error->str = str;
+                error->len = len;
+                SDL_vsnprintf(error->str, error->len, fmt, ap);
+            }
+        }
         va_end(ap);
 
         if (SDL_LogGetPriority(SDL_LOG_CATEGORY_ERROR) <= SDL_LOG_PRIORITY_DEBUG) {
diff --git a/src/SDL_error_c.h b/src/SDL_error_c.h
index af6a15e..1ac1f42 100644
--- a/src/SDL_error_c.h
+++ b/src/SDL_error_c.h
@@ -27,12 +27,13 @@
 #ifndef SDL_error_c_h_
 #define SDL_error_c_h_
 
-#define ERR_MAX_STRLEN  128
-
 typedef struct SDL_error
 {
     int error; /* This is a numeric value corresponding to the current error */
-    char str[ERR_MAX_STRLEN];
+    char *str;
+    size_t len;
+    SDL_realloc_func realloc_func;
+    SDL_free_func free_func;
 } SDL_error;
 
 /* Defined in SDL_thread.c */
diff --git a/src/dynapi/SDL2.exports b/src/dynapi/SDL2.exports
index 74098be..746b5fc 100644
--- a/src/dynapi/SDL2.exports
+++ b/src/dynapi/SDL2.exports
@@ -851,3 +851,4 @@
 ++'_SDL_utf8strnlen'.'SDL2.dll'.'SDL_utf8strnlen'
 # ++'_SDL_GDKGetTaskQueue'.'SDL2.dll'.'SDL_GDKGetTaskQueue'
 # ++'_SDL_GDKRunApp'.'SDL2.dll'.'SDL_GDKRunApp'
+++'_SDL_GetOriginalMemoryFunctions'.'SDL2.dll'.'SDL_GetOriginalMemoryFunctions'
diff --git a/src/dynapi/SDL_dynapi.c b/src/dynapi/SDL_dynapi.c
index 38df73c..994a2c2 100644
--- a/src/dynapi/SDL_dynapi.c
+++ b/src/dynapi/SDL_dynapi.c
@@ -71,11 +71,25 @@ static void SDL_InitDynamicAPI(void);
 
 #define SDL_DYNAPI_VARARGS(_static, name, initcall) \
     _static int SDLCALL SDL_SetError##name(SDL_PRINTF_FORMAT_STRING const char *fmt, ...) { \
-        char buf[512]; /* !!! FIXME: dynamic allocation */ \
+        char buf[128], *str = buf; \
+        int result; \
         va_list ap; initcall; va_start(ap, fmt); \
-        jump_table.SDL_vsnprintf(buf, sizeof (buf), fmt, ap); \
+        result = jump_table.SDL_vsnprintf(buf, sizeof(buf), fmt, ap); \
+        if (result >= 0 && (size_t)result >= sizeof(buf)) { \
+            size_t len = (size_t)result + 1; \
+            str = (char *)jump_table.SDL_malloc(len); \
+            if (str) { \
+                result = jump_table.SDL_vsnprintf(str, len, fmt, ap); \
+            } \
+        } \
         va_end(ap); \
-        return jump_table.SDL_SetError("%s", buf); \
+        if (result >= 0) { \
+            result = jump_table.SDL_SetError("%s", str); \
+        } \
+        if (str != buf) { \
+            jump_table.SDL_free(str); \
+        } \
+        return result; \
     } \
     _static int SDLCALL SDL_sscanf##name(const char *buf, SDL_SCANF_FORMAT_STRING const char *fmt, ...) { \
         int retval; va_list ap; initcall; va_start(ap, fmt); \
diff --git a/src/dynapi/SDL_dynapi_overrides.h b/src/dynapi/SDL_dynapi_overrides.h
index 12d8554..babe19e 100644
--- a/src/dynapi/SDL_dynapi_overrides.h
+++ b/src/dynapi/SDL_dynapi_overrides.h
@@ -877,3 +877,4 @@
 #define SDL_utf8strnlen SDL_utf8strnlen_REAL
 #define SDL_GDKGetTaskQueue SDL_GDKGetTaskQueue_REAL
 #define SDL_GDKRunApp SDL_GDKRunApp_REAL
+#define SDL_GetOriginalMemoryFunctions SDL_GetOriginalMemoryFunctions_REAL
diff --git a/src/dynapi/SDL_dynapi_procs.h b/src/dynapi/SDL_dynapi_procs.h
index f0a2d0d..4c61f19 100644
--- a/src/dynapi/SDL_dynapi_procs.h
+++ b/src/dynapi/SDL_dynapi_procs.h
@@ -960,3 +960,4 @@ SDL_DYNAPI_PROC(size_t,SDL_utf8strnlen,(const char *a, size_t b),(a,b),return)
 SDL_DYNAPI_PROC(int,SDL_GDKGetTaskQueue,(XTaskQueueHandle *a),(a),return)
 SDL_DYNAPI_PROC(int,SDL_GDKRunApp,(SDL_main_func a, void *b),(a,b),return)
 #endif
+SDL_DYNAPI_PROC(void,SDL_GetOriginalMemoryFunctions,(SDL_malloc_func *a, SDL_calloc_func *b, SDL_realloc_func *c, SDL_free_func *d),(a,b,c,d),)
diff --git a/src/stdlib/SDL_malloc.c b/src/stdlib/SDL_malloc.c
index 0a51ddd..72410b4 100644
--- a/src/stdlib/SDL_malloc.c
+++ b/src/stdlib/SDL_malloc.c
@@ -5328,6 +5328,25 @@ static struct
     real_malloc, real_calloc, real_realloc, real_free, { 0 }
 };
 
+void SDL_GetOriginalMemoryFunctions(SDL_malloc_func *malloc_func,
+                                    SDL_calloc_func *calloc_func,
+                                    SDL_realloc_func *realloc_func,
+                                    SDL_free_func *free_func)
+{
+    if (malloc_func) {
+        *malloc_func = real_malloc;
+    }
+    if (calloc_func) {
+        *calloc_func = real_calloc;
+    }
+    if (realloc_func) {
+        *realloc_func = real_realloc;
+    }
+    if (free_func) {
+        *free_func = real_free;
+    }
+}
+
 void SDL_GetMemoryFunctions(SDL_malloc_func *malloc_func,
                             SDL_calloc_func *calloc_func,
                             SDL_realloc_func *realloc_func,
diff --git a/src/thread/SDL_thread.c b/src/thread/SDL_thread.c
index 61af43f..1a60cb0 100644
--- a/src/thread/SDL_thread.c
+++ b/src/thread/SDL_thread.c
@@ -200,19 +200,38 @@ SDL_Generic_SetTLSData(SDL_TLSData *storage)
     return 0;
 }
 
+/* Non-thread-safe global error variable */
+static SDL_error *
+SDL_GetStaticErrBuf()
+{
+    static SDL_error SDL_global_error;
+    static char SDL_global_error_str[128];
+    SDL_global_error.str = SDL_global_error_str;
+    SDL_global_error.len = sizeof(SDL_global_error_str);
+    return &SDL_global_error;
+}
+
+static void
+SDL_FreeErrBuf(void *data)
+{
+    SDL_error *errbuf = (SDL_error *)data;
+
+    if (errbuf->str) {
+        errbuf->free_func(errbuf->str);
+    }
+    errbuf->free_func(errbuf);
+}
+
 /* Routine to get the thread-specific error variable */
 SDL_error *
 SDL_GetErrBuf(void)
 {
 #if SDL_THREADS_DISABLED
-    /* Non-thread-safe global error variable */
-    static SDL_error SDL_global_error;
-    return &SDL_global_error;
+    return SDL_GetStaticErrBuf();
 #else
     static SDL_SpinLock tls_lock;
     static SDL_bool tls_being_created;
     static SDL_TLSID tls_errbuf;
-    static SDL_error SDL_global_errbuf;
     const SDL_error *ALLOCATION_IN_PROGRESS = (SDL_error *)-1;
     SDL_error *errbuf;
 
@@ -233,24 +252,33 @@ SDL_GetErrBuf(void)
         SDL_AtomicUnlock(&tls_lock);
     }
     if (!tls_errbuf) {
-        return &SDL_global_errbuf;
+        return SDL_GetStaticErrBuf();
     }
 
     SDL_MemoryBarrierAcquire();
     errbuf = (SDL_error *)SDL_TLSGet(tls_errbuf);
     if (errbuf == ALLOCATION_IN_PROGRESS) {
-        return &SDL_global_errbuf;
+        return SDL_GetStaticErrBuf();
     }
     if (!errbuf) {
+        /* Get the original memory functions for this allocation because the lifetime
+         * of the error buffer may span calls to SDL_SetMemoryFunctions() by the app
+         */
+        SDL_realloc_func realloc_func;
+        SDL_free_func free_func;
+        SDL_GetOriginalMemoryFunctions(NULL, NULL, &realloc_func, &free_func);
+
         /* Mark that we're in the middle of allocating our buffer */
         SDL_TLSSet(tls_errbuf, ALLOCATION_IN_PROGRESS, NULL);
-        errbuf = (SDL_error *)SDL_malloc(sizeof(*errbuf));
+        errbuf = (SDL_error *)realloc_func(NULL, sizeof(*errbuf));
         if (!errbuf) {
             SDL_TLSSet(tls_errbuf, NULL, NULL);
-            return &SDL_global_errbuf;
+            return SDL_GetStaticErrBuf();
         }
         SDL_zerop(errbuf);
-        SDL_TLSSet(tls_errbuf, errbuf, SDL_free);
+        errbuf->realloc_func = realloc_func;
+        errbuf->free_func = free_func;
+        SDL_TLSSet(tls_errbuf, errbuf, SDL_FreeErrBuf);
     }
     return errbuf;
 #endif /* SDL_THREADS_DISABLED */
diff --git a/test/testautomation_main.c b/test/testautomation_main.c
index 229f2bc..f267789 100644
--- a/test/testautomation_main.c
+++ b/test/testautomation_main.c
@@ -125,6 +125,26 @@ static int main_testImpliedJoystickQuit (void *arg)
 #endif
 }
 
+static int
+main_testSetError(void *arg)
+{
+    size_t i;
+    char error[1024];
+
+    error[0] = '\0';
+    SDL_SetError("");
+    SDLTest_AssertCheck(SDL_strcmp(error, SDL_GetError()) == 0, "SDL_SetError(\"\")");
+
+    for (i = 0; i < (sizeof(error)-1); ++i) {
+        error[i] = 'a' + (i % 26);
+    }
+    error[i] = '\0';
+    SDL_SetError("%s", error);
+    SDLTest_AssertCheck(SDL_strcmp(error, SDL_GetError()) == 0, "SDL_SetError(\"abc...1023\")");
+
+    return TEST_COMPLETED;
+}
+
 static const SDLTest_TestCaseReference mainTest1 =
         { (SDLTest_TestCaseFp)main_testInitQuitJoystickHaptic, "main_testInitQuitJoystickHaptic", "Tests SDL_Init/Quit of Joystick and Haptic subsystem", TEST_ENABLED};
 
@@ -137,12 +157,16 @@ static const SDLTest_TestCaseReference mainTest3 =
 static const SDLTest_TestCaseReference mainTest4 =
         { (SDLTest_TestCaseFp)main_testImpliedJoystickQuit, "main_testImpliedJoystickQuit", "Tests that quit for gamecontroller doesn't quit joystick if you inited it explicitly", TEST_ENABLED};
 
+static const SDLTest_TestCaseReference mainTest5 =
+        { (SDLTest_TestCaseFp)main_testSetError, "main_testSetError", "Tests that SDL_SetError() handles arbitrarily large strings", TEST_ENABLED};
+
 /* Sequence of Main test cases */
 static const SDLTest_TestCaseReference *mainTests[] =  {
     &mainTest1,
     &mainTest2,
     &mainTest3,
     &mainTest4,
+    &mainTest5,
     NULL
 };