其他分享
首页 > 其他分享> > OneFlow: 从 Job 到 Plan

OneFlow: 从 Job 到 Plan

作者:互联网

前言

前面分析了如何从一个个 Op 变到 Job,这篇将分析如何从一个个 Job 变成一个 Plan。

Plan

首先来分析看看我们的目标是什么?我们的目标就是一个物理上可以执行的 Plan。OneFlow 在计算上的设计采用了 Actor 机制,计算图上的每个节点由一个 Actor 完成执行。那么 Plan 是如何为 Actor 机制做抽象的呢?我觉得 Actor 由计算和存储组成,计算需要考虑算子和 kernel,存储位置需要看 Regst。因此,作为 Actor 机制的上层抽象的 Plan 需要如何抽象呢?

我们来看看 Plan 这个数据结构。

message Plan {
  repeated TaskProto task = 1;
  required MemBlockAndChunkList block_chunk_list = 2;
  required JobConfs job_confs = 4;
  required CollectiveBoxingPlan collective_boxing_plan= 5;
  required CtrlRegstDescInfo ctrl_regst_desc_info = 6;
  map<int64, OpAttributeRefTable> job_id2op_attribute_ref_table = 7;
}

message JobConfs {
  map<int64, JobConfigProto> job_id2job_conf = 1;
}

TaskProto

前面我们看到 TaskProto 是 Plan 中一个可以重复的属性,就好像计算图中的计算节点一样。我们需要把关注的焦点放到计算和存储。Actor 和 Actor 之间的关联不是显式的声明出来的,它隐藏在 Actor 的机制中。在一个 Actor 的输入和输出准备好了之后,就执行计算。因此,Actor 之间的关联并不需要显式声明出来,通过输入和输出关联起来即可。

message TaskProto {
  // common
  required TaskType task_type = 1;
  required int64 machine_id = 2;
  required int64 thrd_id = 3;
  required int64 task_id = 4;
  required int64 job_id = 5;
  required TaskSetInfo task_set_info = 6;
  required ExecSequence exec_sequence = 7;
  map<string, RegstDescProto> produced_regst_desc = 8;
  map<string, RegstDescIdSet> consumed_regst_desc_id = 9;
  optional bool all_register_num_eq_one_hint = 10 [default = false];
  // compute task
  optional ParallelContext parallel_ctx = 1000; // CompTask
};
message ExecNodeProto {
  required KernelConf kernel_conf = 1;
  map<string, int64> bn_in_op2regst_desc_id = 2;
}

message ExecSequence {
  repeated ExecNodeProto exec_node = 1;
}

有个地方值得关注:在 OperatorConf 里面,有一个 op_type 属性,这个属性是一个 oneof,其中有一个是 UserOpConf,这个就是用户定义算子的配置:名字、输入、输出、属性。

message UserOpConf {
  message ListString {
    repeated string s = 1;
  }
  required string op_type_name = 1;
  map<string, ListString> input = 2;
  map<string, ListString> output = 3;
  map<string, AttrValue> attr = 4;
}

整体流程

上一篇启动 Session 的时候,已经分析过整体流程了,没有深入细节。这里再简单复述一下。

调用流程

CompileJobsAndMergePlans 的主要工作如下所示:

后面主要关注单个 Job 的编译,MainJob 的生成、编译、链接。

编译单个 Job

// oneflow/core/job/oneflow.cpp: 203
Maybe<void> CompileCurJobOnMaster(Job* job, Plan* plan, bool need_job_complete) {
  const JobDesc& job_desc = GlobalJobDesc();
  if (GlobalProcessCtx::IsThisProcessMaster()) {
    double start = GetCurTime();
    Compiler().Compile(job, plan, need_job_complete);
    PlanUtil::GenMemBlockAndChunk4Plan(plan);

    LOG(INFO) << "\njob_id: " << job_desc.job_id() << " , job_name: " << job_desc.job_name()
              << " , compile time: " << (GetCurTime() - start) / 1000000000.0 << " seconds.\n";
    if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
      TeePersistentLogStream::Create(StrCat("subplan_job_", job_desc.job_id()))->Write(*plan);
    }
  }
  PlanUtil::GenCollectiveBoxingPlan(job, plan);
  PlanUtil::GenRegisterHint(plan);
  return Maybe<void>::Ok();
}

在 Compile 这个方法里面,通过注释可以看到编译分为五步。

void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
  // Step1: ensure job is completed.
  if (need_job_complete) { CHECK_JUST(JobCompleter().Complete(job)); }

  // Step2: new Global<OpGraph> and set log configs.
  Global<OpGraph>::New(*job);
  const JobDesc& job_desc = GlobalJobDesc();
  if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()
      || Global<ResourceDesc, ForSession>::Get()->enable_dry_run()) {
    TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(*job);
    Global<OpGraph>::Get()->ToDotWithFilePath("optimized_dlnet_" + std::to_string(job_desc.job_id())
                                              + "_op_graph.dot");
  }

  // Step3: build task_gph.
  // TODO(levi): we can rewrite this part of code in visitor pattern.
  auto task_gph = std::make_unique<TaskGraph>();
  using std::placeholders::_1;
  task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));
  task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));
  task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
  task_gph->TopoForEachNode(&TaskNode::Build);
  task_gph->RemoveEmptyRegsts();
  task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain();
  auto IsReachable = Global<OpGraph>::Get()->MakePredicatorIsOpNameDataOrCtrlReachable();
  if (job_desc.enable_inplace()) { task_gph->EnableInplaceMemSharing(IsReachable); }
  task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful);
  task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); });

  // Step4: put infomation from task_gph into plan.
  const int64_t node_num = task_gph->node_num();
  const int64_t cpu_num = std::thread::hardware_concurrency();
  const int64_t thread_pool_size = std::min(node_num, cpu_num);
  BlockingCounter counter(node_num);
  std::mutex mtx;
  ThreadPool thread_pool(thread_pool_size);
  task_gph->ForEachNode([&](TaskNode* task_node) {
    thread_pool.AddWork([task_node, plan, &job_desc, &counter, &mtx]() {
      if (!task_node->IsMeaningLess()) {
        TaskProto task_proto;
        task_node->ToProto(&task_proto);
        {
          std::unique_lock<std::mutex> guard(mtx);
          if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat
              || task_node->GetTaskType() == kAcc) {
            CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);
          }
          plan->mutable_task()->Add(std::move(task_proto));
        }  // guard(mtx)
      }
      counter.Decrease();
    } /* thread_pool.AddWork */);
  } /* task_gph->ForEachNode */);
  counter.WaitUntilCntEqualZero();
  // NOTE(levi): release task_gph here to decrise memory peak.
  task_gph.reset();

  // Step5: post-process for plan and delete Global<OpGraph>.
  auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf();
  (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf();
  // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl
  IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable);
  PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan);
  Global<OpGraph>::Delete();
}

MainJob

图片来源:https://zhuanlan.zhihu.com/p/337851255

MainJob 的作用是什么呢?

生成、编译、链接

MainJob 如何来的呢?主要有三个步骤:生成、编译、链接。

生成过程

使用 JobBuilder 来构建 MainJob。执行过程大致如下:

// oneflow/core/job/oneflow.cpp: 457
Maybe<ReentrantLockBackEdge> MakeMainJobComponent(
    const std::string& wait_and_send_ids_lbn, const Range& machine_id_range,
    JobBuilder* job_builder, std::vector<std::map<int64_t, std::string>>* identity_tick_op_names,
    std::vector<std::map<int64_t, std::string>>* cb_sink_tick_op_names) {
  ParallelConf parallel_conf;
  parallel_conf.set_device_tag("cpu");
  parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id_range.begin()) + ":0");
  auto lock_back_edge = std::make_shared<ReentrantLockBackEdge>();
  OperatorConf reentrant_lock_op_conf;
  {
    lock_back_edge->reentrant_lock_op_name =
        std::string("System-Main-ReentrantLock_") + NewUniqueId();
    reentrant_lock_op_conf.set_name(lock_back_edge->reentrant_lock_op_name);
    auto* reentrant_lock_conf = reentrant_lock_op_conf.mutable_reentrant_lock_conf();
    reentrant_lock_conf->set_start(wait_and_send_ids_lbn);
    // ibn "end" is set after plan generated because we don't like cycle in job
    reentrant_lock_conf->set_out("out");
    Global<CriticalSectionDesc>::Get()->DumpCriticalSectionId2IntersectinIds(
        reentrant_lock_conf->mutable_lock_id2intersecting_lock_ids());
    JUST(job_builder->AddOp(parallel_conf, reentrant_lock_op_conf));
  }
  // critical section case op conf
  OperatorConf cs_case_op_conf;
  {
    cs_case_op_conf.set_name(std::string("System-Main-Case_") + NewUniqueId());
    auto* cs_case_conf = cs_case_op_conf.mutable_case_conf();
    cs_case_conf->set_in(reentrant_lock_op_conf.name() + "/out");
    FOR_RANGE(int64_t, i, 0, Global<CriticalSectionDesc>::Get()->CriticalSectionNum()) {
      cs_case_conf->add_out(GenRepeatedBn("out", i));
    }
    JUST(job_builder->AddOp(parallel_conf, cs_case_op_conf));
  }
  const int64_t num_critial_sections = Global<CriticalSectionDesc>::Get()->CriticalSectionNum();
  std::vector<std::string> snk_tick_op_names;
  FOR_RANGE(int64_t, i, 0, num_critial_sections) {
    // source tick
    OperatorConf src_tick_op_conf;
    {
      std::string name_prefix = "System-Main-SourceTick_CriticalSection_";
      src_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" + NewUniqueId());
      auto* src_tick_conf = src_tick_op_conf.mutable_tick_conf();
      src_tick_conf->add_tick(cs_case_op_conf.name() + "/" + GenRepeatedBn("out", i));
      src_tick_conf->set_out("out");
      JUST(job_builder->AddOp(parallel_conf, src_tick_op_conf));
    }

    auto* cur_cb_sink_tick_op_names = &cb_sink_tick_op_names->at(i);
    for (int64_t machine_id = machine_id_range.begin(); machine_id < machine_id_range.end();
         ++machine_id) {
      // identity tick
      OperatorConf identity_tick_op_conf;
      {
        std::string name_prefix = "System-Main-Tick_CriticalSection_";
        identity_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" + NewUniqueId());
        auto* identity_tick_conf = identity_tick_op_conf.mutable_tick_conf();
        identity_tick_conf->add_tick(src_tick_op_conf.name() + "/out");
        identity_tick_conf->set_out("out");
        JUST(job_builder->AddOp(parallel_conf, identity_tick_op_conf));
        auto* cur_id_tick_op_names = &identity_tick_op_names->at(i);
        CHECK_OR_RETURN(
            cur_id_tick_op_names->emplace(machine_id, identity_tick_op_conf.name()).second);
      }
      // callback
      {
        OperatorConf cb_sink_tick_op_conf;
        std::string name_prefix = "System-Main-CallbackSinkTick_";
        cb_sink_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId());
        auto* cb_sink_tick_conf = cb_sink_tick_op_conf.mutable_sink_tick_conf();
        cb_sink_tick_conf->add_tick(identity_tick_op_conf.name() + "/out");
        cb_sink_tick_conf->set_out("out");
        JUST(job_builder->AddOp(parallel_conf, cb_sink_tick_op_conf));
        CHECK_OR_RETURN(
            cur_cb_sink_tick_op_names->emplace(machine_id, cb_sink_tick_op_conf.name()).second);
      }
      // sink tick
      {
        OperatorConf snk_tick_op_conf;
        std::string name_prefix = "System-Main-SinkTick_CriticalSection_";
        snk_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId());
        auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf();
        snk_tick_conf->add_tick(identity_tick_op_conf.name() + "/out");
        snk_tick_conf->set_out("out");
        JUST(job_builder->AddOp(parallel_conf, snk_tick_op_conf));
        snk_tick_op_names.push_back(snk_tick_op_conf.name());
      }
    }
  }
  // critical section esac op conf
  OperatorConf cs_esac_op_conf;
  {
    cs_esac_op_conf.set_name(std::string("System-Main-Esac_") + NewUniqueId());
    // cs_esac_op_conf.set_pass_tag("main");
    auto* cs_esac_conf = cs_esac_op_conf.mutable_esac_conf();
    for (const auto& snk_tick_op_name : snk_tick_op_names) {
      cs_esac_conf->add_in(snk_tick_op_name + "/out");
    }
    cs_esac_conf->set_out("out");
    cs_esac_conf->set_data_type(DataType::kInt32);
    JUST(job_builder->AddOp(parallel_conf, cs_esac_op_conf));
  }
  lock_back_edge->critical_section_sink_lbi.set_op_name(cs_esac_op_conf.name());
  lock_back_edge->critical_section_sink_lbi.set_blob_name("out");
  return lock_back_edge;
}

编译

同样只在 Master 上面编译。设置编译的 scope,然后编译。

// oneflow/core/job/oneflow.cpp: 732
Maybe<void> CompileMainJob(Job* main_job, const std::vector<ReentrantLockBackEdge>& lock_back_edges,
                           int64_t job_id, Plan* main_plan) {
  CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());
  {
    auto scope = std::make_unique<GlobalJobDescScope>(main_job->job_conf(), job_id);
    JUST(CompileCurJobOnMaster(main_job, main_plan, false));
  }
  for (const auto& lock_back_edge : lock_back_edges) {
    JUST(ConnectCriticalSectionEndToReentrantLockEnd(main_plan, lock_back_edge));
  }
  return Maybe<void>::Ok();
}

链接

从效果来看,将所有其他的 Job 的临界区,加入到 Main Plan 里面,构成一个大的 Plan。

// oneflow/core/job/oneflow.cpp: 306
void LinkMainPlan(Plan* plan, Plan&& main_plan,
                  const std::vector<std::map<int64_t, std::string>>& identity_tick_op_names) {
  std::function<bool(const TaskProto*)> IsInterfaceTickTockTask;
  {
    auto task_ids = std::make_shared<HashSet<int64_t>>();
    for (const auto& task : main_plan.task()) {
      if (task.task_type() == TaskType::kTick) { CHECK(task_ids->emplace(task.task_id()).second); }
    }
    IsInterfaceTickTockTask = [task_ids, plan](const TaskProto* task) {
      if (task_ids->find(task->task_id()) != task_ids->end()) { return true; }
      if (task->exec_sequence().exec_node_size() != 1) { return false; }
      const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
      OperatorConf::OpTypeCase op_type_case =
          PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().op_type_case();
      return op_type_case == OperatorConf::kSourceTickConf
             || op_type_case == OperatorConf::kSinkTickConf;
    };
  }
  MergePlan(plan, std::move(main_plan));
  HashMap<std::string, TaskProto*> sole_tick_op_name2sole_task;
  FOR_RANGE(int64_t, i, 0, plan->task_size()) {
    TaskProto* task = plan->mutable_task(i);
    if (IsInterfaceTickTockTask(task) == false) { continue; }
    const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
    const auto& op_name =
        PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name();
    CHECK(sole_tick_op_name2sole_task.emplace(op_name, task).second);
  }
  auto TaskProto4TaskId = PlanUtil::MakeGetterTaskProto4TaskId(*plan);
  const auto& process_ranks = Global<ResourceDesc, ForSession>::Get()->process_ranks();
  FOR_RANGE(int32_t, i, 0, Global<CriticalSectionDesc>::Get()->CriticalSectionNum()) {
    const CriticalSection& cs = Global<CriticalSectionDesc>::Get()->GetCriticalSection(i);
    for (int64_t machine_id : process_ranks) {
      TaskProto* identity_tick =
          sole_tick_op_name2sole_task.at(identity_tick_op_names.at(i).at(machine_id));
      LinkTickTaskProto(
          plan, identity_tick,
          sole_tick_op_name2sole_task.at(cs.machine_id2source_tick_op_name().at(machine_id)),
          sole_tick_op_name2sole_task.at(cs.machine_id2sink_tick_op_name().at(machine_id)));
    }
  }
  {
    // erase source_tick task_proto
    HashSet<std::string> source_tick_op_names;
    FOR_RANGE(int32_t, i, 0, Global<CriticalSectionDesc>::Get()->CriticalSectionNum()) {
      const CriticalSection& cs = Global<CriticalSectionDesc>::Get()->GetCriticalSection(i);
      for (int64_t machine_id : process_ranks) {
        const auto& src_tick_op_name = cs.machine_id2source_tick_op_name().at(machine_id);
        CHECK(source_tick_op_names.emplace(src_tick_op_name).second);
      }
    }
    Erase<PbRpf<TaskProto>>(*plan->mutable_task(), [&](const TaskProto& task) {
      if (task.task_type() == TaskType::kSourceTick) {
        CHECK(task.exec_sequence().exec_node_size() == 1);
        const auto& kernel_conf = task.exec_sequence().exec_node(0).kernel_conf();
        const auto& op_conf = PlanUtil::GetOpAttribute(plan, task.job_id(), kernel_conf).op_conf();
        CHECK(op_conf.has_source_tick_conf());
        CHECK(source_tick_op_names.find(op_conf.name()) != source_tick_op_names.end());
        return true;
      } else {
        return false;
      }
    });
  }
}

标签:OneFlow,name,job,Job,task,Plan,conf,tick,op
来源: https://www.cnblogs.com/zzk0/p/15222259.html