Edit

kc3-lang/angle/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp

Branch :

  • Show log

    Commit

  • Author : Shahbaz Youssefi
    Date : 2021-08-05 10:41:59
    Hash : 210773db
    Message : Translator: Be more explicit about precisions GLSL ES requires that every symbol (variable, block member, function parameter and return value) is appropriately qualified with a precision, either individually or through the global precision specifier. Some tree transformations however produced symbols with EbpUndefined precision. In text GLSL output, these would produce unqualified symbols which was often incorrect. In this change, the transformations are made to produce explicit / more consistent precisions. The validation (that caught these issues) is not included in this change as there are still a few corner cases left to address. Bug: angleproject:4889 Bug: angleproject:6132 Change-Id: Icca8a0a5476f8646226e7243aa8f501f44acc164 Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/3075127 Reviewed-by: Tim Van Patten <timvp@google.com> Reviewed-by: Jamie Madill <jmadill@chromium.org> Commit-Queue: Shahbaz Youssefi <syoussefi@chromium.org>

  • src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp
  • //
    // Copyright 2002 The ANGLE Project Authors. All rights reserved.
    // Use of this source code is governed by a BSD-style license that can be
    // found in the LICENSE file.
    //
    // RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of non-SSBO vectors and
    // matrices, replacing them with calls to functions that choose which component to return or write.
    // We don't need to consider dynamic indexing in SSBO since it can be directly as part of the offset
    // of RWByteAddressBuffer.
    //
    
    #include "compiler/translator/tree_ops/RemoveDynamicIndexing.h"
    
    #include "compiler/translator/Compiler.h"
    #include "compiler/translator/Diagnostics.h"
    #include "compiler/translator/InfoSink.h"
    #include "compiler/translator/StaticType.h"
    #include "compiler/translator/SymbolTable.h"
    #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
    #include "compiler/translator/tree_util/IntermNode_util.h"
    #include "compiler/translator/tree_util/IntermTraverse.h"
    
    namespace sh
    {
    
    namespace
    {
    
    using DynamicIndexingNodeMatcher = std::function<bool(TIntermBinary *)>;
    
    const TType *kIndexType = StaticType::Get<EbtInt, EbpHigh, EvqParamIn, 1, 1>();
    
    constexpr const ImmutableString kBaseName("base");
    constexpr const ImmutableString kIndexName("index");
    constexpr const ImmutableString kValueName("value");
    
    std::string GetIndexFunctionName(const TType &type, bool write)
    {
        TInfoSinkBase nameSink;
        nameSink << "dyn_index_";
        if (write)
        {
            nameSink << "write_";
        }
        if (type.isMatrix())
        {
            nameSink << "mat" << type.getCols() << "x" << type.getRows();
        }
        else
        {
            switch (type.getBasicType())
            {
                case EbtInt:
                    nameSink << "ivec";
                    break;
                case EbtBool:
                    nameSink << "bvec";
                    break;
                case EbtUInt:
                    nameSink << "uvec";
                    break;
                case EbtFloat:
                    nameSink << "vec";
                    break;
                default:
                    UNREACHABLE();
            }
            nameSink << type.getNominalSize();
        }
        return nameSink.str();
    }
    
    TIntermConstantUnion *CreateIntConstantNode(int i)
    {
        TConstantUnion *constant = new TConstantUnion();
        constant->setIConst(i);
        return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
    }
    
    TIntermTyped *EnsureSignedInt(TIntermTyped *node)
    {
        if (node->getBasicType() == EbtInt)
            return node;
    
        TIntermSequence arguments;
        arguments.push_back(node);
        return TIntermAggregate::CreateConstructor(TType(EbtInt), &arguments);
    }
    
    TType *GetFieldType(const TType &indexedType)
    {
        TType *fieldType = new TType(indexedType);
        if (indexedType.isMatrix())
        {
            fieldType->toMatrixColumnType();
        }
        else
        {
            ASSERT(indexedType.isVector());
            fieldType->toComponentType();
        }
        return fieldType;
    }
    
    const TType *GetBaseType(const TType &type, bool write)
    {
        TType *baseType = new TType(type);
        // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
        // end up using mediump version of an indexing function for a highp value, if both mediump and
        // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
        // principle this code could be used with multiple backends.
        baseType->setPrecision(EbpHigh);
        baseType->setQualifier(EvqParamInOut);
        if (!write)
            baseType->setQualifier(EvqParamIn);
        return baseType;
    }
    
    // Generate a read or write function for one field in a vector/matrix.
    // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
    // indices in other places.
    // Note that indices can be either int or uint. We create only int versions of the functions,
    // and convert uint indices to int at the call site.
    // read function example:
    // float dyn_index_vec2(in vec2 base, in int index)
    // {
    //    switch(index)
    //    {
    //      case (0):
    //        return base[0];
    //      case (1):
    //        return base[1];
    //      default:
    //        break;
    //    }
    //    if (index < 0)
    //      return base[0];
    //    return base[1];
    // }
    // write function example:
    // void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
    // {
    //    switch(index)
    //    {
    //      case (0):
    //        base[0] = value;
    //        return;
    //      case (1):
    //        base[1] = value;
    //        return;
    //      default:
    //        break;
    //    }
    //    if (index < 0)
    //    {
    //      base[0] = value;
    //      return;
    //    }
    //    base[1] = value;
    // }
    // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
    TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
                                                          bool write,
                                                          const TFunction &func,
                                                          TSymbolTable *symbolTable)
    {
        ASSERT(!type.isArray());
    
        int numCases = 0;
        if (type.isMatrix())
        {
            numCases = type.getCols();
        }
        else
        {
            numCases = type.getNominalSize();
        }
    
        std::string functionName                = GetIndexFunctionName(type, write);
        TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func);
    
        TIntermSymbol *baseParam  = new TIntermSymbol(func.getParam(0));
        TIntermSymbol *indexParam = new TIntermSymbol(func.getParam(1));
        TIntermSymbol *valueParam = nullptr;
        if (write)
        {
            valueParam = new TIntermSymbol(func.getParam(2));
        }
    
        TIntermBlock *statementList = new TIntermBlock();
        for (int i = 0; i < numCases; ++i)
        {
            TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
            statementList->getSequence()->push_back(caseNode);
    
            TIntermBinary *indexNode =
                new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i));
            if (write)
            {
                TIntermBinary *assignNode =
                    new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy());
                statementList->getSequence()->push_back(assignNode);
                TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
                statementList->getSequence()->push_back(returnNode);
            }
            else
            {
                TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
                statementList->getSequence()->push_back(returnNode);
            }
        }
    
        // Default case
        TIntermCase *defaultNode = new TIntermCase(nullptr);
        statementList->getSequence()->push_back(defaultNode);
        TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
        statementList->getSequence()->push_back(breakNode);
    
        TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList);
    
        TIntermBlock *bodyNode = new TIntermBlock();
        bodyNode->getSequence()->push_back(switchNode);
    
        TIntermBinary *cond =
            new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0));
    
        // Two blocks: one accesses (either reads or writes) the first element and returns,
        // the other accesses the last element.
        TIntermBlock *useFirstBlock = new TIntermBlock();
        TIntermBlock *useLastBlock  = new TIntermBlock();
        TIntermBinary *indexFirstNode =
            new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0));
        TIntermBinary *indexLastNode =
            new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1));
        if (write)
        {
            TIntermBinary *assignFirstNode =
                new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy());
            useFirstBlock->getSequence()->push_back(assignFirstNode);
            TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
            useFirstBlock->getSequence()->push_back(returnNode);
    
            TIntermBinary *assignLastNode =
                new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy());
            useLastBlock->getSequence()->push_back(assignLastNode);
        }
        else
        {
            TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
            useFirstBlock->getSequence()->push_back(returnFirstNode);
    
            TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
            useLastBlock->getSequence()->push_back(returnLastNode);
        }
        TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
        bodyNode->getSequence()->push_back(ifNode);
        bodyNode->getSequence()->push_back(useLastBlock);
    
        TIntermFunctionDefinition *indexingFunction =
            new TIntermFunctionDefinition(prototypeNode, bodyNode);
        return indexingFunction;
    }
    
    class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
    {
      public:
        RemoveDynamicIndexingTraverser(DynamicIndexingNodeMatcher &&matcher,
                                       TSymbolTable *symbolTable,
                                       PerformanceDiagnostics *perfDiagnostics);
    
        bool visitBinary(Visit visit, TIntermBinary *node) override;
    
        void insertHelperDefinitions(TIntermNode *root);
    
        void nextIteration();
    
        bool usedTreeInsertion() const { return mUsedTreeInsertion; }
    
      protected:
        // Maps of types that are indexed to the indexing function ids used for them. Note that these
        // can not store multiple variants of the same type with different precisions - only one
        // precision gets stored.
        std::map<TType, TFunction *> mIndexedVecAndMatrixTypes;
        std::map<TType, TFunction *> mWrittenVecAndMatrixTypes;
    
        bool mUsedTreeInsertion;
    
        // When true, the traverser will remove side effects from any indexing expression.
        // This is done so that in code like
        //   V[j++][i]++.
        // where V is an array of vectors, j++ will only be evaluated once.
        bool mRemoveIndexSideEffectsInSubtree;
    
        DynamicIndexingNodeMatcher mMatcher;
        PerformanceDiagnostics *mPerfDiagnostics;
    };
    
    RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(
        DynamicIndexingNodeMatcher &&matcher,
        TSymbolTable *symbolTable,
        PerformanceDiagnostics *perfDiagnostics)
        : TLValueTrackingTraverser(true, false, false, symbolTable),
          mUsedTreeInsertion(false),
          mRemoveIndexSideEffectsInSubtree(false),
          mMatcher(matcher),
          mPerfDiagnostics(perfDiagnostics)
    {}
    
    void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
    {
        TIntermBlock *rootBlock = root->getAsBlock();
        ASSERT(rootBlock != nullptr);
        TIntermSequence insertions;
        for (auto &type : mIndexedVecAndMatrixTypes)
        {
            insertions.push_back(
                GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable));
        }
        for (auto &type : mWrittenVecAndMatrixTypes)
        {
            insertions.push_back(
                GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable));
        }
        rootBlock->insertChildNodes(0, insertions);
    }
    
    // Create a call to dyn_index_*() based on an indirect indexing op node
    TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
                                              TIntermTyped *index,
                                              TFunction *indexingFunction)
    {
        ASSERT(node->getOp() == EOpIndexIndirect);
        TIntermSequence arguments;
        arguments.push_back(node->getLeft());
        arguments.push_back(index);
    
        TIntermAggregate *indexingCall =
            TIntermAggregate::CreateFunctionCall(*indexingFunction, &arguments);
        indexingCall->setLine(node->getLine());
        return indexingCall;
    }
    
    TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
                                                     TVariable *index,
                                                     TVariable *writtenValue,
                                                     TFunction *indexedWriteFunction)
    {
        ASSERT(node->getOp() == EOpIndexIndirect);
        TIntermSequence arguments;
        // Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
        arguments.push_back(node->getLeft()->deepCopy());
        arguments.push_back(CreateTempSymbolNode(index));
        arguments.push_back(CreateTempSymbolNode(writtenValue));
    
        TIntermAggregate *indexedWriteCall =
            TIntermAggregate::CreateFunctionCall(*indexedWriteFunction, &arguments);
        indexedWriteCall->setLine(node->getLine());
        return indexedWriteCall;
    }
    
    bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
    {
        if (mUsedTreeInsertion)
            return false;
    
        if (node->getOp() == EOpIndexIndirect)
        {
            if (mRemoveIndexSideEffectsInSubtree)
            {
                ASSERT(node->getRight()->hasSideEffects());
                // In case we're just removing index side effects, convert
                //   v_expr[index_expr]
                // to this:
                //   int s0 = index_expr; v_expr[s0];
                // Now v_expr[s0] can be safely executed several times without unintended side effects.
                TIntermDeclaration *indexVariableDeclaration = nullptr;
                TVariable *indexVariable = DeclareTempVariable(mSymbolTable, node->getRight(),
                                                               EvqTemporary, &indexVariableDeclaration);
                insertStatementInParentBlock(indexVariableDeclaration);
                mUsedTreeInsertion = true;
    
                // Replace the index with the temp variable
                TIntermSymbol *tempIndex = CreateTempSymbolNode(indexVariable);
                queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
            }
            else if (mMatcher(node))
            {
                if (mPerfDiagnostics)
                {
                    mPerfDiagnostics->warning(node->getLine(),
                                              "Performance: dynamic indexing of vectors and "
                                              "matrices is emulated and can be slow.",
                                              "[]");
                }
                bool write = isLValueRequiredHere();
    
    #if defined(ANGLE_ENABLE_ASSERTS)
                // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
                // implemented checks in this traverser.
                IntermNodePatternMatcher matcher(
                    IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
                ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
    #endif
    
                const TType &type = node->getLeft()->getType();
                ImmutableString indexingFunctionName(GetIndexFunctionName(type, false));
                TFunction *indexingFunction = nullptr;
                if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
                {
                    indexingFunction =
                        new TFunction(mSymbolTable, indexingFunctionName, SymbolType::AngleInternal,
                                      GetFieldType(type), true);
                    indexingFunction->addParameter(new TVariable(
                        mSymbolTable, kBaseName, GetBaseType(type, false), SymbolType::AngleInternal));
                    indexingFunction->addParameter(
                        new TVariable(mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal));
                    mIndexedVecAndMatrixTypes[type] = indexingFunction;
                }
                else
                {
                    indexingFunction = mIndexedVecAndMatrixTypes[type];
                }
    
                if (write)
                {
                    // Convert:
                    //   v_expr[index_expr]++;
                    // to this:
                    //   int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
                    //   dyn_index_write(v_expr, s0, s1);
                    // This works even if index_expr has some side effects.
                    if (node->getLeft()->hasSideEffects())
                    {
                        // If v_expr has side effects, those need to be removed before proceeding.
                        // Otherwise the side effects of v_expr would be evaluated twice.
                        // The only case where an l-value can have side effects is when it is
                        // indexing. For example, it can be V[j++] where V is an array of vectors.
                        mRemoveIndexSideEffectsInSubtree = true;
                        return true;
                    }
    
                    TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode();
                    if (leftBinary != nullptr && mMatcher(leftBinary))
                    {
                        // This is a case like:
                        // mat2 m;
                        // m[a][b]++;
                        // Process the child node m[a] first.
                        return true;
                    }
    
                    // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
                    // only writes it and doesn't need the previous value. http://anglebug.com/1116
    
                    TFunction *indexedWriteFunction = nullptr;
                    if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
                    {
                        ImmutableString functionName(
                            GetIndexFunctionName(node->getLeft()->getType(), true));
                        indexedWriteFunction =
                            new TFunction(mSymbolTable, functionName, SymbolType::AngleInternal,
                                          StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
                        indexedWriteFunction->addParameter(new TVariable(mSymbolTable, kBaseName,
                                                                         GetBaseType(type, true),
                                                                         SymbolType::AngleInternal));
                        indexedWriteFunction->addParameter(new TVariable(
                            mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal));
                        TType *valueType = GetFieldType(type);
                        valueType->setQualifier(EvqParamIn);
                        indexedWriteFunction->addParameter(new TVariable(
                            mSymbolTable, kValueName, static_cast<const TType *>(valueType),
                            SymbolType::AngleInternal));
                        mWrittenVecAndMatrixTypes[type] = indexedWriteFunction;
                    }
                    else
                    {
                        indexedWriteFunction = mWrittenVecAndMatrixTypes[type];
                    }
    
                    TIntermSequence insertionsBefore;
                    TIntermSequence insertionsAfter;
    
                    // Store the index in a temporary signed int variable.
                    // s0 = index_expr;
                    TIntermTyped *indexInitializer               = EnsureSignedInt(node->getRight());
                    TIntermDeclaration *indexVariableDeclaration = nullptr;
                    TVariable *indexVariable                     = DeclareTempVariable(
                        mSymbolTable, indexInitializer, EvqTemporary, &indexVariableDeclaration);
                    insertionsBefore.push_back(indexVariableDeclaration);
    
                    // s1 = dyn_index(v_expr, s0);
                    TIntermAggregate *indexingCall = CreateIndexFunctionCall(
                        node, CreateTempSymbolNode(indexVariable), indexingFunction);
                    TIntermDeclaration *fieldVariableDeclaration = nullptr;
                    TVariable *fieldVariable                     = DeclareTempVariable(
                        mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration);
                    insertionsBefore.push_back(fieldVariableDeclaration);
    
                    // dyn_index_write(v_expr, s0, s1);
                    TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
                        node, indexVariable, fieldVariable, indexedWriteFunction);
                    insertionsAfter.push_back(indexedWriteCall);
                    insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
    
                    // replace the node with s1
                    queueReplacement(CreateTempSymbolNode(fieldVariable), OriginalNode::IS_DROPPED);
                    mUsedTreeInsertion = true;
                }
                else
                {
                    // The indexed value is not being written, so we can simply convert
                    //   v_expr[index_expr]
                    // into
                    //   dyn_index(v_expr, index_expr)
                    // If the index_expr is unsigned, we'll convert it to signed.
                    ASSERT(!mRemoveIndexSideEffectsInSubtree);
                    TIntermAggregate *indexingCall = CreateIndexFunctionCall(
                        node, EnsureSignedInt(node->getRight()), indexingFunction);
                    queueReplacement(indexingCall, OriginalNode::IS_DROPPED);
                }
            }
        }
        return !mUsedTreeInsertion;
    }
    
    void RemoveDynamicIndexingTraverser::nextIteration()
    {
        mUsedTreeInsertion               = false;
        mRemoveIndexSideEffectsInSubtree = false;
    }
    
    bool RemoveDynamicIndexingIf(DynamicIndexingNodeMatcher &&matcher,
                                 TCompiler *compiler,
                                 TIntermNode *root,
                                 TSymbolTable *symbolTable,
                                 PerformanceDiagnostics *perfDiagnostics)
    {
        RemoveDynamicIndexingTraverser traverser(std::move(matcher), symbolTable, perfDiagnostics);
        do
        {
            traverser.nextIteration();
            root->traverse(&traverser);
            if (!traverser.updateTree(compiler, root))
            {
                return false;
            }
        } while (traverser.usedTreeInsertion());
        // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle
        // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are
        // function call nodes with no corresponding definition nodes. This needs special handling in
        // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a
        // superficial reading of the code.
        traverser.insertHelperDefinitions(root);
        return compiler->validateAST(root);
    }
    
    }  // namespace
    
    ANGLE_NO_DISCARD bool RemoveDynamicIndexingOfNonSSBOVectorOrMatrix(
        TCompiler *compiler,
        TIntermNode *root,
        TSymbolTable *symbolTable,
        PerformanceDiagnostics *perfDiagnostics)
    {
        DynamicIndexingNodeMatcher matcher = [](TIntermBinary *node) {
            return IntermNodePatternMatcher::IsDynamicIndexingOfNonSSBOVectorOrMatrix(node);
        };
        return RemoveDynamicIndexingIf(std::move(matcher), compiler, root, symbolTable,
                                       perfDiagnostics);
    }
    
    ANGLE_NO_DISCARD bool RemoveDynamicIndexingOfSwizzledVector(TCompiler *compiler,
                                                                TIntermNode *root,
                                                                TSymbolTable *symbolTable,
                                                                PerformanceDiagnostics *perfDiagnostics)
    {
        DynamicIndexingNodeMatcher matcher = [](TIntermBinary *node) {
            return IntermNodePatternMatcher::IsDynamicIndexingOfSwizzledVector(node);
        };
        return RemoveDynamicIndexingIf(std::move(matcher), compiler, root, symbolTable,
                                       perfDiagnostics);
    }
    
    }  // namespace sh