86 const VdfRequest &computeRequest,
89 schedule, computeRequest, errorLogger,
98 template <
typename Callback >
101 const VdfRequest &computeRequest,
103 Callback &&callback);
107 typedef typename DataManager::DataHandle _DataHandle;
110 typedef uint32_t _EvaluationStage;
113 template <
typename Callback >
121 const size_t requestedIndex,
122 Callback &callback) :
126 _requestedIndex(requestedIndex),
138 const size_t _requestedIndex;
140 _EvaluationStage _evaluationStage;
155 _taskIndex(taskIndex),
167 _EvaluationStage _evaluationStage;
182 _taskIndex(taskIndex),
194 _EvaluationStage _evaluationStage;
209 _taskIndex(taskIndex),
221 _EvaluationStage _evaluationStage;
276 template <
typename Callback >
280 const size_t requestedIndex,
287 void _SpawnRequestedTasks(
306 template <
typename Callback >
307 bool _ProcessLeafTask(
311 const size_t requestedIndex,
313 _EvaluationStage *evaluationStage,
320 bool _ProcessComputeTask(
325 _EvaluationStage *evaluationStage,
332 bool _ProcessInputsTask(
337 _EvaluationStage *evaluationStage,
344 bool _ProcessKeepTask(
348 _EvaluationStage *evaluationStage,
355 bool _InvokeKeepTask(
365 void _InvokeTouchTask(
373 bool _InvokeComputeTask(
381 template <
typename Iterable >
382 bool _InvokeComputeTasks(
383 const Iterable &tasks,
391 bool _IsInputDependencyCached(
399 bool _InvokeComputeOrKeepTasks(
409 bool _InvokeComputeOrKeepTasks(
419 bool _InvokePrereqInputs(
430 bool _InvokeOptionalInputs(
442 bool _InvokeRequiredInputs(
452 bool _InvokeInputsTask(
463 bool _InvokePrepTask(
483 void _CreateScratchCache(
485 const _DataHandle dataHandle,
504 void _PassThroughNode(
515 const _DataHandle dataHandle,
516 const bool hasAssociatedInput,
521 void _PrepareReadWriteBuffer(
530 void _PassOrCopyBuffer(
553 void _PublishScratchBuffers(
562 const _DataHandle dataHandle,
568 bool _DetectInterruption(
574 bool _HasDetectedInterruption()
const;
578 void _TransportErrors(
const TfErrorMark &errorMark);
581 void _PostTransportedErrors();
585 return *
static_cast<Derived *
>(
this);
592 DataManager *_dataManager;
602 std::unique_ptr<std::atomic<uint8_t>[]> _dependencyState;
608 std::atomic<bool> _resetState;
615 tbb::concurrent_vector<TfErrorTransport> _errors;
618 std::atomic<bool> _isInterrupted;
623template <
typename Derived,
typename DataManager >
627 DataManager *dataManager) :
629 _dataManager(dataManager),
631 _computeTasks(&_taskGraph),
632 _inputsTasks(&_taskGraph),
633 _prepTasks(&_taskGraph),
634 _keepTasks(&_taskGraph),
639template <
typename Derived,
typename DataManager >
645template <
typename Derived,
typename DataManager >
646template <
typename Callback >
650 const VdfRequest &computeRequest,
654 TRACE_SCOPE(
"VdfParallelExecutorEngineBase::RunSchedule");
657 TF_PY_ALLOW_THREADS_IN_SCOPE();
663 _resetState.store(
false, std::memory_order_relaxed);
670 VdfRequest::IndexedView view(computeRequest);
676 _isolatingDispatcher.Run([engine, &state, &view, &callback] {
685 [engine, &state, &view, &callback, &taskLists]
686 (
size_t b,
size_t e) {
687 WorkTaskGraph::TaskList *taskList = &taskLists.local();
688 for (size_t i = b; i != e; ++i) {
689 if (const VdfMaskedOutput *maskedOutput = view.Get(i)) {
691 state, *maskedOutput, i, callback, taskList);
700 engine->_taskGraph.
RunLists(taskLists);
705 "VdfParallelExecutorEngineBase::RunSchedule "
706 "(wait for parallel tasks)");
707 engine->_taskGraph.
Wait();
711 _isolatingDispatcher.Wait();
715 _Self()._FinalizeEvaluation();
718 _isInterrupted.store(
false, std::memory_order_relaxed);
721 _PostTransportedErrors();
724template <
typename Derived,
typename DataManager >
736 _dependencyState.reset(
new std::atomic<uint8_t>[numUniqueDeps]);
737 char *
const dependencyState =
738 reinterpret_cast<char*
>(_dependencyState.get());
739 memset(dependencyState, 0,
740 sizeof(std::atomic<uint8_t>) * numUniqueDeps);
750template <
typename Derived,
typename DataManager >
751template <
typename Callback >
756 const size_t requestedIndex,
768 if (_executor.GetOutputValue(output, mask)) {
769 callback(maskedOutput, requestedIndex);
776 bool isReset = _resetState.load(std::memory_order_relaxed);
777 if (!isReset && _resetState.compare_exchange_strong(isReset,
true)) {
784 _taskGraph.AllocateTask< _LeafTask<Callback> >(
785 this, state, maskedOutput, requestedIndex, callback);
786 taskList->push_back(task);
789template <
typename Derived,
typename DataManager >
790template <
typename Callback >
805 if (_engine->_ProcessLeafTask(
806 this, _state, _output, _requestedIndex, _callback, &_evaluationStage,
808 _RecycleAsContinuation();
815 RemoveChildReference();
822template <
typename Derived,
typename DataManager >
844 if (_engine->_ProcessComputeTask(
845 this, _state, _node, scheduleTask, &_evaluationStage, &bypass)) {
853 _engine->_computeTasks.
MarkDone(_taskIndex);
860 _engine->_TransportErrors(errorMark);
867template <
typename Derived,
typename DataManager >
881 _state.GetSchedule().GetInputsTask(_taskIndex);
885 if (_engine->_ProcessInputsTask(
886 this, _state, _node, scheduleTask, &_evaluationStage, &bypass)) {
887 _RecycleAsContinuation();
893 _engine->_inputsTasks.MarkDone(_taskIndex);
894 RemoveChildReference();
901template <
typename Derived,
typename DataManager >
915 if (_engine->_ProcessKeepTask(
916 this, _state, _node, &_evaluationStage, &bypass)) {
917 _RecycleAsContinuation();
923 _engine->_keepTasks.MarkDone(_taskIndex);
924 RemoveChildReference();
931template <
typename Derived,
typename DataManager >
938 while (output && output != &_source) {
939 _engine->_Self()._Touch(*output);
947template <
typename Derived,
typename DataManager >
961 const bool invoked = _engine->_InvokeComputeTasks(
962 _state.GetSchedule().GetComputeTaskIds(_node),
963 _state, _node,
this,
nullptr);
970 _RecycleAsContinuation();
977 RemoveChildReference();
984template <
typename Derived,
typename DataManager >
1006 if (tasks.size() > 1) {
1010 if (_keepTasks.
Claim(keepTaskIndex, successor) ==
1014 this, state, node, keepTaskIndex);
1015 _SpawnOrBypass(task, bypass);
1023 if (_computeTasks.
Claim(computeTaskIndex, successor) ==
1025 _ComputeTask *task =
1027 this, state, node, computeTaskIndex);
1028 _SpawnOrBypass(task, bypass);
1033template <
typename Derived,
typename DataManager >
1043 if (!bypass || *bypass) {
1050template <
typename Derived,
typename DataManager >
1051template <
typename Callback >
1057 const size_t requestedIndex,
1059 _EvaluationStage *evaluationStage,
1064 EvaluationStageSpawn,
1065 EvaluationStageCallback
1069 switch (*evaluationStage) {
1074 case EvaluationStageSpawn: {
1076 _SpawnRequestedTasks(state, node, task, bypass);
1077 *evaluationStage = EvaluationStageCallback;
1083 case EvaluationStageCallback: {
1084 callback(maskedOutput, requestedIndex);
1091template <
typename Derived,
typename DataManager >
1098 _EvaluationStage *evaluationStage,
1103 EvaluationStageInputs,
1104 EvaluationStagePrepNode,
1105 EvaluationStageEvaluateNode
1109 switch (*evaluationStage) {
1112 case EvaluationStageInputs: {
1116 if (_DetectInterruption(state, node)) {
1123 node, VdfExecutionStats::NodeRequiredInputsEvent);
1126 const bool invokedRequireds =
1127 _InvokeRequiredInputs(scheduleTask, state, task, bypass);
1128 const bool invokedInputsTask =
1129 _InvokeInputsTask(scheduleTask, state, node, task, bypass);
1133 if (invokedRequireds || invokedInputsTask) {
1134 *evaluationStage = EvaluationStagePrepNode;
1140 case EvaluationStagePrepNode: {
1145 if (_DetectInterruption(state, node)) {
1151 if (_InvokePrepTask(scheduleTask, state, node, task)) {
1152 *evaluationStage = EvaluationStageEvaluateNode;
1158 case EvaluationStageEvaluateNode: {
1163 if (_HasDetectedInterruption()) {
1168 _EvaluateNode(scheduleTask, state, node, task);
1176template <
typename Derived,
typename DataManager >
1183 _EvaluationStage *evaluationStage,
1188 EvaluationStagePrereqs,
1189 EvaluationStageOptionals,
1196 node, VdfExecutionStats::NodeInputsTaskEvent);
1199 switch (*evaluationStage) {
1202 case EvaluationStagePrereqs: {
1206 if (_InvokePrereqInputs(
1207 scheduleTask, state, task, bypass)) {
1208 *evaluationStage = EvaluationStageOptionals;
1214 case EvaluationStageOptionals: {
1220 if (_HasDetectedInterruption()) {
1227 if (_InvokeOptionalInputs(
1228 scheduleTask, state, node, task, bypass)) {
1229 *evaluationStage = EvaluationStageDone;
1239template <
typename Derived,
typename DataManager >
1245 _EvaluationStage *evaluationStage,
1250 EvaluationStageKeep,
1251 EvaluationStagePublish
1259 switch (*evaluationStage) {
1262 case EvaluationStageKeep: {
1269 bool invoked =
false;
1276 if (computeTask.flags.hasKeep) {
1277 invoked |= _InvokeComputeTask(
1278 taskId, state, node, task, bypass);
1285 *evaluationStage = EvaluationStagePublish;
1291 case EvaluationStagePublish: {
1293 if (_HasDetectedInterruption()) {
1298 _PublishScratchBuffers(schedule, node);
1305template <
typename Derived,
typename DataManager >
1321 this, state, node, idx);
1322 _SpawnOrBypass(task, bypass);
1327 return claimState != VdfParallelTaskSync::State::Done;
1330template <
typename Derived,
typename DataManager >
1340 _TouchTask *task = _taskGraph.
AllocateTask<_TouchTask>(
1341 this, dest, source);
1345template <
typename Derived,
typename DataManager >
1356 claimState = _computeTasks.
Claim(taskIndex, successor);
1361 _ComputeTask *task =
1363 this, state, node, taskIndex);
1364 _SpawnOrBypass(task, bypass);
1369 return claimState != VdfParallelTaskSync::State::Done;
1372template <
typename Derived,
typename DataManager >
1373template <
typename Iterable >
1376 const Iterable &tasks,
1383 bool invoked =
false;
1385 invoked |= _InvokeComputeTask(taskId, state, node, successor, bypass);
1392template <
typename Derived,
typename DataManager >
1406 std::atomic<uint8_t> *state = &_dependencyState[uniqueIndex];
1407 uint8_t currentState = state->load(std::memory_order_relaxed);
1411 if (currentState == StateUndecided) {
1414 const uint8_t newState = isCached ? StateCached : StateUncached;
1419 if (state->compare_exchange_strong(currentState, newState)) {
1425 return currentState == StateCached;
1428template <
typename Derived,
typename DataManager >
1446 if (_IsInputDependencyCached(input.uniqueIndex, input.output, input.mask)) {
1460 bool invoked = _InvokeComputeTasks(tasks, state, node, successor, bypass);
1466 invoked |= _InvokeKeepTask(keepTask, node, state, successor, bypass);
1473template <
typename Derived,
typename DataManager >
1498 const VdfNode &node = output.GetNode();
1504 bool invoked = _InvokeComputeTasks(tasks, state, node, successor, bypass);
1510 invoked |= _InvokeKeepTask(keepTask, node, state, successor, bypass);
1517template <
typename Derived,
typename DataManager >
1525 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_InvokePrereqInputs");
1528 if (!scheduleTask.prereqsNum) {
1538 bool invoked =
false;
1540 invoked |= _InvokeComputeOrKeepTasks(i, state, successor, bypass);
1547template <
typename Derived,
typename DataManager >
1556 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_InvokeOptionalInputs");
1559 if (!scheduleTask.optionalsNum) {
1580 bool invoked =
false;
1587 invoked |= _InvokeComputeOrKeepTasks(i, state, successor, bypass);
1594 PEE_TRACE_SCOPE(
"Task Inversion");
1603 invoked |= _InvokeComputeOrKeepTasks(
1604 *scheduleInput.source, state, successor, bypass);
1613template <
typename Derived,
typename DataManager >
1621 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_InvokeRequiredInputs");
1631 bool invoked =
false;
1633 invoked |= _InvokeComputeOrKeepTasks(i, state, successor, bypass);
1640template <
typename Derived,
typename DataManager >
1649 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_InvokeInputsTask");
1660 _inputsTasks.
Claim(inputsTaskIndex, successor);
1667 this, state, node, inputsTaskIndex);
1668 _SpawnOrBypass(task, bypass);
1673 return claimState != VdfParallelTaskSync::State::Done;
1676template <
typename Derived,
typename DataManager >
1684 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_InvokePrepTask");
1695 _PrepareNode(state, node);
1699 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_InvokePrepTask (task)");
1703 _prepTasks.
Claim(prepTaskIndex, successor);
1710 _PrepareNode(state, node);
1711 _prepTasks.
MarkDone(prepTaskIndex);
1721template <
typename Derived,
typename DataManager >
1727 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_PrepareNode");
1732 node, VdfExecutionStats::NodePrepareEvent);
1736 VDF_FOR_EACH_SCHEDULED_OUTPUT_ID(outputId, schedule, node) {
1737 _PrepareOutput(schedule, outputId);
1741template <
typename Derived,
typename DataManager >
1753 _Self()._Touch(output);
1756 _DataHandle dataHandle =
1757 _dataManager->GetOrCreateDataHandle(output.GetId());
1762 _dataManager->GetPrivateBufferData(dataHandle);
1768 if (output.GetAssociatedInput()) {
1769 _PrepareReadWriteBuffer(
1770 output, outputId, requestMask, schedule, privateBuffer);
1776 _dataManager->GetScratchBufferData(dataHandle);
1785 _CreateScratchCache(output, dataHandle, keepMask, scratchBuffer);
1789template <
typename Derived,
typename DataManager >
1793 const _DataHandle dataHandle,
1798 _dataManager->GetPublicBufferData(dataHandle);
1804 _dataManager->CreateOutputCache(output, scratchBuffer, mask.
GetBits());
1815 _dataManager->CreateOutputCache(output, scratchBuffer, unionBits);
1819template <
typename Derived,
typename DataManager >
1827 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_EvaluateNode");
1832 node, VdfExecutionStats::NodeEvaluateEvent);
1835 if (scheduleTask.flags.isAffective) {
1836 _ComputeNode(scheduleTask, state, node);
1842 _PassThroughNode(scheduleTask, state, node);
1846template <
typename Derived,
typename DataManager >
1853 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_ComputeNode");
1857 stats->LogTimestamp(VdfExecutionStats::NodeDidComputeEvent, node);
1869 if (_DetectInterruption(state, node)) {
1877 VDF_FOR_EACH_SCHEDULED_OUTPUT_ID(outputId, schedule, node) {
1881 _DataHandle dataHandle = _dataManager->GetDataHandle(output.GetId());
1882 TF_DEV_AXIOM(_dataManager->IsValidDataHandle(dataHandle));
1886 _dataManager->GetPrivateBufferData(dataHandle);
1896 "No value set for output " + output.GetDebugName() +
1897 " of type " + output.GetSpec().GetType().GetTypeName() +
1898 " named " + output.GetName().GetString());
1902 output.GetSpec().GetType(),
1904 _dataManager->GetOrCreateOutputValueForWriting(
1905 output, dataHandle));
1911 const bool hasAssociatedInput = output.GetAssociatedInput();
1913 scheduleTask, state, output, outputId, dataHandle,
1914 hasAssociatedInput, privateBuffer);
1918template <
typename Derived,
typename DataManager >
1925 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_PassThroughNode");
1930 VDF_FOR_EACH_SCHEDULED_OUTPUT_ID(outputId, schedule, node) {
1934 _DataHandle dataHandle = _dataManager->GetDataHandle(output.GetId());
1935 TF_DEV_AXIOM(_dataManager->IsValidDataHandle(dataHandle));
1939 _dataManager->GetPrivateBufferData(dataHandle);
1944 const bool hasAssociatedInput = output.GetAssociatedInput();
1946 scheduleTask, state, output, outputId, dataHandle,
1947 hasAssociatedInput, privateBuffer);
1951template <
typename Derived,
typename DataManager >
1958 const _DataHandle dataHandle,
1959 const bool hasAssociatedInput,
1965 const bool hasMultipleInvocations =
1974 _Self()._FinalizeOutput(
1975 state, output, outputId, dataHandle, invocationIndex, passToOutput);
1980 if (!passToOutput) {
1993 _AbsorbPublicBuffer(output, dataHandle, privateMask);
1998 if (scratchBuffer) {
1999 scratchBuffer->
Merge(
2001 _dataManager->PublishScratchBufferData(dataHandle);
2007 _dataManager->PublishPrivateBufferData(dataHandle);
2015 _dataManager->GetScratchBufferData(dataHandle);
2022 const VdfMask &keepMask = hasMultipleInvocations
2033 "VdfParallelExecutorEngineBase::_FinalizeOutput (keep)");
2034 scratchValue->Merge(
2041 if (!hasMultipleInvocations) {
2042 _AbsorbPublicBuffer(
2044 _dataManager->PublishScratchBufferData(dataHandle);
2050template <
typename Derived,
typename DataManager >
2063 _PassOrCopyBuffer(output, *source, mask, schedule, privateBuffer);
2064 _InvokeTouchTask(output, *source);
2071 const VdfInput *input = output.GetAssociatedInput();
2075 if (numInputNodes == 1 && !(*input)[0].GetMask().IsAllZeros()) {
2076 const VdfOutput &source = (*input)[0].GetSourceOutput();
2077 _PassOrCopyBuffer(output, source, mask, schedule, privateBuffer);
2083 _dataManager->CreateOutputCache(output, privateBuffer);
2086template <
typename Derived,
typename DataManager >
2096 bool passBuffer =
false;
2099 _DataHandle sourceHandle = _dataManager->GetDataHandle(source.
GetId());
2100 if (_dataManager->IsValidDataHandle(sourceHandle)) {
2113 passBuffer = !_IsInputDependencyCached(
2114 uniqueIndex, source, inputMask);
2121 _dataManager->GetPrivateBufferData(sourceHandle);
2122 _PassBuffer(sourcePrivateBuffer, privateBuffer);
2127 _CopyBuffer(output, source, inputMask, privateBuffer);
2131template <
typename Derived,
typename DataManager >
2146template <
typename Derived,
typename DataManager >
2154 PEE_TRACE_SCOPE(
"VdfParallelExecutorEngineBase::_CopyBuffer");
2161 if (!sourceVector) {
2168 VdfVector *destValue = _dataManager->CreateOutputCache(output, toBuffer);
2169 destValue->
Copy(*sourceVector, fromMask);
2172template <
typename Derived,
typename DataManager >
2179 VDF_FOR_EACH_SCHEDULED_OUTPUT_ID(outputId, schedule, node) {
2183 const _DataHandle dataHandle =
2184 _dataManager->GetDataHandle(output.GetId());
2185 TF_DEV_AXIOM(_dataManager->IsValidDataHandle(dataHandle));
2189 _dataManager->GetScratchBufferData(dataHandle);
2194 _AbsorbPublicBuffer(
2196 _dataManager->PublishScratchBufferData(dataHandle);
2201template <
typename Derived,
typename DataManager >
2205 const _DataHandle dataHandle,
2210 _dataManager->GetPublicBufferData(dataHandle);
2216 if (!publicValue || publicMask.
IsEmpty() || publicMask == haveMask) {
2229 _dataManager->GetScratchBufferData(dataHandle);
2234 const VdfMask extendedMask = publicMask | haveMask;
2235 if (!scratchValue) {
2236 scratchValue = _dataManager->CreateOutputCache(
2237 output, scratchBuffer, extendedMask.
GetBits());
2243 scratchValue->
Merge(*publicValue, mergeBits);
2245 return scratchValue;
2248template <
typename Derived,
typename DataManager >
2257 const bool hasCycle = _Self()._DetectCycle(state, node);
2265 _isInterrupted.store(
true, std::memory_order_relaxed);
2270 return _HasDetectedInterruption();
2273template <
typename Derived,
typename DataManager >
2278 return _isInterrupted.load(std::memory_order_relaxed);
2281template <
typename Derived,
typename DataManager >
2287 _errors.grow_by(1)->swap(transport);
2290template <
typename Derived,
typename DataManager >
2294 if (_errors.empty()) {
2300 errorTransport.Post();
2307PXR_NAMESPACE_CLOSE_SCOPE