Hash :
16771259
        
        Author :
  
        
        Date :
2025-09-12T11:10:18
        
      
Update main() prototype in shader if main() has been replaced. If the main() function definition has been wrapped/replaced in a shader and a main() function prototype is also present, the prototype needs to be replaced. Otherwise it will continue to reference the replaced main() which can cause issues with shader translation steps. Bug: angleproject:444653099 Change-Id: Ie6ce85cac89e026876a1b6e25cd294f1d8a536c4 Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/6944807 Reviewed-by: Shahbaz Youssefi <syoussefi@chromium.org> Reviewed-by: Kimmo Kinnunen <kkinnunen@apple.com>
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
//
// Copyright 2017 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a
// return statement, this is done by replacing the main() function with another function that calls
// the old main, like this:
//
// void main() { body }
// =>
// void main0() { body }
// void main()
// {
//     main0();
//     codeToRun
// }
//
// This way the code will get run even if the return statement inside main is executed.
//
// This is done if main ends in an unconditional |discard| as well, to help with SPIR-V generation
// that expects no dead-code to be present after branches in a block.  To avoid bugs when |discard|
// is wrapped in unconditional blocks, any |discard| in main() is used as a signal to wrap it.
//
#include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
#include "compiler/translator/Compiler.h"
#include "compiler/translator/IntermNode.h"
#include "compiler/translator/StaticType.h"
#include "compiler/translator/SymbolTable.h"
#include "compiler/translator/tree_util/FindMain.h"
#include "compiler/translator/tree_util/IntermNode_util.h"
#include "compiler/translator/tree_util/IntermTraverse.h"
namespace sh
{
namespace
{
constexpr const ImmutableString kMainString("main");
class ContainsReturnOrDiscardTraverser : public TIntermTraverser
{
  public:
    ContainsReturnOrDiscardTraverser()
        : TIntermTraverser(true, false, false), mContainsReturnOrDiscard(false)
    {}
    bool visitBranch(Visit visit, TIntermBranch *node) override
    {
        if (node->getFlowOp() == EOpReturn || node->getFlowOp() == EOpKill)
        {
            mContainsReturnOrDiscard = true;
        }
        return false;
    }
    bool containsReturnOrDiscard() { return mContainsReturnOrDiscard; }
  private:
    bool mContainsReturnOrDiscard;
};
bool ContainsReturnOrDiscard(TIntermNode *node)
{
    ContainsReturnOrDiscardTraverser traverser;
    node->traverse(&traverser);
    return traverser.containsReturnOrDiscard();
}
void WrapMainAndAppend(TIntermBlock *root,
                       TIntermFunctionDefinition *main,
                       TIntermNode *codeToRun,
                       TSymbolTable *symbolTable)
{
    // Replace main() with main0() with the same body.
    TFunction *oldMain =
        new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
                      StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
    TIntermFunctionDefinition *oldMainDefinition =
        CreateInternalFunctionDefinitionNode(*oldMain, main->getBody());
    bool replaced = root->replaceChildNode(main, oldMainDefinition);
    ASSERT(replaced);
    // void main()
    TFunction *newMain = new TFunction(symbolTable, kMainString, SymbolType::UserDefined,
                                       StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
    TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(newMain);
    // {
    //     main0();
    //     codeToRun
    // }
    TIntermBlock *newMainBody = new TIntermBlock();
    TIntermSequence emptySequence;
    TIntermAggregate *oldMainCall = TIntermAggregate::CreateFunctionCall(*oldMain, &emptySequence);
    newMainBody->appendStatement(oldMainCall);
    newMainBody->appendStatement(codeToRun);
    // Add the new main() to the root node.
    TIntermFunctionDefinition *newMainDefinition =
        new TIntermFunctionDefinition(newMainProto, newMainBody);
    root->appendStatement(newMainDefinition);
    // If a function prototype of main() also exists, it will need to be replaced. Otherwise it will
    // continue to internally reference the TFunction of the replaced 'main' function definition.
    TIntermFunctionPrototype *oldMainProto = FindMainPrototype(root);
    if (oldMainProto)
    {
        // Replace the prototype node but initialize it with the newMain TFunction; now this newly
        // created main() prototype will reference the new TFunction of newMain.
        newMainProto = new TIntermFunctionPrototype(newMain);
        replaced     = root->replaceChildNode(oldMainProto, newMainProto);
    }
}
}  // anonymous namespace
bool RunAtTheEndOfShader(TCompiler *compiler,
                         TIntermBlock *root,
                         TIntermNode *codeToRun,
                         TSymbolTable *symbolTable)
{
    TIntermFunctionDefinition *main = FindMain(root);
    if (ContainsReturnOrDiscard(main))
    {
        WrapMainAndAppend(root, main, codeToRun, symbolTable);
    }
    else
    {
        main->getBody()->appendStatement(codeToRun);
    }
    return compiler->validateAST(root);
}
}  // namespace sh