commit 4d9e543e7e8aa93a18e8bf7668a42b28aa1177b5 Author: Charlie Swanson Date: Tue Feb 9 21:53:03 2021 -0500 WIP - some unit tests are passing. Need to add utilization within sharded_agg_helpers diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 1026bcae7b..8846ee1ef4 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -1496,7 +1496,7 @@ public: class ExpressionFieldPath final : public Expression { public: bool isRootFieldPath() const { - return _variable == Variables::kRootId; + return _variable == Variables::kRootId && _fieldPath.getPathLength() == 1; } boost::intrusive_ptr optimize() final; diff --git a/src/mongo/db/pipeline/semantic_analysis.cpp b/src/mongo/db/pipeline/semantic_analysis.cpp index e44b0af957..64820ed99a 100644 --- a/src/mongo/db/pipeline/semantic_analysis.cpp +++ b/src/mongo/db/pipeline/semantic_analysis.cpp @@ -30,6 +30,8 @@ #include "mongo/platform/basic.h" #include "mongo/db/matcher/expression_algo.h" +#include "mongo/db/pipeline/document_source_replace_root.h" +#include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/pipeline/semantic_analysis.h" @@ -98,47 +100,208 @@ StringMap invertRenameMap(const StringMap& originalMap return reversedMap; } +const ReplaceRootTransformation* isReplaceRoot(const DocumentSource* source) { + // We have to use getSourceName() since DocumentSourceReplaceRoot is never materialized - it + // uses DocumentSourceSingleDocumentTransformation. + auto singleDocTransform = + dynamic_cast(source); + if (!singleDocTransform) { + return nullptr; + } + return dynamic_cast(&singleDocTransform->getTransformer()); +} + +/** + * Detects if 'replaceRootTransform' represents the nesting of a field path. If it does, returns + * the name of that field path. For example, if 'replaceRootTransform' represents the transformation + * associated with {$replaceWith: {nested: "$$ROOT"}} or {$replaceRoot: {newRoot: {nested: + * "$$ROOT"}}}, returns "nested". + */ +boost::optional replaceRootNestsRoot( + const ReplaceRootTransformation* replaceRootTransform) { + auto expressionObject = + dynamic_cast(replaceRootTransform->getExpression().get()); + if (!expressionObject) { + return boost::none; + } + auto children = expressionObject->getChildExpressions(); + if (children.size() != 1u) { + return boost::none; + } + auto&& [nestedName, expression] = children[0]; + if (!dynamic_cast(expression.get()) || + !dynamic_cast(expression.get())->isRootFieldPath()) { + return boost::none; + } + return nestedName; +} + +/** + * Detects if 'replaceRootTransform' represents the unnesting of a field path. If it does, returns + * the name of that field path. For example, if 'replaceRootTransform' represents the transformation + * associated with {$replaceWith: "$x"} or {$replaceRoot: {newRoot: "$x"}}, returns "x". + */ +boost::optional replaceRootUnnestsPath( + const ReplaceRootTransformation* replaceRootTransform) { + auto expressionFieldPath = + dynamic_cast(replaceRootTransform->getExpression().get()); + if (!expressionFieldPath) { + return boost::none; + } + return expressionFieldPath->getFieldPathWithoutCurrentPrefix().fullPath(); +} + +/** + * Looks for a pattern where the user temporarily nests the whole object, does some computation, + * then unnests the object. Like so: + * [{$replaceWith: {nested: "$$ROOT"}}, ..., {$replaceWith: "$nested"}]. + * + * If this pattern is detected, returns an iterator to the 'second' replace root, whichever is later + * according to the traversal order. + */ +template +boost::optional lookForNestUnnestPattern( + Iterator start, + Iterator end, + std::set pathsOfInterest, + const Direction& traversalDir, + boost::optional> additionalStageValidatorCallback) { + auto replaceRootTransform = isReplaceRoot((*start).get()); + if (!replaceRootTransform) { + return boost::none; + } + + auto targetName = traversalDir == Direction::kForward + ? replaceRootNestsRoot(replaceRootTransform) + : replaceRootUnnestsPath(replaceRootTransform); + if (!targetName || targetName->find(".") != std::string::npos) { + // Bail out early on dotted paths - we don't intend to deal with that complexity here, + // though we could in the future. + return boost::none; + } + auto nameTestCallback = + traversalDir == Direction::kForward ? replaceRootUnnestsPath : replaceRootNestsRoot; + + ++start; // Advance one to go past the first $replaceRoot we just looked at. + for (; start != end; ++start) { + replaceRootTransform = isReplaceRoot((*start).get()); + if (!replaceRootTransform) { + if (additionalStageValidatorCallback && + !((*additionalStageValidatorCallback)((*start).get()))) { + // There was an additional condition which failed - bail out. + return boost::none; + } + + auto renames = renamedPaths({*targetName}, **start, traversalDir); + if (!renames || + (renames->find(*targetName) != renames->end() && + (*renames)[*targetName] != *targetName)) { + // This stage is not a $replaceRoot - and it modifies our nested path + // ('targetName') somehow. + return boost::none; + } + // This is not a $replaceRoot - but it doesn't impact the nested path, so we continue + // searching for the unnester. + continue; + } + if (auto nestName = nameTestCallback(replaceRootTransform); + nestName && *nestName == *targetName) { + return start; + } else { + // If we have a replaceRoot which is not the one we're looking for - then it modifies + // the path we're trying to preserve. As a future enhancement, we maybe could recurse + // here. + return boost::none; + } + } + return boost::none; +} + /** - * Computes and returns a rename mapping for 'pathsOfInterest' over multiple aggregation pipeline - * stages. The range of pipeline stages we compute renames over is represented by the iterators - * 'start' and 'end'. If both 'start' and 'end' are reverse iterators, then 'start' should come - * after 'end' in the pipeline, 'traversalDir' should be "kBackward," 'pathsOfInterest' should be - * valid path names after stage 'start,' and this template will compute a mapping from the given - * names of 'pathsOfInterest' to their names as they were directly after stage 'end.'If both 'start' - * and 'end' are forwards iterators, then 'start' should come before 'end' in the pipeline, - * 'traversalDir' should be "kForward," 'pathsOfInterest' should be valid path names before stage - * 'start,' and this template will compute a mapping from the given names of 'pathsOfInterest' to - * their names as they are directly before stage 'end.' + * Computes and returns a rename mapping for 'pathsOfInterest' over multiple aggregation + * pipeline stages. The range of pipeline stages we consider renames over is represented by the + * iterators 'start' and 'end'. + * + * If both 'start' and 'end' are reverse iterators, then 'start' should come after 'end' in the + * pipeline, and 'traversalDir' should be "kBackward," 'pathsOfInterest' should be valid path names + * after stage 'start.' + * + * If both 'start' and 'end' are forwards iterators, then 'start' should come before 'end' in the + * pipeline, 'traversalDir' should be "kForward," and 'pathsOfInterest' should be valid path names + * before stage 'start.' + * + * This function will compute an iterator pointing to the "last" stage (farthest in the given + * direction) which preserves 'pathsOfInterest' allowing renames, and returns that iterator and a + * mapping from the given names of 'pathsOfInterest' to their names as they were directly "before" + * (just previous to, according to the direction) the result iterator. If all stages preserve the + * paths of interest, returns 'end.' + * + * An optional 'additionalStageValidatorCallback' function can be provided to short-circuit this + * process and return an iterator to the first stage which either (a) does not preserve + * 'pathsOfInterest,' as before, or (b) does not meet this callback function's critera. * * This should only be used internally; callers who need to track path renames through an * aggregation pipeline should use one of the publically exposed options availible in the header. */ template -boost::optional> multiStageRenamedPaths( +std::pair> multiStageRenamedPaths( Iterator start, Iterator end, std::set pathsOfInterest, - const Direction& traversalDir) { - // The keys to this map will always be the original names of 'pathsOfInterest'. The values will - // be updated as we loop through the pipeline's stages to always be the most up-to-date name we - // know of for that path. + const Direction& traversalDir, + boost::optional> additionalStageValidatorCallback = + boost::none) { + // The keys to this map will always be the original names of 'pathsOfInterest'. The values + // will be updated as we loop through the pipeline's stages to always be the most up-to-date + // name we know of for that path. StringMap renameMap; for (auto&& path : pathsOfInterest) { renameMap[path] = path; } for (; start != end; ++start) { + if (additionalStageValidatorCallback && + !((*additionalStageValidatorCallback)((*start).get()))) { + // There was an additional condition which failed - bail out. + return {start, renameMap}; + } + auto renamed = renamedPaths(pathsOfInterest, **start, traversalDir); if (!renamed) { - return boost::none; + if (auto finalReplaceRoot = lookForNestUnnestPattern( + start, end, pathsOfInterest, traversalDir, additionalStageValidatorCallback)) { + // We've just detected a pattern where the user temporarily nests the whole + // object, does some computation, then unnests the object. Like so: + // [{$replaceWith: {nested: "$$ROOT"}}, ..., {$replaceWith: "$nested"}]. + // This analysis makes sure that the middle stages don't modify 'nested' or + // whatever the nesting field path is. In this case, we can safely skip over all + // intervening stages and continue on our way. + start = *finalReplaceRoot; + continue; + } + return {start, renameMap}; } - //'pathsOfInterest' always holds the current names of the paths we're interested in, so it - // needs to be updated after each stage. + //'pathsOfInterest' always holds the current names of the paths we're interested in, so + // it needs to be updated after each stage. pathsOfInterest.clear(); for (auto it = renameMap.cbegin(); it != renameMap.cend(); ++it) { renameMap[it->first] = (*renamed)[it->second]; pathsOfInterest.emplace(it->second); } } + return {end, renameMap}; +} +template +boost::optional> renamedPathsFullPipeline( + Iterator start, + Iterator end, + std::set pathsOfInterest, + const Direction& traversalDir, + boost::optional> additionalStageValidatorCallback) { + auto [itr, renameMap] = multiStageRenamedPaths( + start, end, pathsOfInterest, Direction::kForward, additionalStageValidatorCallback); + if (itr != end) { + return boost::none; // The paths were not preserved to the very end. + } return renameMap; } @@ -148,8 +311,7 @@ std::set extractModifiedDependencies(const std::set& d const std::set& preservedPaths) { std::set modifiedDependencies; - // The modified dependencies is *almost* the set difference 'dependencies' - 'preservedPaths', - // except that if p in 'preservedPaths' is a "path prefix" of d in 'dependencies', then 'd' + // The modified dependencies is *almost* the path prefix" of d in 'dependencies', then 'd' // should not be included in the modified dependencies. for (auto&& dependency : dependencies) { bool preserved = false; @@ -232,15 +394,53 @@ boost::optional> renamedPaths(const std::set boost::optional> renamedPaths( const Pipeline::SourceContainer::const_iterator start, const Pipeline::SourceContainer::const_iterator end, - const std::set& pathsOfInterest) { - return multiStageRenamedPaths(start, end, pathsOfInterest, Direction::kForward); + const std::set& pathsOfInterest, + boost::optional> additionalStageValidatorCallback) { + return renamedPathsFullPipeline( + start, end, pathsOfInterest, Direction::kForward, additionalStageValidatorCallback); } boost::optional> renamedPaths( const Pipeline::SourceContainer::const_reverse_iterator start, const Pipeline::SourceContainer::const_reverse_iterator end, - const std::set& pathsOfInterest) { - return multiStageRenamedPaths(start, end, pathsOfInterest, Direction::kBackward); + const std::set& pathsOfInterest, + boost::optional> additionalStageValidatorCallback) { + return renamedPathsFullPipeline( + start, end, pathsOfInterest, Direction::kBackward, additionalStageValidatorCallback); +} + +std::pair> +findLongestViablePrefixPreservingPaths( + const Pipeline::SourceContainer::const_iterator start, + const Pipeline::SourceContainer::const_iterator end, + const std::set& pathsOfInterest, + boost::optional> additionalStageValidatorCallback) { + return multiStageRenamedPaths( + start, end, pathsOfInterest, Direction::kForward, additionalStageValidatorCallback); +} + +bool pathSetContainsOverlappingPath(const std::set& paths, + const std::string& targetPath) { + auto targetTopLevelField = [&]() { + auto dotPosition = targetPath.find("."); + if (dotPosition == std::string::npos) { + return targetPath; + } + return targetPath.substr(0, dotPosition); + }(); + // Find the lower bound of where the path would be, then keep iterating until we find a path + // which does not start with the top-level string. This will help constrain our search. + auto it = paths.lower_bound(targetTopLevelField); + while (it != paths.end() && str::startsWith(*it, targetTopLevelField)) { + // Be careful to check both directions: The 'targetPath' overlaps if either is a prefix of + // another. For example "a" overlaps with "a.b", and vice versa. + if (targetPath == *it || expression::isPathPrefixOf(*it, targetPath) || + expression::isPathPrefixOf(targetPath, *it)) { + return true; + } + ++it; + } + return false; } } // namespace mongo::semantic_analysis diff --git a/src/mongo/db/pipeline/semantic_analysis.h b/src/mongo/db/pipeline/semantic_analysis.h index 73739919e1..33a9851736 100644 --- a/src/mongo/db/pipeline/semantic_analysis.h +++ b/src/mongo/db/pipeline/semantic_analysis.h @@ -72,7 +72,9 @@ boost::optional> renamedPaths(const std::set boost::optional> renamedPaths( const Pipeline::SourceContainer::const_iterator start, const Pipeline::SourceContainer::const_iterator end, - const std::set& pathsOfInterest); + const std::set& pathsOfInterest, + boost::optional> additionalStageValidatorCallback = + boost::none); /** * Tracks renames by walking a pipeline backwards. Takes two reverse iterators that represent two @@ -87,7 +89,24 @@ boost::optional> renamedPaths( boost::optional> renamedPaths( const Pipeline::SourceContainer::const_reverse_iterator start, const Pipeline::SourceContainer::const_reverse_iterator end, - const std::set& pathsOfInterest); + const std::set& pathsOfInterest, + boost::optional> additionalStageValidatorCallback = + boost::none); + +/** + * Attempts to find a maximal prefix of the pipeline given by 'start' and 'end' which will preserve + * all paths in 'pathsOfInterest' and also have each DocumentSource satisfy + * 'additionalStageValidatorCallback'. + * + * Returns an iterator to the first stage which modifies one of the paths in 'pathsOfInterest' or + * fails 'additionalStageValidatorCallback', or returns 'end' if no such stage exists. + */ +std::pair> +findLongestViablePrefixPreservingPaths(const Pipeline::SourceContainer::const_iterator start, + const Pipeline::SourceContainer::const_iterator end, + const std::set& pathsOfInterest, + boost::optional> + additionalStageValidatorCallback = boost::none); /** * Given a set of paths 'dependencies', determines which of those paths will be modified if all @@ -99,4 +118,7 @@ boost::optional> renamedPaths( std::set extractModifiedDependencies(const std::set& dependencies, const std::set& preservedPaths); +bool pathSetContainsOverlappingPath(const std::set& paths, + const std::string& targetPath); + } // namespace mongo::semantic_analysis diff --git a/src/mongo/db/pipeline/semantic_analysis_test.cpp b/src/mongo/db/pipeline/semantic_analysis_test.cpp index a2fcf80832..66ba34c33d 100644 --- a/src/mongo/db/pipeline/semantic_analysis_test.cpp +++ b/src/mongo/db/pipeline/semantic_analysis_test.cpp @@ -439,5 +439,265 @@ TEST_F(SemanticAnalysisRenamedPaths, ReturnsNoneWhenModificationsAreNotKnown) { } } +TEST_F(SemanticAnalysisRenamedPaths, DetectsSimpleReplaceRootPattern) { + auto pipeline = Pipeline::parse( + {fromjson("{$replaceWith: {nested: '$$ROOT'}}"), fromjson("{$replaceWith: '$nested'}")}, + getExpCtx()); + { + auto renames = + renamedPaths(pipeline->getSources().begin(), pipeline->getSources().end(), {"a"}); + ASSERT_TRUE(static_cast(renames)); + } + { + auto renames = + renamedPaths(pipeline->getSources().begin(), pipeline->getSources().end(), {"b"}); + ASSERT_TRUE(static_cast(renames)); + } + { + auto renames = + renamedPaths(pipeline->getSources().rbegin(), pipeline->getSources().rend(), {"b"}); + ASSERT_TRUE(static_cast(renames)); + } +} + +TEST_F(SemanticAnalysisRenamedPaths, DetectsReplaceRootPatternAllowsIntermediateStages) { + auto pipeline = + Pipeline::parse({fromjson("{$replaceWith: {nested: '$$ROOT'}}"), + fromjson("{$set: {bigEnough: {$gte: [{$bsonSize: '$nested'}, 300]}}}"), + fromjson("{$match: {bigEnough: true}}"), + fromjson("{$replaceWith: '$nested'}")}, + getExpCtx()); + { + auto renames = + renamedPaths(pipeline->getSources().begin(), pipeline->getSources().end(), {"a"}); + ASSERT_TRUE(static_cast(renames)); + } + { + auto renames = + renamedPaths(pipeline->getSources().begin(), pipeline->getSources().end(), {"b"}); + ASSERT_TRUE(static_cast(renames)); + } + { + auto renames = + renamedPaths(pipeline->getSources().rbegin(), pipeline->getSources().rend(), {"b"}); + ASSERT_TRUE(static_cast(renames)); + } +} + +TEST_F(SemanticAnalysisRenamedPaths, DetectsReplaceRootPatternDisallowsIntermediateModification) { + auto pipeline = Pipeline::parse({fromjson("{$replaceWith: {nested: '$$ROOT'}}"), + fromjson("{$set: {'nested.field': 'anyNewValue'}}"), + fromjson("{$replaceWith: '$nested'}")}, + getExpCtx()); + { + auto renames = + renamedPaths(pipeline->getSources().begin(), pipeline->getSources().end(), {"a"}); + ASSERT_FALSE(static_cast(renames)); + } + { + auto renames = + renamedPaths(pipeline->getSources().begin(), pipeline->getSources().end(), {"b"}); + ASSERT_FALSE(static_cast(renames)); + } + { + auto renames = + renamedPaths(pipeline->getSources().rbegin(), pipeline->getSources().rend(), {"b"}); + ASSERT_FALSE(static_cast(renames)); + } +} + +TEST_F(SemanticAnalysisRenamedPaths, DoesNotDetectFalseReplaceRootIfTypoed) { + auto pipeline = Pipeline::parse( + {fromjson("{$replaceWith: {nested: '$$ROOT'}}"), fromjson("{$replaceWith: '$nestedTypo'}")}, + getExpCtx()); + { + auto renames = + renamedPaths(pipeline->getSources().begin(), pipeline->getSources().end(), {"a"}); + ASSERT_FALSE(static_cast(renames)); + } + { + auto renames = + renamedPaths(pipeline->getSources().rbegin(), pipeline->getSources().rend(), {"b"}); + ASSERT_FALSE(static_cast(renames)); + } +} + +TEST_F(SemanticAnalysisRenamedPaths, DetectsReplaceRootPatternIfCurrentInsteadOfROOT) { + auto pipeline = Pipeline::parse( + {fromjson("{$replaceWith: {nested: '$$CURRENT'}}"), fromjson("{$replaceWith: '$nested'}")}, + getExpCtx()); + { + auto renames = + renamedPaths(pipeline->getSources().begin(), pipeline->getSources().end(), {"a"}); + ASSERT_TRUE(static_cast(renames)); + } + { + auto renames = + renamedPaths(pipeline->getSources().rbegin(), pipeline->getSources().rend(), {"b"}); + ASSERT_TRUE(static_cast(renames)); + } +} + +TEST_F(SemanticAnalysisRenamedPaths, DoesNotDetectFalseReplaceRootIfNoROOT) { + auto pipeline = Pipeline::parse( + {fromjson("{$replaceWith: {nested: '$subObj'}}"), fromjson("{$replaceWith: '$nested'}")}, + getExpCtx()); + { + auto renames = + renamedPaths(pipeline->getSources().begin(), pipeline->getSources().end(), {"a"}); + ASSERT_FALSE(static_cast(renames)); + } + { + auto renames = + renamedPaths(pipeline->getSources().rbegin(), pipeline->getSources().rend(), {"b"}); + ASSERT_FALSE(static_cast(renames)); + } +} + +TEST_F(SemanticAnalysisRenamedPaths, DoesNotDetectFalseReplaceRootIfTargetPathIsRenamed) { + + { + auto pipeline = Pipeline::parse({fromjson("{$replaceWith: {nested: '$$ROOT'}}"), + fromjson("{$unset : 'nested'}"), + fromjson("{$replaceWith: '$nested'}")}, + getExpCtx()); + auto renames = + renamedPaths(pipeline->getSources().begin(), pipeline->getSources().end(), {"a"}); + ASSERT_FALSE(static_cast(renames)); + } + { + auto pipeline = Pipeline::parse({fromjson("{$replaceWith: {nested: '$$ROOT'}}"), + fromjson("{$set : {nested: '$somethingElese'}}"), + fromjson("{$replaceWith: '$nested'}")}, + getExpCtx()); + auto renames = + renamedPaths(pipeline->getSources().rbegin(), pipeline->getSources().rend(), {"b"}); + ASSERT_FALSE(static_cast(renames)); + } + { + // This case could someday work - we leave it as a future improvement. + auto pipeline = Pipeline::parse({fromjson("{$replaceWith: {nested: '$$ROOT'}}"), + fromjson("{$set : {somethingElse: '$nested'}}"), + fromjson("{$replaceWith: '$somethingElse'}")}, + getExpCtx()); + auto renames = + renamedPaths(pipeline->getSources().rbegin(), pipeline->getSources().rend(), {"b"}); + ASSERT_FALSE(static_cast(renames)); + } + { + // This is a tricky one. The pattern does exist, but it's doubly nested and only unnested + // once. + auto pipeline = Pipeline::parse({fromjson("{$replaceWith: {nested: '$$ROOT'}}"), + fromjson("{$replaceWith: {doubleNested: '$nested'}}"), + fromjson("{$replaceWith: '$doubleNested'}")}, + getExpCtx()); + auto renames = + renamedPaths(pipeline->getSources().rbegin(), pipeline->getSources().rend(), {"b"}); + ASSERT_FALSE(static_cast(renames)); + } + { + // Similar to above but double nested then double unnested. We could someday make this work, + // but leave it for a future improvement. + auto pipeline = Pipeline::parse({fromjson("{$replaceWith: {nested: '$$ROOT'}}"), + fromjson("{$replaceWith: {doubleNested: '$nested'}}"), + fromjson("{$replaceWith: '$doubleNested'}"), + fromjson("{$replaceWith: '$nested'}")}, + getExpCtx()); + auto renames = + renamedPaths(pipeline->getSources().rbegin(), pipeline->getSources().rend(), {"b"}); + ASSERT_FALSE(static_cast(renames)); + } +} + +using SemanticAnalysisFindLongestViablePrefix = AggregationContextFixture; +TEST_F(SemanticAnalysisFindLongestViablePrefix, AllowsReplaceRootPattern) { + auto pipeline = + Pipeline::parse({fromjson("{$replaceWith: {nested: '$$ROOT'}}"), + fromjson("{$set: {bigEnough: {$gte: [{$bsonSize: '$nested'}, 300]}}}"), + fromjson("{$match: {bigEnough: true}}"), + fromjson("{$replaceWith: '$nested'}")}, + getExpCtx()); + auto [itr, renames] = findLongestViablePrefixPreservingPaths( + pipeline->getSources().begin(), pipeline->getSources().end(), {"a"}); + ASSERT(itr == pipeline->getSources().end()); +} + +TEST_F(SemanticAnalysisFindLongestViablePrefix, FindsPrefixWithoutReplaceRoot) { + auto pipeline = Pipeline::parse({fromjson("{$match: {testing: true}}"), + fromjson("{$unset: 'unset'}"), + fromjson("{$set: {x: '$y'}}")}, + getExpCtx()); + { + auto [itr, renames] = findLongestViablePrefixPreservingPaths( + pipeline->getSources().begin(), pipeline->getSources().end(), {"a"}); + ASSERT(itr == pipeline->getSources().end()); + } + { + auto [itr, renames] = findLongestViablePrefixPreservingPaths( + pipeline->getSources().begin(), pipeline->getSources().end(), {"unset"}); + ASSERT(itr == std::next(pipeline->getSources().begin())); + } + { + auto [itr, renames] = findLongestViablePrefixPreservingPaths( + pipeline->getSources().begin(), pipeline->getSources().end(), {"y"}); + ASSERT(itr == pipeline->getSources().end()); + ASSERT(renames["y"] == "x"); + } + { + auto [itr, renames] = findLongestViablePrefixPreservingPaths( + pipeline->getSources().begin(), pipeline->getSources().end(), {"x"}); + ASSERT(itr == std::prev(pipeline->getSources().end())); + ASSERT(renames["x"] == "x"); + } +} + +TEST_F(SemanticAnalysisFindLongestViablePrefix, CorrectlyAnswersReshardingUseCase) { + auto expCtx = getExpCtx(); + auto lookupNss = NamespaceString{"config.cache.chunks.test"}; + expCtx->setResolvedNamespace(lookupNss, ExpressionContext::ResolvedNamespace{lookupNss, {}}); + auto pipeline = + Pipeline::parse({fromjson("{$replaceWith: {original: '$$ROOT'}}"), + fromjson("{$lookup: {from: {db: 'config', coll: 'cache.chunks.test'}, " + "pipeline: [], as: 'intersectingChunk'}}"), + fromjson("{$match: {intersectingChunk: {$ne: []}}}"), + fromjson("{$replaceWith: '$original'}")}, + getExpCtx()); + { + auto [itr, renames] = findLongestViablePrefixPreservingPaths( + pipeline->getSources().begin(), pipeline->getSources().end(), {"_id"}); + ASSERT(itr == pipeline->getSources().end()); + ASSERT(renames["_id"] == "_id"); + } +} + +TEST(PathSetContainsOverlappingPath, ReturnsTrueIfExactPathPresent) { + ASSERT_TRUE(pathSetContainsOverlappingPath({"a", "b"}, "a")); + ASSERT_TRUE(pathSetContainsOverlappingPath({"a", "b"}, "b")); + ASSERT_TRUE(pathSetContainsOverlappingPath({"abc", "xyz"}, "abc")); + ASSERT_TRUE(pathSetContainsOverlappingPath({"ABC", "XYZ"}, "XYZ")); + ASSERT_TRUE(pathSetContainsOverlappingPath({"ABC.DEF", "XYZ.123"}, "XYZ.123")); +} + +TEST(PathSetContainsOverlappingPath, ReturnsTrueIfPathPrefixPresent) { + ASSERT_TRUE(pathSetContainsOverlappingPath({"a", "b"}, "a.b")); + ASSERT_TRUE(pathSetContainsOverlappingPath({"a", "b"}, "b.c")); + ASSERT_TRUE(pathSetContainsOverlappingPath({"abc", "xyz"}, "abc.deep")); + ASSERT_TRUE(pathSetContainsOverlappingPath({"abc.deep", "abcdeep", "abcDeep"}, "abc.deep")); +} + +TEST(PathSetContainsOverlappingPath, ReturnsTrueIfPathSuffixPresent) { + ASSERT_TRUE(pathSetContainsOverlappingPath({"a.b", "b.c"}, "a")); + ASSERT_TRUE(pathSetContainsOverlappingPath({"a.b", "b.c"}, "b")); + ASSERT_TRUE(pathSetContainsOverlappingPath({"abc.deep", "xyz.deep"}, "abc")); + ASSERT_TRUE(pathSetContainsOverlappingPath({"abc.deep", "abcdeep", "abcDeep"}, "abc")); +} + +TEST(PathSetContainsOverlappingPath, ReturnsFalseIfPathNotPresent) { + ASSERT_FALSE(pathSetContainsOverlappingPath({"a", "b"}, "c")); + ASSERT_FALSE(pathSetContainsOverlappingPath({"a", "b"}, "ab")); + ASSERT_FALSE(pathSetContainsOverlappingPath({"abc.deep", "xyz.deep"}, "abcdeep")); + ASSERT_FALSE(pathSetContainsOverlappingPath({"abcdeep", "abcDeep"}, "abc.deep")); +} + } // namespace } // namespace mongo diff --git a/src/mongo/db/pipeline/sharded_agg_helpers.cpp b/src/mongo/db/pipeline/sharded_agg_helpers.cpp index 3bec7700e3..0fb0a9c1b0 100644 --- a/src/mongo/db/pipeline/sharded_agg_helpers.cpp +++ b/src/mongo/db/pipeline/sharded_agg_helpers.cpp @@ -298,6 +298,32 @@ void moveFinalUnwindFromShardsToMerger(Pipeline* shardPipe, Pipeline* mergePipe) } } +/** + * TODO + */ +void moveEligibleStreamingStagesBeforeSortOnShards(Pipeline* shardPipe, + Pipeline* mergePipe, + const BSONObj& sortPattern) { + tassert(5363800, + "Expected non-empty shardPipe consisting of at least a $sort stage", + !shardPipe->getSources().empty()); + tassert(5363801, + "Expected last stage on the shards to be a $sort", + dynamic_cast(shardPipe->getSources().back().get())); + while (!mergePipe->getSources().empty()) { + auto* firstMergeStage = mergePipe->getSources().back().get(); + if (firstMergeStage->distributedPlanLogic()) { + // This stage can't be moved onto the shards to run in parallel - bail out. + return; + } + auto modifiedPaths = firstMergeStage->getModifiedPaths(); + if (modifiedPaths.type == DocumentSource::GetModPathsReturn::Type::kNotSupported) { + // We can't tell if this stage messes with the sort pattern so bail out. + return; + } + } +} + /** * Returns true if the final stage of the pipeline limits the number of documents it could output * (such as a $limit stage). @@ -773,6 +799,10 @@ SplitPipeline splitPipeline(std::unique_ptr pipeline) // The order in which optimizations are applied can have significant impact on the efficiency of // the final pipeline. Be Careful! + if (inputsSort) { + moveEligibleStreamingStagesBeforeSortOnShards( + shardsPipeline.get(), mergePipeline.get(), *inputsSort); + } moveFinalUnwindFromShardsToMerger(shardsPipeline.get(), mergePipeline.get()); propagateDocLimitToShards(shardsPipeline.get(), mergePipeline.get()); limitFieldsSentFromShardsToMerger(shardsPipeline.get(), mergePipeline.get());