Edit

kc3-lang/angle/src/compiler/translator/ASTMetadataHLSL.cpp

Branch :

  • Show log

    Commit

  • Author : Corentin Wallez
    Date : 2015-04-29 10:15:16
    Hash : a1884f2b
    Message : Fixes for the tagging of discontinuous loops The first fix was for all loops being considered discontinuous in OutputHLSL because of a typo that produced a tautology. The error was not detected by the unit tests because as an optimization we do not generate Lod0 calls when they are not needed for the callee function (which was correctly detected by the analysis in this case). Fixed the unit tests by adding a call to a builtin gradient operation. The second fix was for discard not being taken into account in the analyses of the AST, which caused a WebGL test regression after the first fix for conformance/glsl/bugs/conditional-discard-in-loop BUG=angleproject:982 Change-Id: I1315eac1ad36f726be52d7fda5facf3104341b1f Reviewed-on: https://chromium-review.googlesource.com/267814 Tested-by: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Jamie Madill <jmadill@chromium.org> Reviewed-by: Geoff Lang <geofflang@chromium.org>

  • src/compiler/translator/ASTMetadataHLSL.cpp
  • //
    // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
    // Use of this source code is governed by a BSD-style license that can be
    // found in the LICENSE file.
    //
    
    // Analysis of the AST needed for HLSL generation
    
    #include "compiler/translator/ASTMetadataHLSL.h"
    
    #include "compiler/translator/CallDAG.h"
    #include "compiler/translator/SymbolTable.h"
    
    namespace
    {
    
    // Class used to traverse the AST of a function definition, checking if the
    // function uses a gradient, and writing the set of control flow using gradients.
    // It assumes that the analysis has already been made for the function's
    // callees.
    class PullGradient : public TIntermTraverser
    {
      public:
        PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag)
            : TIntermTraverser(true, false, true),
              mMetadataList(metadataList),
              mMetadata(&(*metadataList)[index]),
              mIndex(index),
              mDag(dag)
        {
            ASSERT(index < metadataList->size());
        }
    
        void traverse(TIntermAggregate *node)
        {
            node->traverse(this);
            ASSERT(mParents.empty());
        }
    
        // Called when a gradient operation or a call to a function using a gradient is found.
        void onGradient()
        {
            mMetadata->mUsesGradient = true;
            // Mark the latest control flow as using a gradient.
            if (!mParents.empty())
            {
                mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
            }
        }
    
        void visitControlFlow(Visit visit, TIntermNode *node)
        {
            if (visit == PreVisit)
            {
                mParents.push_back(node);
            }
            else if (visit == PostVisit)
            {
                ASSERT(mParents.back() == node);
                mParents.pop_back();
                // A control flow's using a gradient means its parents are too.
                if (mMetadata->mControlFlowsContainingGradient.count(node)> 0 && !mParents.empty())
                {
                    mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
                }
            }
        }
    
        bool visitLoop(Visit visit, TIntermLoop *loop)
        {
            visitControlFlow(visit, loop);
            return true;
        }
    
        bool visitSelection(Visit visit, TIntermSelection *selection)
        {
            visitControlFlow(visit, selection);
            return true;
        }
    
        bool visitUnary(Visit visit, TIntermUnary *node) override
        {
            if (visit == PreVisit)
            {
                switch (node->getOp())
                {
                  case EOpDFdx:
                  case EOpDFdy:
                    onGradient();
                  default:
                    break;
                }
            }
    
            return true;
        }
    
        bool visitAggregate(Visit visit, TIntermAggregate *node) override
        {
            if (visit == PreVisit)
            {
                if (node->getOp() == EOpFunctionCall)
                {
                    if (node->isUserDefined())
                    {
                        size_t calleeIndex = mDag.findIndex(node);
                        ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
    
                        if ((*mMetadataList)[calleeIndex].mUsesGradient) {
                            onGradient();
                        }
                    }
                    else
                    {
                        TString name = TFunction::unmangleName(node->getName());
    
                        if (name == "texture2D" ||
                            name == "texture2DProj" ||
                            name == "textureCube")
                        {
                            onGradient();
                        }
                    }
                }
            }
    
            return true;
        }
    
      private:
        MetadataList *mMetadataList;
        ASTMetadataHLSL *mMetadata;
        size_t mIndex;
        const CallDAG &mDag;
    
        // Contains a stack of the control flow nodes that are parents of the node being
        // currently visited. It is used to mark control flows using a gradient.
        std::vector<TIntermNode*> mParents;
    };
    
    // Traverses the AST of a function definition, assuming it has already been used to
    // traverse the callees of that function; computes the discontinuous loops and the if
    // statements that contain a discontinuous loop in their call graph.
    class PullComputeDiscontinuousLoops : public TIntermTraverser
    {
      public:
        PullComputeDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
            : TIntermTraverser(true, false, true),
              mMetadataList(metadataList),
              mMetadata(&(*metadataList)[index]),
              mIndex(index),
              mDag(dag)
        {
        }
    
        void traverse(TIntermAggregate *node)
        {
            node->traverse(this);
            ASSERT(mLoopsAndSwitches.empty());
            ASSERT(mIfs.empty());
        }
    
        // Called when a discontinuous loop or a call to a function with a discontinuous loop
        // in its call graph is found.
        void onDiscontinuousLoop()
        {
            mMetadata->mHasDiscontinuousLoopInCallGraph = true;
            // Mark the latest if as using a discontinuous loop.
            if (!mIfs.empty())
            {
                mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back());
            }
        }
    
        bool visitLoop(Visit visit, TIntermLoop *loop) override
        {
            if (visit == PreVisit)
            {
                mLoopsAndSwitches.push_back(loop);
            }
            else if (visit == PostVisit)
            {
                ASSERT(mLoopsAndSwitches.back() == loop);
                mLoopsAndSwitches.pop_back();
            }
    
            return true;
        }
    
        bool visitSelection(Visit visit, TIntermSelection *node) override
        {
            if (visit == PreVisit)
            {
                mIfs.push_back(node);
            }
            else if (visit == PostVisit)
            {
                ASSERT(mIfs.back() == node);
                mIfs.pop_back();
                // An if using a discontinuous loop means its parents ifs are also discontinuous.
                if (mMetadata->mIfsContainingDiscontinuousLoop.count(node) > 0 && !mIfs.empty())
                {
                    mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back());
                }
            }
    
            return true;
        }
    
        bool visitBranch(Visit visit, TIntermBranch *node) override
        {
            if (visit == PreVisit)
            {
                switch (node->getFlowOp())
                {
                  case EOpBreak:
                    {
                        ASSERT(!mLoopsAndSwitches.empty());
                        TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode();
                        if (loop != nullptr)
                        {
                            mMetadata->mDiscontinuousLoops.insert(loop);
                            onDiscontinuousLoop();
                        }
                    }
                    break;
                  case EOpContinue:
                    {
                        ASSERT(!mLoopsAndSwitches.empty());
                        TIntermLoop *loop = nullptr;
                        size_t i = mLoopsAndSwitches.size();
                        while (loop == nullptr && i > 0)
                        {
                            --i;
                            loop = mLoopsAndSwitches.at(i)->getAsLoopNode();
                        }
                        ASSERT(loop != nullptr);
                        mMetadata->mDiscontinuousLoops.insert(loop);
                        onDiscontinuousLoop();
                    }
                    break;
                  case EOpKill:
                  case EOpReturn:
                    // A return or discard jumps out of all the enclosing loops
                    if (!mLoopsAndSwitches.empty())
                    {
                        for (TIntermNode* node : mLoopsAndSwitches)
                        {
                            TIntermLoop *loop = node->getAsLoopNode();
                            if (loop)
                            {
                                mMetadata->mDiscontinuousLoops.insert(loop);
                            }
                        }
                        onDiscontinuousLoop();
                    }
                    break;
                  default:
                    UNREACHABLE();
                }
            }
    
            return true;
        }
    
        bool visitAggregate(Visit visit, TIntermAggregate *node) override
        {
            if (visit == PreVisit && node->getOp() == EOpFunctionCall)
            {
                if (node->isUserDefined())
                {
                    size_t calleeIndex = mDag.findIndex(node);
                    ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
    
                    if ((*mMetadataList)[calleeIndex].mHasDiscontinuousLoopInCallGraph)
                    {
                        onDiscontinuousLoop();
                    }
                }
            }
    
            return true;
        }
    
        bool visitSwitch(Visit visit, TIntermSwitch *node) override
        {
            if (visit == PreVisit)
            {
                mLoopsAndSwitches.push_back(node);
            }
            else if (visit == PostVisit)
            {
                ASSERT(mLoopsAndSwitches.back() == node);
                mLoopsAndSwitches.pop_back();
            }
            return true;
        }
    
      private:
        MetadataList *mMetadataList;
        ASTMetadataHLSL *mMetadata;
        size_t mIndex;
        const CallDAG &mDag;
    
        std::vector<TIntermNode*> mLoopsAndSwitches;
        std::vector<TIntermSelection*> mIfs;
    };
    
    // Tags all the functions called in a discontinuous loop
    class PushDiscontinuousLoops : public TIntermTraverser
    {
      public:
        PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
            : TIntermTraverser(true, true, true),
              mMetadataList(metadataList),
              mMetadata(&(*metadataList)[index]),
              mIndex(index),
              mDag(dag),
              mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
        {
        }
    
        void traverse(TIntermAggregate *node)
        {
            node->traverse(this);
            ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
        }
    
        bool visitLoop(Visit visit, TIntermLoop *loop)
        {
            bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
    
            if (visit == PreVisit && isDiscontinuous)
            {
                mNestedDiscont++;
            }
            else if (visit == PostVisit && isDiscontinuous)
            {
                mNestedDiscont--;
            }
    
            return true;
        }
    
        bool visitAggregate(Visit visit, TIntermAggregate *node) override
        {
            switch (node->getOp())
            {
              case EOpFunctionCall:
                if (visit == PreVisit && node->isUserDefined() && mNestedDiscont > 0)
                {
                    size_t calleeIndex = mDag.findIndex(node);
                    ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
    
                    (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
                }
                break;
              default:
                break;
            }
            return true;
        }
    
      private:
        MetadataList *mMetadataList;
        ASTMetadataHLSL *mMetadata;
        size_t mIndex;
        const CallDAG &mDag;
    
        int mNestedDiscont;
    };
    
    }
    
    bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermSelection *node)
    {
        return mControlFlowsContainingGradient.count(node) > 0;
    }
    
    bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
    {
        return mControlFlowsContainingGradient.count(node) > 0;
    }
    
    bool ASTMetadataHLSL::hasDiscontinuousLoop(TIntermSelection *node)
    {
        return mIfsContainingDiscontinuousLoop.count(node) > 0;
    }
    
    MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
    {
        MetadataList metadataList(callDag.size());
    
        // Compute all the information related to when gradient operations are used.
        // We want to know for each function and control flow operation if they have
        // a gradient operation in their call graph (shortened to "using a gradient"
        // in the rest of the file).
        //
        // This computation is logically split in three steps:
        //  1 - For each function compute if it uses a gradient in its body, ignoring
        // calls to other user-defined functions.
        //  2 - For each function determine if it uses a gradient in its call graph,
        // using the result of step 1 and the CallDAG to know its callees.
        //  3 - For each control flow statement of each function, check if it uses a
        // gradient in the function's body, or if it calls a user-defined function that
        // uses a gradient.
        //
        // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3
        // for leaves first, then going down the tree. This is correct because 1 doesn't
        // depend on other functions, and 2 and 3 depend only on callees.
        for (size_t i = 0; i < callDag.size(); i++)
        {
            PullGradient pull(&metadataList, i, callDag);
            pull.traverse(callDag.getRecordFromIndex(i).node);
        }
    
        // Compute which loops are discontinuous and which function are called in
        // these loops. The same way computing gradient usage is a "pull" process,
        // computing "bing used in a discont. loop" is a push process. However we also
        // need to know what ifs have a discontinuous loop inside so we do the same type
        // of callgraph analysis as for the gradient.
    
        // First compute which loops are discontinuous (no specific order) and pull
        // the ifs and functions using a discontinuous loop.
        for (size_t i = 0; i < callDag.size(); i++)
        {
            PullComputeDiscontinuousLoops pull(&metadataList, i, callDag);
            pull.traverse(callDag.getRecordFromIndex(i).node);
        }
    
        // Then push the information to callees, either from the a local discontinuous
        // loop or from the caller being called in a discontinuous loop already
        for (size_t i = callDag.size(); i-- > 0;)
        {
            PushDiscontinuousLoops push(&metadataList, i, callDag);
            push.traverse(callDag.getRecordFromIndex(i).node);
        }
    
        // We create "Lod0" version of functions with the gradient operations replaced
        // by non-gradient operations so that the D3D compiler is happier with discont
        // loops.
        for (auto &metadata : metadataList)
        {
            metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
        }
    
        return metadataList;
    }