diff --git a/src/mongo/dbtests/query_stage_save_restore.cpp b/src/mongo/dbtests/query_stage_save_restore.cpp index 8c1a47e..8ff7633 100644 --- a/src/mongo/dbtests/query_stage_save_restore.cpp +++ b/src/mongo/dbtests/query_stage_save_restore.cpp @@ -37,6 +37,14 @@ #include "mongo/db/exec/delete.h" #include "mongo/db/exec/index_scan.h" #include "mongo/db/operation_context_impl.h" +#include "mongo/db/exec/fetch.h" +#include "mongo/db/exec/index_scan.h" +#include "mongo/db/exec/keep_mutations.h" +#include "mongo/db/exec/update.h" +#include "mongo/db/operation_context_impl.h" +#include "mongo/db/ops/update_driver.h" +#include "mongo/db/ops/update_lifecycle_impl.h" +#include "mongo/db/ops/update_request.h" #include "mongo/dbtests/dbtests.h" #include "mongo/dbtests/query_stage_save_restore_test_suite.h" #include "mongo/dbtests/query_stage_specific_save_restore_test.h" @@ -163,6 +171,118 @@ namespace QueryStageSaveRestore { }; // + // UpdateStage + // + + class UpdateStageSaveRestoreTest : public QueryStageSpecificSaveRestoreTest { + public: + UpdateStageSaveRestoreTest(OperationContextImpl* _txn, const char* ns) : + QueryStageSpecificSaveRestoreTest(_txn, ns), + _query(fromjson("{a: {$gte: 0}}")), + _updates(fromjson("{$inc: {y: 1}}")) + { } + + ~UpdateStageSaveRestoreTest() + { + delete _updateParams; + delete _request; + delete _driver; + delete _lifecycle; + } + + bool produce() { + /* Like the DeleteStage, an UpdateStage doesn't actually return PlanStage::ADVANCED + * once it has updated a document, instead we have to rely off the nMatched data to + * determine when it has advanced. Hence the custom produce method. + */ + Client::WriteContext ctx(_txn, _ns); + const UpdateStats* stats = static_cast(_tree->getSpecificStats()); + size_t matchedBefore = stats->nMatched; + WorkingSetID wsOut; + PlanStage::StageState state = PlanStage::NEED_TIME; + while (stats->nMatched <= matchedBefore) { + if (PlanStage::DEAD == state || + PlanStage::FAILURE == state || + _tree->isEOF()) { + return false; + } + state = _tree->work(&wsOut); + } + ++_nMatches; + return true; + } + + private: + void setupTree() { + Client::WriteContext ctx(_txn, _ns); + Collection* coll = ctx.ctx().db()->getCollection(_txn, _ns); + + // IndexScan for underneath our update + IndexScanParams childParams; + childParams.descriptor = getIndex(BSON("a" << 1)); + childParams.bounds.isSimpleRange = true; + childParams.bounds.startKey = BSON("" << 0); + childParams.bounds.endKey = BSONObj(); + childParams.bounds.endKeyInclusive = true; + + // Set up parameters for the UpdateStage. + Client& c = cc(); + CurOp& curOp = *c.curop(); + auto_ptr opDebug(&curOp.debug()); + NamespaceString nsString(ns()); + _driver = new UpdateDriver( (UpdateDriver::Options()) ); + _request = new UpdateRequest(nsString); + _lifecycle = new UpdateLifecycleImpl(false, nsString); + _request->setLifecycle(_lifecycle); + + _request->setMulti(); + _request->setQuery(_query); + _request->setUpdates(_updates); + + ASSERT_OK(_driver->parse(_request->getUpdates(), _request->isMulti())); + _updateParams = new UpdateStageParams(_request, _driver, opDebug.release()); + + /* The update stage needs an index scan underneath, followed by two intermediary + * states: A FetchStage to get the whole documents, not just the indexed part, and + * a KeepMutationsStage because ???. Then the root is the actual UpdateStage. + */ + auto_ptr keepMutationsStage( + new KeepMutationsStage(NULL, &wsTree, + new FetchStage( + _txn, &wsTree, new IndexScan(_txn, childParams, &wsTree, NULL), NULL, coll + ) + ) + ); + auto_ptr updateStage( + new UpdateStage(_txn, *_updateParams, &wsTree, coll, keepMutationsStage.release()) + ); + + _tree.reset(updateStage.release()); + } + + void setupRecordIds() { + // Make a dummy scan to get all RecordIds out + WorkingSet wsDummy; + IndexScanParams dummyParams; + dummyParams.descriptor = getIndex(BSON("a" << 1)); + dummyParams.bounds.isSimpleRange = true; + dummyParams.bounds.endKey = BSONObj(); + IndexScan dummy(_txn, dummyParams, &wsDummy, NULL); + getRecordIdsFromScan(&dummy, &wsDummy); + } + + // We need to have these objects stick around past when they are created so they will be + // available during the test. So keep a reference to them here. The pointers have to be + // deleted upon deconstruction. + BSONObj _query; + BSONObj _updates; + UpdateStageParams *_updateParams; + UpdateRequest *_request; + UpdateDriver *_driver; + UpdateLifecycleImpl *_lifecycle; + }; + + // // DBTests for each stage // @@ -186,6 +306,16 @@ namespace QueryStageSaveRestore { } }; + class UpdateStageSaveRestore { + public: + void run() { + OperationContextImpl txn; + UpdateStageSaveRestoreTest updateStageTests(&txn, ns()); + QueryStageSaveRestoreTestSuite baseline(&updateStageTests, &txn, ns()); + baseline.test(); + } + }; + class All : public Suite { public: All() : Suite("query_stage_save_restore") {} @@ -193,6 +323,7 @@ namespace QueryStageSaveRestore { void setupTests() { add(); add(); + add(); } };