Commit 2d01573bb9444fb1d684116900d871a5c3e138bf

Sylvain Becker 2021-03-16T15:44:33

Add METAL shaders

diff --git a/src/render/metal/SDL_shaders_metal.metal b/src/render/metal/SDL_shaders_metal.metal
index 7975a39..fd9e3a8 100644
--- a/src/render/metal/SDL_shaders_metal.metal
+++ b/src/render/metal/SDL_shaders_metal.metal
@@ -6,11 +6,13 @@ using namespace metal;
 struct SolidVertexInput
 {
     float2 position [[attribute(0)]];
+    float4 color    [[attribute(1)]];
 };
 
 struct SolidVertexOutput
 {
     float4 position [[position]];
+    float4 color;
     float pointSize [[point_size]];
 };
 
@@ -20,24 +22,27 @@ vertex SolidVertexOutput SDL_Solid_vertex(SolidVertexInput in [[stage_in]],
 {
     SolidVertexOutput v;
     v.position = (projection * transform) * float4(in.position, 0.0f, 1.0f);
+    v.color = in.color;
     v.pointSize = 1.0f;
     return v;
 }
 
-fragment float4 SDL_Solid_fragment(const device float4 &col [[buffer(0)]])
+fragment float4 SDL_Solid_fragment(SolidVertexInput in [[stage_in]])
 {
-    return col;
+    return in.color;
 }
 
 struct CopyVertexInput
 {
     float2 position [[attribute(0)]];
-    float2 texcoord [[attribute(1)]];
+    float4 color    [[attribute(1)]];
+    float2 texcoord [[attribute(2)]];
 };
 
 struct CopyVertexOutput
 {
     float4 position [[position]];
+    float4 color;
     float2 texcoord;
 };
 
@@ -47,16 +52,16 @@ vertex CopyVertexOutput SDL_Copy_vertex(CopyVertexInput in [[stage_in]],
 {
     CopyVertexOutput v;
     v.position = (projection * transform) * float4(in.position, 0.0f, 1.0f);
+    v.color = in.color;
     v.texcoord = in.texcoord;
     return v;
 }
 
 fragment float4 SDL_Copy_fragment(CopyVertexOutput vert [[stage_in]],
-                                  const device float4 &col [[buffer(0)]],
                                   texture2d<float> tex [[texture(0)]],
                                   sampler s [[sampler(0)]])
 {
-    return tex.sample(s, vert.texcoord) * col;
+    return tex.sample(s, vert.texcoord) * vert.color;
 }
 
 struct YUVDecode
@@ -68,7 +73,6 @@ struct YUVDecode
 };
 
 fragment float4 SDL_YUV_fragment(CopyVertexOutput vert [[stage_in]],
-                                 const device float4 &col [[buffer(0)]],
                                  constant YUVDecode &decode [[buffer(1)]],
                                  texture2d<float> texY [[texture(0)]],
                                  texture2d_array<float> texUV [[texture(1)]],
@@ -81,11 +85,10 @@ fragment float4 SDL_YUV_fragment(CopyVertexOutput vert [[stage_in]],
 
     yuv += decode.offset;
 
-    return col * float4(dot(yuv, decode.Rcoeff), dot(yuv, decode.Gcoeff), dot(yuv, decode.Bcoeff), 1.0);
+    return vert.color * float4(dot(yuv, decode.Rcoeff), dot(yuv, decode.Gcoeff), dot(yuv, decode.Bcoeff), 1.0);
 }
 
 fragment float4 SDL_NV12_fragment(CopyVertexOutput vert [[stage_in]],
-                                 const device float4 &col [[buffer(0)]],
                                  constant YUVDecode &decode [[buffer(1)]],
                                  texture2d<float> texY [[texture(0)]],
                                  texture2d<float> texUV [[texture(1)]],
@@ -97,11 +100,10 @@ fragment float4 SDL_NV12_fragment(CopyVertexOutput vert [[stage_in]],
 
     yuv += decode.offset;
 
-    return col * float4(dot(yuv, decode.Rcoeff), dot(yuv, decode.Gcoeff), dot(yuv, decode.Bcoeff), 1.0);
+    return vert.color * float4(dot(yuv, decode.Rcoeff), dot(yuv, decode.Gcoeff), dot(yuv, decode.Bcoeff), 1.0);
 }
 
 fragment float4 SDL_NV21_fragment(CopyVertexOutput vert [[stage_in]],
-                                 const device float4 &col [[buffer(0)]],
                                  constant YUVDecode &decode [[buffer(1)]],
                                  texture2d<float> texY [[texture(0)]],
                                  texture2d<float> texUV [[texture(1)]],
@@ -113,5 +115,6 @@ fragment float4 SDL_NV21_fragment(CopyVertexOutput vert [[stage_in]],
 
     yuv += decode.offset;
 
-    return col * float4(dot(yuv, decode.Rcoeff), dot(yuv, decode.Gcoeff), dot(yuv, decode.Bcoeff), 1.0);
+    return vert.color * float4(dot(yuv, decode.Rcoeff), dot(yuv, decode.Gcoeff), dot(yuv, decode.Bcoeff), 1.0);
 }
+