如何减轻C++中由虚函数调用引起的分支预测错误?

huangapple go评论67阅读模式
英文:

How to mitigate branch mispredictions due to virtual function calls in C++?

问题

背景: 我正在开发一个复杂语言的编译器(因此有许多不同类型的AST节点)。我需要一种遍历AST的方法,目前我使用虚拟函数为不同的AST节点创建迭代器。例如,operator++ 实现如下:

namespace ast {

class Iterator {
  Iterator operator++() {
    if (++curr_it_)
      return curr_it_;

    auto curr_node_ = /* 获取相邻节点 */;
    if (!curr_node_)
      return nullptr;

    return (curr_it_ = curr_node_->create_iterator());
  }
};

InternalIterator BinaryExpr::create_iterator() { /* ... */ }
InternalIterator ArrayLiteral::create_iterator() { /* ... */ }
InternalIterator FunctionCall::create_iterator() { /* ... */ }

}  // namespace ast

问题: 性能分析结果显示,调用create_iterator 导致了50%的分支预测失败率。我应该采取什么措施来减轻这个问题?

附言: 我使用 valgrind --tool=cachegrind --branch-sim=yes 来对我的编译器进行性能分析,而大部分分支预测失败都来自间接函数调用。

英文:

Background: I'm working on a compiler of a complex language (therefore there are many different types of AST nodes). I need a way to traverse an AST, and currently I use virtual functions to create iterators for different AST nodes. For example, operator++ is implemented as:

namespace ast {

class Iterator {
  Iterator operator++() {
    if (++curr_it_)
      return curr_it_;

    auto curr_node_ = /* get the neighbor */;
    if (!curr_node_)
      return nullptr;

    return (curr_it_ = curr_node_->create_iterator());
  }
};

InternalIterator BinaryExpr::create_iterator() { /* ... */ }
InternalIterator ArrayLiteral::create_iterator() { /* ... */ }
InternalIterator FunctionCall::create_iterator() { /* ... */ }

}  // namespace ast

Problem: the profiling result shows that the call to create_iterator causes 50% branch misprediction rate. What can I do to mitigate it?

P.S. I use valgrind --tool=cachegrind --branch-sim=yes to profile my compiler, and most of the mispredictions come from indirect function call.

答案1

得分: 4

如果你想减少分支错误预测,只有一小部分方法可以实现:

  • 通过使用条件指令如CMOV来减少条件分支,或者使用与分支无关的特定用例指令序列。例如,如果你的各种迭代器具有相同的行为但不同的数据步进,那么你可以删除虚拟调用,并通过查找表(LUT)或其他方式设置步进。在这种情况下可能不可行。

  • 重新排列数据以帮助分支预测器。几乎肯定在这里不可行。

  • 如果分支在分支缓存中没有被正确预测,可以对分支条件取反。这是PGO的主要功能之一。当然,如果你的错误预测率为50%,那就不适用了。

  • 重新排列代码以使分支预测器工作更好。这非常琐碎且与处理器相关(例如,你可能会优化为特定型号的AMD处理器),并且在现代分支预测器中不太可能获得太多好处。唯一的例外是短循环(但不要太短)比长循环更容易预测,但在这种情况下,你无法控制太多。

该分支预测不佳是因为分支是该代码的主要特点。分支预测在高层面上是一种动态优化低熵跟踪执行的方式。换句话说,它是CPU用于注意和利用分支选择模式的一种方式。如果没有可识别的模式,分支(错误)预测对你的情况不适用。

英文:

If you're trying to reduce branch mispredictions, there's only a small number of ways that can be done:

  • Reduce conditional branches by using conditional instructions like CMOV or use case-specific instruction sequences which don't involve branching. For instance, if your various iterators had the same behavior but with different, like, data striding, then you could remove the virtual call and set the striding through a LUT or something. Probably not possible here.

  • Rearrange the data to help the branch predictor. Almost certainly not possible here.

  • Negate the branch condition if a branch is being mispredicted when it isn't in the branch cache. (This is one of the main things PGO would do.) Not applicable if your misprediction rate is 50%, of course.

  • Rearrange the code so that the branch predictor works better. This is extremely fiddly and processor-specific (e.g. you'd be optimizing for, say, a particular model of AMD processor) and unlikely to yield much benefit with modern branch predictors. The one exception here is that short loops

    are easier to predict than long ones, but in this case you don't have a lot of control over that.

That branch is being predicted poorly because branching is what that code does. Branch prediction, at a high level, is a way to dynamically optimize the execution of low-entropy traces. Put differently, it's a way for the CPU to notice and leverage patterns in which branches you take. If there's no patterns to recognize, branch (mis)prediction is not applicable to your situation.

答案2

得分: 1

我将为您翻译代码部分:

我将完全删除虚拟方法调用,改用switch语句

InternalIterator create_iterator(Node* node) {
    NodeId id = node->id;

    static_assert(NODE_COUNT == 10, "您忘了在这里添加或删除一个新的情况");
    switch (id) {
        case BINARY_EXPR:   return create_binary_expr_iterator((BinaryExpr*)(node);
        case ARRAY_LITERAL: return create_array_literal_iterator((ArrayLiteral*)(node);
        case FUNCTION_CALL: return create_function_call_iterator((FunctionCall*)(node);
        // 其余部分 ....
    }
}
这是一个相当直观的方法,很可能会提高性能,因为不需要虚拟调度,而且switch语句很可能会转换为跳转表

但是,您也可以使用访问者模式而不是创建抽象迭代器来遍历AST。访问者模式可以与CRTP一起使用,以进行静态调度,从而减少一些分支,并且非常方便使用。

这是一个示例:

template <class Visitor, class T = void*>
class AstVisitor {
    AstTree& tree;
public:
    explicit AstVisitor(AstTree& tree) : tree(tree) {}

    T begin() { return visit(tree.start); }

protected:
    // ... (其他visit_*函数)
    
    T visit(size_t i) {
        Node* node = node_at(i);
        auto id = node->id;
        switch (id) {
            // ... (其他情况)
            default: return nullptr;
        }
    }

private:
    Node* node_at(u32 i) { return &tree.nodes.at(i); }
};
实现:

class AstPrinter : public AstVisitor<AstPrinter> {
public:
    int indentation = 0;

    explicit AstPrinter(AstTree& tree) : AstVisitor(tree) {}

    void indent() {
        printf("%.*s", indentation * 4, "                                                                              ");
    }

    void* visit_ident(Node::Ident* node) {
        auto name = name_at(node->name);
        printf("Ident {" SV_FMT "}", SV_ARG(name));
        return nullptr;
    }

    // ... (其他visit_*函数)

};

这样可以最小化分支数量,并通过CRTP进行静态调度,而不是依赖于动态调度。

英文:

I'd remove the virtual method calls completely and use a switch instead.

InternalIterator create_iterator(Node* node) {
    NodeId id = node-&gt;id;

    static_assert(NODE_COUNT == 10, &quot;You forgot to add or remove a new case here&quot;);
    switch (id) {
        case BINARY_EXPR:   return create_binary_expr_iterator((BinaryExpr*)(node);
        case ARRAY_LITERAL: return create_array_literal_iterator((ArrayLiteral*)(node);
        case FUNCTION_CALL: return create_function_call_iterator((FunctionCall*)(node);
        // The rest ....
    }
}

It's quite straight-forward and could most likely increase performance as you won't do a virtual dispatch, and the switch is probably getting converted to a jump table.

However, you can also traverse the ast using the visitor pattern instead of creating abstract iterators. The visitor pattern can be used together with CRTP to do static dispatch would remove some branching, and it's quite convenient working with.

Here's an example how it could look like:

template &lt;class Visitor, class T = void*&gt;
class AstVisitor {
    AstTree&amp; tree;
public:
    explicit AstVisitor(AstTree&amp; tree) : tree(tree) {}

    T begin() { return visit(tree.start); }

protected:
    T visit_ident(Node* node) {
        return static_cast&lt;Visitor*&gt;(this)-&gt;visit_ident(&amp;node-&gt;ident);
    }

    T visit_lit(Node* node) {
        return static_cast&lt;Visitor*&gt;(this)-&gt;visit_lit(&amp;node-&gt;lit);
    }

    T visit_type(Node* node) {
        return static_cast&lt;Visitor*&gt;(this)-&gt;visit_type(&amp;node-&gt;type);
    }

    T visit_bin_op(Node* node) {
        return static_cast&lt;Visitor*&gt;(this)-&gt;visit_bin_op(&amp;node-&gt;bin_op);
    }

    T visit_var_decl(Node* node) {
        return static_cast&lt;Visitor*&gt;(this)-&gt;visit_var_decl(&amp;node-&gt;var_decl);
    }

    T visit_if_expr(Node* node) {
        return static_cast&lt;Visitor*&gt;(this)-&gt;visit_if_expr(&amp;node-&gt;if_expr);
    }

    T visit_param_decl(Node* node) {
        return static_cast&lt;Visitor*&gt;(this)-&gt;visit_param_decl(&amp;node-&gt;param_decl);
    }

    T visit_fn_decl(Node* node) {
        return static_cast&lt;Visitor*&gt;(this)-&gt;visit_fn_decl(&amp;node-&gt;fn_decl);
    }

    T visit_block(Node* node) {
        return static_cast&lt;Visitor*&gt;(this)-&gt;visit_block(&amp;node-&gt;block);
    }

    T visit_expression(size_t i) {
        Node* node = node_at(i);
        auto  id   = node-&gt;id;
        switch (id) {
            case Node::IDENT:     return static_cast&lt;Visitor*&gt;(this)-&gt;visit_ident(&amp;node-&gt;ident);
            case Node::LIT:       return static_cast&lt;Visitor*&gt;(this)-&gt;visit_lit(&amp;node-&gt;lit);
            case Node::TYPE:      return static_cast&lt;Visitor*&gt;(this)-&gt;visit_type(&amp;node-&gt;type);
            case Node::BIN_OP:    return static_cast&lt;Visitor*&gt;(this)-&gt;visit_bin_op(&amp;node-&gt;bin_op);
            case Node::IF_EXPR:   return static_cast&lt;Visitor*&gt;(this)-&gt;visit_if_expr(&amp;node-&gt;if_expr);
            default:  return nullptr;
        }
    }

    T visit(size_t i) {
        Node* node = node_at(i);
        auto  id   = node-&gt;id;
        switch (id) {
            case Node::IDENT:     return static_cast&lt;Visitor*&gt;(this)-&gt;visit_ident(&amp;node-&gt;ident);
            case Node::LIT:       return static_cast&lt;Visitor*&gt;(this)-&gt;visit_lit(&amp;node-&gt;lit);
            case Node::TYPE:      return static_cast&lt;Visitor*&gt;(this)-&gt;visit_type(&amp;node-&gt;type);
            case Node::BIN_OP:    return static_cast&lt;Visitor*&gt;(this)-&gt;visit_bin_op(&amp;node-&gt;bin_op);
            case Node::VAR_DECL:  return static_cast&lt;Visitor*&gt;(this)-&gt;visit_var_decl(&amp;node-&gt;var_decl);
            case Node::IF_EXPR:   return static_cast&lt;Visitor*&gt;(this)-&gt;visit_if_expr(&amp;node-&gt;if_expr);
            case Node::PARAM_DECL: return static_cast&lt;Visitor*&gt;(this)-&gt;visit_param_decl(&amp;node-&gt;param_decl);
            case Node::FN_DECL:   return static_cast&lt;Visitor*&gt;(this)-&gt;visit_fn_decl(&amp;node-&gt;fn_decl);
            case Node::BLOCK:     return static_cast&lt;Visitor*&gt;(this)-&gt;visit_block(&amp;node-&gt;block);
            case Node::MODULE:    return static_cast&lt;Visitor*&gt;(this)-&gt;visit_module(&amp;node-&gt;module);
            default: return nullptr;
        }
    }

private:
    Node* node_at(u32 i) { return &amp;tree.nodes.at(i); }
};

Implementation:

class AstPrinter : public AstVisitor&lt;AstPrinter&gt; {
public:
    int indentation = 0;

    explicit AstPrinter(AstTree&amp; tree) : AstVisitor(tree) {}

    void indent() {
        printf(&quot;%.*s&quot;, indentation*4, &quot;                                                                              &quot;);
    }

    void* visit_ident(Node::Ident* node) {
        auto name = name_at(node-&gt;name);
        printf(&quot;Ident {&quot; SV_FMT &quot;}&quot;, SV_ARG(name));
        return nullptr;
    }

    void* visit_lit(Node::Lit* node) {
        auto print_value = [](Node::Lit* node) {
            switch (node-&gt;type) {
                case Type::VOID:    return printf(&quot;void&quot;);
                case Type::INT:     return printf(&quot;%lld&quot;, node-&gt;value.int64);
                case Type::REAL:    return printf(&quot;%f&quot;,   node-&gt;value.float64);
                case Type::NUMBER_OF_TYPE_TYPES:
                    return printf(&quot;&lt;invalid&gt;&quot;);
            }
        };

        auto type = name_at(node-&gt;type);
        printf(&quot;Lit { type=&quot; SV_FMT &quot;, value=&quot;, SV_ARG(type));
        print_value(node);
        printf(&quot;}&quot;);
        return nullptr;
    }

    void* visit_type(Node::Type* node) {
        auto type = name_at(node-&gt;type);
        printf(&quot;Type {&quot; SV_FMT &quot; }&quot;, SV_ARG(type));
        return nullptr;
    }

    void* visit_bin_op(Node::BinOp* node) {
        printf(&quot;BinOp { left=&quot;);
        visit(node-&gt;left);
        printf(&quot;, operation=%d, right=&quot;, node-&gt;op);
        visit(node-&gt;right);
        printf(&quot; }&quot;);
        return nullptr;
    }

    void* visit_var_decl(Node::VarDecl* node) {
        auto name = name_at(node-&gt;name);
        auto type = (node-&gt;type != INVALID_TYPE) ? name_at(node-&gt;type) : Str(&quot;&lt;none&gt;&quot;, sizeof(&quot;&lt;none&gt;&quot;));
        printf(&quot;VarDecl { name=&quot; SV_FMT &quot;, type=&quot; SV_FMT &quot;, expr=&quot;, SV_ARG(name), SV_ARG(type));
        visit_expression(node-&gt;expr);
        printf(&quot;}&quot;);
        return nullptr;
    }

    void* visit_if_expr(Node::IfExpr* node) {
        printf(&quot;IfExpr { condition=&quot;);
        visit(node-&gt;cond);
        printf(&quot;, left=&quot;);
        visit(node-&gt;left);
        printf(&quot;, right=&quot;);
        visit(node-&gt;right);
        printf(&quot;}&quot;);
        return nullptr;
    }

    void* visit_param_decl(Node::ParamDecl* node) {
        auto name = name_at(node-&gt;name);
        auto type = (node-&gt;type != INVALID_TYPE) ? name_at(node-&gt;type) : Str(&quot;&lt;none&gt;&quot;, sizeof(&quot;&lt;none&gt;&quot;));
        printf(&quot;ParamDecl { name=&quot; SV_FMT &quot;, type=&quot; SV_FMT &quot;, expr=&quot;, SV_ARG(name), SV_ARG(type));
        if (node-&gt;expr != NO_EXPR) { visit_expression(node-&gt;expr); } else { printf(&quot;&lt;none&gt;&quot;); }
        printf(&quot;}&quot;);
        return nullptr;
    }

    void* visit_fn_decl(Node::FnDecl* node) {
        auto name = name_at(node-&gt;name);
        printf(&quot;FnDecl { name=&quot; SV_FMT &quot;, &quot;, SV_ARG(name));

        auto params = view_at(node-&gt;params);
        for (u32 i = 0; i &lt; params.size(); ++i) {
            auto param = params[i];
            printf(&quot;param[%d] = &quot;, i);
            visit(param);
            printf(&quot;, &quot;);
        }

        printf(&quot;, body=&quot;);
        visit(node-&gt;block);

        return nullptr;
    }

    void* visit_block(Node::Block* node) {
        printf(&quot;Block {\n&quot;);
        indentation += 1;
        indent();

        auto stmts = view_at(node-&gt;stmts);
        for (u32 i = 0; i &lt; stmts.size(); ++i) {
            auto stmt = stmts[i];
            indent();
            printf(&quot;stmt[%d] = &quot;, i);
            visit(stmt);
            printf(&quot;\n&quot;);
        }

        indentation -= 1;
        indent();
        printf(&quot;}&quot;);

        return nullptr;
    }

    void* visit_module(Node::Module* node) {
        auto name = name_at(node-&gt;name);

        printf(&quot;Module {\n&quot;);
        indentation += 1;

        indent();
        printf(&quot;name = &quot; SV_FMT &quot;\n&quot;, SV_ARG(name));

        auto stmts = view_at(node-&gt;stmts);
        for (u32 i = 0; i &lt; stmts.size(); ++i) {
            auto stmt = stmts[i];
            indent();
            printf(&quot;stmt[%d] = &quot;, i);
            visit(stmt);
            printf(&quot;\n&quot;);
        }

        indentation -= 1;
        indent();
        printf(&quot;}&quot;);
        return nullptr;
    }
};

This minimizes the amount of branching and uses static dispatch through CRTP instead of relying on dynamic dispatch.

huangapple
  • 本文由 发表于 2023年7月27日 21:21:29
  • 转载请务必保留本文链接:https://go.coder-hub.com/76780185.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定