diff --git a/src/mongo/db/query/optimizer/algebra/algebra_test.cpp b/src/mongo/db/query/optimizer/algebra/algebra_test.cpp index 0f4964eacf..fd761ab838 100644 --- a/src/mongo/db/query/optimizer/algebra/algebra_test.cpp +++ b/src/mongo/db/query/optimizer/algebra/algebra_test.cpp @@ -30,6 +30,7 @@ #include "mongo/db/query/optimizer/algebra/operator.h" #include "mongo/db/query/optimizer/algebra/polyvalue.h" #include "mongo/unittest/unittest.h" +#include "mongo/util/visit_helper.h" namespace mongo::optimizer::algebra { @@ -549,5 +550,81 @@ TEST(PolyValueTest, WalkerBasic) { ASSERT(walk(tree.cast()->get<1>(), walker)); } +TEST(PolyValueTest, LambdaIsLeaf) { + auto tree = Tree::make(Tree::make(1.0), Tree::make(2.0)); + + ASSERT( + !walk(tree, + visit_helper::Overloaded{ + [](AtLeastBinaryNode&, std::vector&, Tree&, Tree&) { return false; }, + [](Leaf&) { return true; }, + [](BinaryNode&, Tree&, Tree&) { return false; }, + [](NaryNode&, std::vector&) { return false; }})); + ASSERT( + walk(tree.cast()->get<0>(), + visit_helper::Overloaded{ + [](AtLeastBinaryNode&, std::vector&, Tree&, Tree&) { return false; }, + [](Leaf&) { return true; }, + [](BinaryNode&, Tree&, Tree&) { return false; }, + [](NaryNode&, std::vector&) { return false; }})); +} + +TEST(PolyValueTest, LambdaWithDefault) { + auto tree = Tree::make(42); + ASSERT_EQ(42, + walk(tree, + visit_helper::Overloaded{ + [](auto&& node, ...) { return 42; }, + })); + ASSERT_EQ(42.0, + walk(tree, + visit_helper::Overloaded{ + [](Leaf& node) { return node.x; }, + [](auto&& node, ...) { return 43.0; }, + })); +} + +TEST(PolyValueTest, LambdaCountLeafs) { + auto tree = Tree::make(Tree::make(1.0), Tree::make(2.0)); + + auto count = 0; + transport(tree, visit_helper::Overloaded{[&](Leaf&) { count++; }, [](auto&&, ...) {}}); + ASSERT_EQ(2, count); + + // Now include the initial reference to the tree itself. + count = 0; + transport(tree, + visit_helper::Overloaded{[&](Tree&, Leaf&) { count++; }, [](auto&&, ...) {}}); + ASSERT_EQ(2, count); +} + +TEST(PolyValueTest, LambdaWithReturnType) { + auto tree = Tree::make(Tree::make(1.0), Tree::make(2.0)); + + double res = transport( + tree, + visit_helper::Overloaded{ + [](Leaf& node) { return node.x; }, + [](BinaryNode& node, double child0, double child1) { return child0 + child1; }, + [](auto&&, ...) { return 0.0; }}); + ASSERT_EQ(3.0, res); +} + +TEST(PolyValueTest, LambdaWithExtraArg) { + auto tree = Tree::make(Tree::make(1.0), Tree::make(2.0)); + + double multArg = 2.0; + double res = + transport(tree, + visit_helper::Overloaded{ + [](Leaf& node, double multiplier) { return node.x * multiplier; }, + [](BinaryNode& node, double multiplier, double child0, double child1) { + return (child0 + child1) * multiplier; + }, + [](auto&&, ...) { return 0.0; }}, + multArg); + ASSERT_EQ((1.0 * multArg + 2.0 * multArg) * multArg, res); +} + } // namespace } // namespace mongo::optimizer::algebra diff --git a/src/mongo/db/query/optimizer/algebra/operator.h b/src/mongo/db/query/optimizer/algebra/operator.h index e09c00580c..39b7ba562e 100644 --- a/src/mongo/db/query/optimizer/algebra/operator.h +++ b/src/mongo/db/query/optimizer/algebra/operator.h @@ -29,9 +29,11 @@ #pragma once +#include #include #include "mongo/db/query/optimizer/algebra/polyvalue.h" +#include "mongo/util/concepts.h" namespace mongo::optimizer { namespace algebra { @@ -67,17 +69,18 @@ class OpSpecificArity : public OpNodeStorage { using Base = OpNodeStorage; public: - template - OpSpecificArity(Ts&&... vals) : Base({std::forward(vals)...}) { - static_assert(sizeof...(Ts) == Arity, "constructor paramaters do not match"); - } + TEMPLATE(typename... Ts) + REQUIRES(sizeof...(Ts) == Arity) + OpSpecificArity(Ts&&... vals) : Base({std::forward(vals)...}) {} - template = 0 && I < Arity), int> = 0> + TEMPLATE(int I) + REQUIRES(I >= 0 && I < Arity) auto& get() noexcept { return this->_nodes[I]; } - template = 0 && I < Arity), int> = 0> + TEMPLATE(int I) + REQUIRES(I >= 0 && I < Arity) const auto& get() const noexcept { return this->_nodes[I]; } @@ -151,28 +154,61 @@ template using OpConcreteType = typename std::remove_reference_t::template get_t<0>; } // namespace detail +MONGO_MAKE_BOOL_TRAIT(IsCallable, + (typename Func, typename... Args), + (Func, Args...), + (Func & fn, Args&&... args), + // + fn(std::forward(args)...)); + template class OpTransporter { - D& _domain; + D _domain; - template + template struct Deducer {}; - template - struct Deducer { + + template + struct Deducer { using type = - decltype(std::declval().transport(std::declval(), - std::declval&>(), + decltype(std::declval().transport(std::declval(), + std::declval&>(), std::declval()...)); }; - template - struct Deducer { + + template + struct Deducer { using type = decltype(std::declval().transport( - std::declval&>(), std::declval()...)); + std::declval&>(), std::declval()...)); + }; + + template + struct Deducer { + using type = decltype(std::declval()(std::declval(), + std::declval&>(), + std::declval()...)); + }; + + template + struct Deducer { + using type = decltype( + std::declval()(std::declval&>(), std::declval()...)); }; - template - using deduced_t = typename Deducer::type; + template + using deduced_t = typename Deducer, Args...>::type; + + TEMPLATE(typename N, typename T, typename... Ts) + REQUIRES(IsCallable) + auto transformStep(N&& slot, T&& op, Ts&&... args) { + if constexpr (withSlot) { + return _domain(std::forward(slot), std::forward(op), std::forward(args)...); + } else { + return _domain(std::forward(op), std::forward(args)...); + } + } - template + TEMPLATE(typename N, typename T, typename... Ts) + REQUIRES(!IsCallable) auto transformStep(N&& slot, T&& op, Ts&&... args) { if constexpr (withSlot) { return _domain.transport( @@ -223,9 +259,9 @@ class OpTransporter { } public: - OpTransporter(D& domain) : _domain(domain) {} + OpTransporter(D&& domain) : _domain(std::forward(domain)) {} - template > + template > R operator()(N&& slot, T&& op, Args&&... args) { // N is either `PolyValue&` or `const PolyValue&` i.e. reference // T is either `A&` or `const A&` where A is one of Ts @@ -273,9 +309,20 @@ public: template class OpWalker { - D& _domain; + D _domain; + + TEMPLATE(typename N, typename T, typename... Ts) + REQUIRES(std::is_invocable_v) + auto walkStep(N&& slot, T&& op, Ts&&... args) { + if constexpr (withSlot) { + return _domain(std::forward(slot), std::forward(op), std::forward(args)...); + } else { + return _domain(std::forward(op), std::forward(args)...); + } + } - template + TEMPLATE(typename N, typename T, typename... Ts) + REQUIRES(!std::is_invocable_v) auto walkStep(N&& slot, T&& op, Ts&&... args) { if constexpr (withSlot) { return _domain.walk( @@ -302,7 +349,7 @@ class OpWalker { } public: - OpWalker(D& domain) : _domain(domain) {} + OpWalker(D&& domain) : _domain(std::forward(domain)) {} template auto operator()(N&& slot, T&& op, Args&&... args) { @@ -328,13 +375,14 @@ public: }; template -auto transport(N&& node, D& domain, Args&&... args) { - return node.visit(OpTransporter{domain}, std::forward(args)...); +auto transport(N&& node, D&& domain, Args&&... args) { + return node.visit(OpTransporter{std::forward(domain)}, + std::forward(args)...); } template -auto walk(N&& node, D& domain, Args&&... args) { - return node.visit(OpWalker{domain}, std::forward(args)...); +auto walk(N&& node, D&& domain, Args&&... args) { + return node.visit(OpWalker{std::forward(domain)}, std::forward(args)...); } } // namespace algebra