其他分享
首页 > 其他分享> > OneFlow: 启动 Runtime

OneFlow: 启动 Runtime

作者:互联网

前言

我们前面介绍了从 Op 到 Job,又从 Job 到 Plan,这篇文章将会分析运行时(Runtime)启动,分析 Actor 是如何启动的。运行时启动的时机,发生在启动 Session 的时候,将 Job 编译成一个物理可以执行的 Plan 之后,就可以按照 Plan 启动运行时,启动 Actor 了。

流程回顾

运行时 Runtime 在什么时候启动的呢?在 Python 调用 StartLazyGlobalSession 的时候,在这个方法初始化全局 OneFlow 对象,将 JobSet 编译成 Plan,使用这个 Plan 启动 Runtime。

Runtime 的初始化流程如下。我们知道 Plan 是物理上可以执行的计算图,Plan 中的节点 TaskProto 则对应计算图上的节点,一个 Task 对应一个 Actor。Runtime 启动的时候,调用 HandoutTasks 将 Task 分发出去,构造 Actor。

// oneflow/core/job/runtime.cpp
Runtime::Runtime(const Plan& plan, const HashMap<std::string, Blob*>& variable_op_name2eager_blob) {
  {
    // NOTE(chengcheng): All runtime Global objects AddPlan
    Global<RegstMgr>::Get()->AddPlan(plan, variable_op_name2eager_blob);
    Global<ThreadMgr>::Get()->AddPlan(plan);
    Global<RuntimeJobDescs>::Get()->AddPlan(plan);
    collective_boxing_executor_plan_token_ =
        Global<boxing::collective::CollectiveBoxingExecutor>::Get()->AddPlan(plan);
  }
  std::vector<const TaskProto*> source_tasks;
  std::vector<const TaskProto*> other_tasks;
  int64_t this_machine_task_num = 0;
  for (const TaskProto& task : plan.task()) {
    if (task.machine_id() != GlobalProcessCtx::Rank()) { continue; }
    if (!HasNonCtrlConsumedRegstDescId(task)) {
      source_tasks.push_back(&task);
    } else {
      other_tasks.push_back(&task);
    }
    auto it = job_id2actor_size_.find(task.job_id());
    if (it == job_id2actor_size_.end()) {
      auto emplace_ret_pair = job_id2actor_size_.emplace(task.job_id(), 0);
      CHECK(emplace_ret_pair.second);
      it = emplace_ret_pair.first;
    }
    it->second++;
    this_machine_task_num++;
  }
  RuntimeCtx* runtime_ctx = Global<RuntimeCtx>::Get();
  runtime_ctx->NewCounter("constructing_actor_cnt", this_machine_task_num);
  HandoutTasks(source_tasks);
  HandoutTasks(other_tasks);
  runtime_ctx->WaitUntilCntEqualZero("constructing_actor_cnt");
  LOG(INFO) << "Actors on this machine constructed";
  OF_SESSION_BARRIER();
  LOG(INFO) << "Actors on every machine constructed";
  for (auto pair : job_id2actor_size_) {
    runtime_ctx->NewCounter(GetRunningActorCountKeyByJobId(pair.first), pair.second);
  }
  SendCmdMsg(source_tasks, ActorCmd::kStart);
}

HandoutTasks 接受 Task 数组作为参数,这些 Task 将会逐个添加到对应的 Thread 里面,最后通过基于消息机制的 ActorMsgBus 发送构造指令来构造 Actor。

// oneflow/core/job/runtime.cpp: 36
void SendCmdMsg(const std::vector<const TaskProto*>& tasks, ActorCmd cmd) {
  for (const TaskProto* task : tasks) {
    ActorMsg msg = ActorMsg::BuildCommandMsg(task->task_id(), cmd);
    Global<ActorMsgBus>::Get()->SendMsg(msg);
  }
}

void HandoutTasks(const std::vector<const TaskProto*>& tasks) {
  for (const TaskProto* task : tasks) {
    Global<ThreadMgr>::Get()->GetThrd(task->thrd_id())->AddTask(*task);
  }
  SendCmdMsg(tasks, ActorCmd::kConstructActor);
}

ThreadMgr

下面是 ThreadMgr 的头文件,提供了两个成员方法,和两个函数。前面在启动 Runtime 的时候,会将具有全局信息的 Plan 加入到 ThreadMgr 当中。

// oneflow/core/thread/thread_manager.h
namespace oneflow {

class Plan;

class ThreadMgr final {
 public:
  OF_DISALLOW_COPY_AND_MOVE(ThreadMgr);
  ThreadMgr() = default;
  ~ThreadMgr();

  void AddPlan(const Plan& plan);
  Thread* GetThrd(int64_t thrd_id);

 private:
  friend class Global<ThreadMgr>;

  HashMap<int64_t, std::unique_ptr<Thread>> threads_;
};

void SingleThreadLoop(size_t num, std::function<void(size_t i)> Callback);
void MultiThreadLoop(size_t num, std::function<void(size_t i)> Callback);

#define REGISTER_DEVICE_THREAD_CREATOR_WITH_STREAM_ID(device, creator) \
  REGISTER_CLASS_CREATOR(int, device, Thread, creator, const StreamId&)

}  // namespace oneflow

// oneflow/core/thread/thread_manager.cpp: 28
namespace oneflow {

ThreadMgr::~ThreadMgr() {
  for (auto& thread_pair : threads_) {
    ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread);
    thread_pair.second->GetMsgChannelPtr()->Send(msg);
    thread_pair.second.reset();
    LOG(INFO) << "actor thread " << thread_pair.first << " finish";
  }
}

Thread* ThreadMgr::GetThrd(int64_t thrd_id) {
  auto iter = threads_.find(thrd_id);
  CHECK(iter != threads_.end()) << "thread " << thrd_id << " not found";
  return iter->second.get();
}

void ThreadMgr::AddPlan(const Plan& plan) {
  const int64_t this_rank = GlobalProcessCtx::Rank();
  for (const TaskProto& task : plan.task()) {
    TaskId task_id = DeserializeTaskIdFromInt64(task.task_id());
    StreamId stream_id = task_id.stream_id();
    if (stream_id.device_id().rank() != this_rank) { continue; }
    int64_t thrd_id = SerializeStreamIdToInt64(stream_id);
    if (threads_.find(thrd_id) != threads_.end()) { continue; }
    Thread* thread =
        NewObj<int, Thread, const StreamId&>(stream_id.device_id().device_type(), stream_id);
    CHECK_NOTNULL(thread);
    threads_[thrd_id].reset(thread);
  }
}

void SingleThreadLoop(size_t num, std::function<void(size_t i)> Callback) {
  FOR_RANGE(size_t, i, 0, num) { Callback(i); }
}

void MultiThreadLoop(size_t num, std::function<void(size_t i)> Callback) {
  size_t thread_num = Global<ThreadPool>::Get()->thread_num();
  thread_num = std::min(num, thread_num);
  BalancedSplitter bs(num, thread_num);
  BlockingCounter bc(thread_num);
  FOR_RANGE(size_t, range_id, 0, thread_num) {
    Global<ThreadPool>::Get()->AddWork([&bc, &bs, range_id, Callback] {
      FOR_RANGE(size_t, i, bs.At(range_id).begin(), bs.At(range_id).end()) { Callback(i); }
      bc.Decrease();
    });
  }
  bc.WaitUntilCntEqualZero();
}

}  // namespace oneflow

Thread

接下来考察一下 Thread 这个类,从接口来看,这个类提供的接口允许添加 Task,给 Actor 发送消息。从类成员来看,需要存储各种映射,存储线程对象和 mutex,存储当前线程 id,是否使用本地的消息队列,是否开启 light actor。

// oneflow/core/thread/thread.h
namespace oneflow {

class Thread {
 public:
  OF_DISALLOW_COPY_AND_MOVE(Thread);
  virtual ~Thread();

  void AddTask(const TaskProto&);

  Channel<ActorMsg>* GetMsgChannelPtr() { return &msg_channel_; }

  inline void EnqueueActorMsg(const ActorMsg& msg) {
    if (UseLocalMsgQueue()) {
      local_msg_queue_.push(msg);
    } else {
      msg_channel_.Send(msg);
    }
  }

  template<typename InputIt>
  inline void EnqueueActorMsg(InputIt first, InputIt last) {
    if (UseLocalMsgQueue()) {
      for (auto it = first; it != last; ++it) { local_msg_queue_.push(*it); }
    } else {
      for (auto it = first; it != last; ++it) { msg_channel_.Send(*it); }
    }
  }

  void JoinAllActor() { actor_thread_.join(); }

 protected:
  Thread();
  std::thread& mut_actor_thread() { return actor_thread_; }
  void PollMsgChannel(const ThreadCtx& thread_ctx);
  void set_thrd_id(int64_t val) { thrd_id_ = val; }

 private:
  void ConstructActor(int64_t actor_id, const ThreadCtx& thread_ctx);

  inline bool UseLocalMsgQueue() const {
    return local_msg_queue_enabled_ && std::this_thread::get_id() == actor_thread_.get_id();
  }

  HashMap<int64_t, TaskProto> id2task_;
  std::mutex id2task_mtx_;

  std::thread actor_thread_;
  Channel<ActorMsg> msg_channel_;
  HashMap<int64_t, std::unique_ptr<ActorBase>> id2actor_ptr_;
  HashMap<int64_t, int64_t> id2job_id_;
  std::queue<ActorMsg> local_msg_queue_;
  bool local_msg_queue_enabled_;
  int64_t thrd_id_;
  bool light_actor_enabled_;
};

}  // namespace oneflow

Thread 的方法是如何实现的呢?

// oneflow/core/thread/thread.cpp
namespace oneflow {

Thread::Thread() {
  local_msg_queue_enabled_ =
      ParseBooleanFromEnv("ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE", false);
  light_actor_enabled_ = ParseBooleanFromEnv("ONEFLOW_ACTOR_ENABLE_LIGHT_ACTOR", false);
}

Thread::~Thread() {
  actor_thread_.join();
  CHECK(id2task_.empty());
  msg_channel_.Close();
}

void Thread::AddTask(const TaskProto& task) {
  std::unique_lock<std::mutex> lck(id2task_mtx_);
  CHECK(id2task_.emplace(task.task_id(), task).second);
}

void Thread::PollMsgChannel(const ThreadCtx& thread_ctx) {
  while (true) {
    if (local_msg_queue_.empty()) {
      CHECK_EQ(msg_channel_.ReceiveMany(&local_msg_queue_), kChannelStatusSuccess);
    }
    ActorMsg msg = std::move(local_msg_queue_.front());
    local_msg_queue_.pop();
    if (msg.msg_type() == ActorMsgType::kCmdMsg) {
      if (msg.actor_cmd() == ActorCmd::kStopThread) {
        CHECK(id2actor_ptr_.empty());
        break;
      } else if (msg.actor_cmd() == ActorCmd::kConstructActor) {
        ConstructActor(msg.dst_actor_id(), thread_ctx);
        continue;
      } else {
        // do nothing
      }
    }
    int64_t actor_id = msg.dst_actor_id();
    auto actor_it = id2actor_ptr_.find(actor_id);
    CHECK(actor_it != id2actor_ptr_.end());
    int process_msg_ret = actor_it->second->ProcessMsg(msg);
    if (process_msg_ret == 1) {
      LOG(INFO) << "thread " << thrd_id_ << " deconstruct actor " << actor_id;
      auto job_id_it = id2job_id_.find(actor_id);
      const int64_t job_id = job_id_it->second;
      id2job_id_.erase(job_id_it);
      id2actor_ptr_.erase(actor_it);
      Global<RuntimeCtx>::Get()->DecreaseCounter(GetRunningActorCountKeyByJobId(job_id));
    } else {
      CHECK_EQ(process_msg_ret, 0);
    }
  }
}

void Thread::ConstructActor(int64_t actor_id, const ThreadCtx& thread_ctx) {
  std::unique_lock<std::mutex> lck(id2task_mtx_);
  auto task_it = id2task_.find(actor_id);
  std::unique_ptr<ActorBase> actor_ptr;
  const TaskProto& task = task_it->second;
  if (light_actor_enabled_) { actor_ptr = TryNewLightActor(task, thread_ctx); }
  if (!actor_ptr) {
    actor_ptr = NewActor(task, thread_ctx);
    LOG(INFO) << "Thread " << thrd_id_ << " construct Actor " << TaskType_Name(task.task_type())
              << " " << actor_id;
  } else {
    LOG(INFO) << "Thread " << thrd_id_ << " construct LightActor "
              << TaskType_Name(task.task_type()) << " " << actor_id;
  }
  CHECK(id2actor_ptr_.emplace(actor_id, std::move(actor_ptr)).second);
  CHECK(id2job_id_.emplace(actor_id, task.job_id()).second);
  id2task_.erase(task_it);
  Global<RuntimeCtx>::Get()->DecreaseCounter("constructing_actor_cnt");
}

}  // namespace oneflow

搜索代码,看看哪些地方调用了 PollMsgChannel。

两种方法的结构是类似的,通过 std::thread 来启动 PollMsgChannel,接着这个 Thread 将从消息队列中拉取消息,然后执行。那这些 CpuThread 和 GpuThread 又是如何启动的呢?在 ThreadMgr 的 AddPlan 里面!

// oneflow/core/thread/cpu_thread.cpp
namespace oneflow {

CpuThread::CpuThread(int64_t thrd_id) {
  set_thrd_id(thrd_id);
  mut_actor_thread() = std::thread([this, thrd_id]() {
    OF_PROFILER_NAME_THIS_HOST_THREAD("CPU Actor : (" + std::to_string(thrd_id) + ")");
    ThreadCtx ctx;
#ifdef WITH_CUDA
    ctx.cb_event_chan = nullptr;
#endif  // WITH_CUDA
    PollMsgChannel(ctx);
  });
}

REGISTER_DEVICE_THREAD_CREATOR_WITH_STREAM_ID(DeviceType::kCPU,
                                              ([](const StreamId& stream_id) -> Thread* {
                                                return new CpuThread(
                                                    SerializeStreamIdToInt64(stream_id));
                                              }));

}  // namespace oneflow

Actor

前面分析了线程是如何产生的,线程运行的核心是 Actor。一个线程上有多个 Actor,线程通过轮询消息队列,然后将消息发送给不同的 Actor 来执行。真正干活的 Actor 是如何构造,如何执行的呢?

Actor 的构造很简单,通过 TaskProto 上面的类型,去选择一个对应的 Actor 进行初始化。

// oneflow/core/actor/actor_base.cpp
std::unique_ptr<ActorBase> NewActor(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
  ActorBase* rptr = NewObj<int32_t, ActorBase>(task_proto.task_type());
  const auto& job_descs = *Global<RuntimeJobDescs>::Get();
  rptr->Init(&job_descs.job_desc(task_proto.job_id()), task_proto, thread_ctx);
  return std::unique_ptr<ActorBase>(rptr);
}

Actor 的执行通过 ProcessMsg 方法来进行。前面我们已经看到了线程会轮询消息队列来拉取消息,然后将消息发送给对应的 Actor 进行处理。下面的分析可能有点零碎,核心要抓住一点,如何从拿到消息,到启动 Kernel。

// 1: success, and actor finish
// 0: success, and actor not finish
int ProcessMsg(const ActorMsg& msg) override { return (this->*msg_handler_)(msg); }
  // Msg Handler
  void set_msg_handler(MsgHandler val) { msg_handler_ = val; }
#define OF_SET_MSG_HANDLER(val)                                   \
  do {                                                            \
    LOG(INFO) << "actor " << actor_id() << " switch to " << #val; \
    set_msg_handler(static_cast<MsgHandler>(val));                \
  } while (0)
void NaiveActor::VirtualActorInit(const TaskProto&) {
  OF_SET_MSG_HANDLER(&NaiveActor::HandlerNormal);
}
// oneflow/core/actor/actor.cpp: 258
int Actor::HandlerNormal(const ActorMsg& msg) {
  if (msg.msg_type() == ActorMsgType::kEordMsg) {
    remaining_eord_cnt_ -= 1;
    CHECK(eord_regst_desc_ids_.insert(msg.eord_regst_desc_id()).second);
    if (naive_consumed_rs_.HasRegstDescId(msg.eord_regst_desc_id())) {
      is_naive_consumed_eord_ = true;
    } else if (inplace_consumed_rs_.HasRegstDescId(msg.eord_regst_desc_id())) {
      is_inplace_consumed_eord_ = true;
    } else {
      NormalProcessCustomizedEordMsg(msg);
    }
  } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
    if (msg.SrcMachineId() == GlobalProcessCtx::Rank()) {
      Regst* regst = msg.regst();
      if (naive_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) {
        CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst));
        const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(regst->regst_desc_id());
        CHECK(rdeq.empty() == false);
        if (rdeq.front()->regst_desc()->regst_desc_type().has_data_regst_desc()) {
          NormalProcessNaiveReadableDataRegstMsg(rdeq);
        }
      } else if (inplace_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) {
        CHECK_EQ(0, inplace_consumed_rs_.TryPushBackRegst(regst));
        int64_t out_regst_desc_id = inplace_regst_desc_id_in2out_.at(regst->regst_desc_id());
        CHECK(regst->GetSoleBlob()->dptr()
              == inplace_produced_rs_.Front(out_regst_desc_id)->GetSoleBlob()->dptr());
      } else if (TryUpdtStateAsProducedRegst(regst) == 0) {
        // do nothing
      } else {
        NormalProcessCustomizedReadableRegstMsg(msg);
      }
    } else {
      if (NormalTryProcessReadableMsgFromOtherMachine(msg) == false) {
        // process ctrl msg from other rank
        if (IsConsumedCtrlRegstDescId(msg.regst_desc_id())) {
          Regst* regst = msg.regst();
          CHECK(naive_consumed_rs_.HasRegstDescId(msg.regst_desc_id()));
          CHECK(Global<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(msg.regst_desc_id()));
          CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst, msg.regst_desc_id()));
          const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(msg.regst_desc_id());
          CHECK(rdeq.empty() == false);
        } else {
          CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst()), 0);
        }
      }
    }
    ActUntilFail();
  } else if (msg.msg_type() == ActorMsgType::kCmdMsg) {
    CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart);
    ActUntilFail();
  } else {
    UNIMPLEMENTED();
  }
  // handler halts
  bool has_naive_or_inplace = naive_consumed_rs_.total_regst_desc_cnt() != 0
                              || inplace_consumed_rs_.total_regst_desc_cnt() != 0;
  bool naive_or_inplace_eord_and_empty =
      (is_naive_consumed_eord_ || is_inplace_consumed_eord_)
      && (naive_consumed_rs_.available_regst_desc_cnt() == 0
          && inplace_consumed_rs_.available_regst_desc_cnt() == 0);
  bool customized_eord = IsCustomizedReadAlwaysUnReadyFromNow();
  if ((has_naive_or_inplace && naive_or_inplace_eord_and_empty)
      || (!has_naive_or_inplace && customized_eord)) {
    CHECK_EQ(naive_consumed_rs_.available_regst_desc_cnt(), 0);
    AsyncReturnAllCustomizedReadableRegst();
    AsyncSendEORDMsgForAllProducedRegstDesc();
    if (remaining_eord_cnt_ == 0 && total_reading_cnt_ == 0) {
      OF_SET_MSG_HANDLER(nullptr);
      return 1;
    } else {
      OF_SET_MSG_HANDLER(&Actor::HandlerZombie);
      return 0;
    }
  }
  return 0;
}
// oneflow/core/actor/actor.cpp
void Actor::ActUntilFail() {
  while (IsReadReady() && IsWriteReady()) {
    Act();

    AsyncSendCustomizedProducedRegstMsgToConsumer();
    AsyncSendNaiveProducedRegstMsgToConsumer();
    AsyncSendInplaceProducedRegstMsgToConsumer();

    AsyncSendCustomizedConsumedRegstMsgToProducer();
    AsyncSendNaiveConsumedRegstMsgToProducer();
    AsyncRetInplaceConsumedRegstIfNoConsumer();

    AsyncSendQueuedMsg();
  }
  // NOTE(liujuncheng): return inplace consumed
  AsyncSendQueuedMsg();
}
void NaiveActor::Act() {
  KernelCtx kernel_ctx = GenDefaultKernelCtx();
  AsyncLaunchKernel(kernel_ctx, [&](int64_t regst_desc_id) -> Regst* { return nullptr; });
}
// oneflow/core/actor/actor.h: 58
struct ExecKernel {
  std::unique_ptr<const Kernel> kernel;
  HashMap<std::string, BlobInfo> bn_in_op2blob_info;
};

// oneflow/core/actor/actor.cpp: 470
void Actor::AsyncLaunchKernel(const KernelCtx& kernel_ctx,
                              std::function<Regst*(int64_t)> Regst4RegstDescId) {
  for (const ExecKernel& ek : exec_kernel_vec_) {
    ek.kernel->Launch(kernel_ctx, [&](const std::string& bn_in_op) -> Blob* {
      const auto blob_info_it = ek.bn_in_op2blob_info.find(bn_in_op);
      if (blob_info_it == ek.bn_in_op2blob_info.cend()) { return nullptr; }
      const BlobInfo& info = blob_info_it->second;
      if (info.regst_desc_id == -1) { return nullptr; }
      Regst* regst;
      if (info.rs != nullptr) {
        regst = info.rs->Front(info.regst_desc_id);
      } else {
        regst = Regst4RegstDescId(info.regst_desc_id);
      }
      if (regst == nullptr) { return nullptr; }
      if (info.ordinal >= 0) {
        return regst->GetBlobByOrdinal(info.ordinal);
      } else {
        return regst->GetBlobByLbi(info.lbi);
      }
    });
  }
}
// oneflow/core/kernel/kernel.cpp: 43
void Kernel::Launch(const KernelCtx& ctx,
                    const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
  Forward(ctx, BnInOp2Blob);
}

void Kernel::Forward(const KernelCtx& ctx,
                     const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
  if (!blob_access_checker_disabled_) { SetOutputBlobProducerInferAccessChecker(BnInOp2Blob); }
  ForwardHeader(ctx, BnInOp2Blob);
  if ((!kernel_conf_.all_blobs_are_static())
      && IsAllBlobEmpty(op_attribute().output_bns(), BnInOp2Blob) && IsStateless()) {
    return;
  }
  if (!blob_access_checker_disabled_) { SetOutputBlobProducerComputeAccessChecker(BnInOp2Blob); }
  OF_PROFILER_ONLY_CODE(profiler::TraceKernelForwardDataContentStart(this, ctx, BnInOp2Blob));
  ForwardDataContent(ctx, BnInOp2Blob);
  OF_PROFILER_ONLY_CODE(profiler::TraceKernelForwardDataContentEnd(this, ctx, BnInOp2Blob));
  if (!blob_access_checker_disabled_) { SetOutputBlobConsumerAccessChecker(BnInOp2Blob); }
}
void UserKernel::ForwardDataContent(
    const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
  ForwardUserKernel(BnInOp2Blob, opkernel_state_.get());
}

void UserKernel::ForwardUserKernel(const std::function<Blob*(const std::string&)>& BnInOp2Blob,
                                   user_op::OpKernelState* opkernel_state) const {
  const bool updated = ctx_->UpdateTensorWithCorrBlob(BnInOp2Blob);

#ifdef WITH_CUDA_GRAPHS
  bool capturing = false;
  if (cuda_graph_ctx_) {
    if (!cuda_graph_ctx_->IsCapturing()) {
      if (cuda_graph_ctx_->IsCaptured() && (!updated)) {
        cuda_graph_ctx_->Launch();
        return;
      }
      capturing = true;
      cuda_graph_ctx_->BeginCapture();
    }
  }
#endif  // WITH_CUDA_GRAPHS

  kernel_->Compute(ctx_.get(), opkernel_state);

#ifdef WITH_CUDA_GRAPHS
  if (cuda_graph_ctx_ && capturing) {
    cuda_graph_ctx_->EndCapture();
    cuda_graph_ctx_->Launch();
  }
#endif  // WITH_CUDA_GRAPHS
}
// oneflow/user/kernels/add_n_kernel.cpp: 22
template<typename T>
void cpu_add(const int64_t n, T* out, const std::vector<const T*>& in) {
  for (int64_t i = 0; i != n; ++i) {
    out[i] = in.at(0)[i];
    for (int32_t j = 1; j < in.size(); ++j) { out[i] += in.at(j)[i]; }
  }
}

// oneflow/user/kernels/add_n_kernel.cpp: 32
template<typename T>
class CpuAddNKernel : public user_op::OpKernel {
 public:
  CpuAddNKernel() = default;
  ~CpuAddNKernel() = default;

  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }

 private:
  void Compute(user_op::KernelComputeContext* ctx) const override {
    size_t in_num = ctx->inputs().size();

    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
    int64_t n = out->shape().elem_cnt();
    T* out_dptr = out->mut_dptr<T>();

    std::vector<const T*> in_dptrs(in_num);
    for (int32_t i = 0; i < in_num; ++i) {
      in_dptrs.at(i) = ctx->Tensor4ArgNameAndIndex("in", i)->dptr<T>();
    }

    cpu_add<T>(n, out_dptr, in_dptrs);
  }
};

总结

这篇文章从 Runtime 启动开始,讲了如何启动线程,启动 Actor。线程通过轮询消息队列拉取消息,将消息转发给对应的 Actor 去执行。Actor 将启动 Kernel,Kernel 从 KernelComputeContext 获取输入和输出的信息,最后执行运算。

标签:OneFlow,const,thread,启动,regst,_.,msg,Runtime,id
来源: https://www.cnblogs.com/zzk0/p/15226851.html