8. 通用 Awaiter

Benny Huo大约 13 分钟

8. 通用 Awaiter

每次我们新增功能时,都需要修改 TaskPromise 增加对应的 await_transform 函数,这显然是个设计问题。

问题背景

我们前面在实现无阻塞 sleep 和 Channel 的时候都需要专门实现对应的 Awaiter 类型,并且在 TaskPromise 当中添加相应的 await_transform 函数。增加新类型这没什么问题,但如果每增加一个新功能就要对原有的 TaskPromise 类型做修改,这说明 TaskPromise 的扩展性不够好。

当然,有读者会说,如果我们把所有的 await_transform 函数都去掉,改成给对应的类型实现 operator co_await 来获取 Awaiter(例如 sleep 的例子当中通过 duration 转 Awaiter) 或者干脆就自己就定义成 Awaiter(例如 Channel 当中的 ReadAwaiter),这样我们就不用总是修改 TaskPromise 了。话虽如此,但完全由外部定义 Awaiter 对象的获取会使得调度器无法被包装正确使用,甚至我们在定义 TaskPromise 的时候把调度器定义成私有成员,因为我们根本不希望外部能够轻易获取到调度器的实例。

使用 await_transform 本质上就是为了保证调度器的正确应用,却带来了扩展上的问题,那这是说 C++ 协程的设计有问题吗?当然也不是。我们完全可以定义一个 Awaiter 类型,外部只需要继承这个 Awaiter 在受限的范围内自定义逻辑,完成自己的需求同时也能保证调度器的调度。

通用的 await_transform

了解了需求背景之后,我们只需要在 TaskPromise 当中定义一个更加通用版本的 await_transform,来为 Awaiter 提供调度器:

template<typename ResultType, typename Executor>
struct TaskPromise {
  template<typename AwaiterImpl>
  AwaiterImpl await_transform(AwaiterImpl awaiter) {
    awaiter.install_executor(&executor);
    return awaiter;
  }
  ...
}

你看得没错,我们真的只是给这个通用的 Awaiter 添加了当前协程的调度器。

Awaiter 的定义

既然 Awaiter 的核心是调度器,我们可以直接给出它的基本定义:

template<typename R>
struct Awaiter {
  ...

  void install_executor(AbstractExecutor *executor) {
    _executor = executor;
  }

 private:
  AbstractExecutor *_executor = nullptr;
  ...

  // 方便用调度器调度任意逻辑,这里也处理了调度器为空的情况
  void dispatch(std::function<void()> &&f) {
    if (_executor) {
      _executor->execute(std::move(f));
    } else {
      f();
    }
  }
};

作为 Awaiter 本身,当然也得有标准当中定义的基本的三个函数要求:

template<typename R>
struct Awaiter {

  // 简单处理,永远挂起,当然这也是协程调度的一个潜在的优化点
  bool await_ready() const { return false; }

  void await_suspend(std::coroutine_handle<> handle) {
    // 记录当前协程的 handle,方面后面恢复
    this->_handle = handle;
    ...
  }

  R await_resume() {
    ...
    // 返回 co_await 的结果,当然对于 void 的情况,我们也会有特化版本
    return _result->get_or_throw();
  }
 protected:
  // 结果对子类可见,方便灵活操作
  std::optional<Result<R>> _result{}; 

 private:
  AbstractExecutor *_executor = nullptr;
  // 保存协程的 handle,恢复时会用到,私有化这个成员目的是将其逻辑封装,避免滥用
  std::coroutine_handle<> _handle = nullptr;
  ...
}

这几个函数是协程在挂起和恢复时调用的。我们将协程 handle 的保存和结果的返回逻辑固化,因为几乎所有的 Awaiter 都有这样的需求。不过协程的挂起后和恢复前是两个非常重要的时间点,扩展 Awaiter 时经常需要在这两个时间点实现定义化的业务逻辑,因此我们需要定义两个虚函数让子类按需实现:

template<typename R>
struct Awaiter {

  bool await_ready() const { return false; }

  void await_suspend(std::coroutine_handle<> handle) {
    this->_handle = handle;
    // 调用 after_suspend,子类可以自定义这个函数来处理需要的逻辑
    after_suspend();
  }

  R await_resume() {
    // 调用 before_resume,子类可以自定义这个函数来处理需要的逻辑
    before_resume();
    return _result->get_or_throw();
  }
  ...

 protected:
  std::optional<Result<R>> _result{};

  virtual void after_suspend() {}

  virtual void before_resume() {}

  ...
}

剩下的就是协程的恢复了,这时候我们要求必须使用调度器进行调度。为了防止外部不按要求处理调度逻辑,我们将调度器和协程的 handle 都定义为私有成员,因此我们也需要提供相应的函数来封装协程恢复的逻辑:

template<typename R>
struct Awaiter {

  ...

  // 协程恢复时,co_await 表达式返回 value
  void resume(R value) {
    dispatch([this, value]() {
      // 将 value 封装到 _result 当中,await_resume 时会返回 value
      _result = Result<R>(static_cast<R>(value));
      _handle.resume();
    });
  }

  // 不提供 value,但也要恢复协程,这种情况需要子类在 before_resume 当中写入 _result,或者抛出异常
  // 我们将会在 Channel 关闭时用到这个函数
  void resume_unsafe() {
    dispatch([this]() { _handle.resume(); });
  }

  // 挂起点出现异常,用异常来恢复协程
  void resume_exception(std::exception_ptr &&e) {
    dispatch([this, e]() {
      _result = Result<R>(static_cast<std::exception_ptr>(e));
      _handle.resume();
    });
  }
  ...

}

这样一来,如果我们想要扩展新功能,只需要继承 Awaiter,在 after_suspend 当中或者之后找个合适的时机调用 resume/resume_unsafe/resume_exception 三个函数当中的任意一个来恢复协程即可。如果在恢复前有其他逻辑需要处理,也可以覆写 before_resume 来实现。

Awaiter 的应用

接下来我们使用 Awaiter 对现有的几个 awaiter 类型做重构,之后再尝试基于 Awaiter 做一点小小的扩展。

重构 SleepAwaiter

SleepAwaiter 是最简单的一个。我们当初为了让无阻塞的 sleep 看上去更加自然,直接对 duration 做了支持,于是可以写出下面的代码:

Task<void, LooperExecutor> task() {
  co_await 300ms;
  ...
}

duration 的支持源自于在 TaskPromise 当中添加了 durationSleepAwaiterawaiter_transform 函数:

template<typename _Rep, typename _Period>
SleepAwaiter await_transform(std::chrono::duration<_Rep, _Period> &&duration) {
  return SleepAwaiter(&executor, std::chrono::duration_cast<std::chrono::milliseconds>(duration).count());
}

如果不要求对 duration 直接支持的话,我们其实也可以这么设计:

template<typename _Rep, typename _Period>
SleepAwaiter await_transform(SleepAwaiter awaiter) {
  // 保存调度器,后面调度用
  awaiter._executor = &executor;
  return awaiter;
}

这与我们前面给出的通用 Awaiter 版本的 await_transform 如出一辙:

template<typename AwaiterImpl>
AwaiterImpl await_transform(AwaiterImpl awaiter) {
  // 传入调度器,后面调度用
  awaiter.install_executor(&executor);
  return awaiter;
}

因此我们可以使用通用的 Awaiter 重构 SleepAwaiter,下面我们给出重构前和重构后的对比:

重构前

struct SleepAwaiter {

  explicit SleepAwaiter(AbstractExecutor *executor, long long duration) noexcept
      : _executor(executor), _duration(duration) {}

  bool await_ready() const { return false; }

  void await_suspend(std::coroutine_handle<> handle) const {
    static Scheduler scheduler;

    scheduler.execute([this, handle]() {
      _executor->execute([handle]() {
        handle.resume();
      });
    }, _duration);
  }

  void await_resume() {}

 private:
  AbstractExecutor *_executor;
  long long _duration;
};

重构后

struct SleepAwaiter : Awaiter<void> {

  explicit SleepAwaiter(long long duration) noexcept
      : _duration(duration) {}

  // 新增一个支持 duration 的构造器,方便外部使用
  template<typename _Rep, typename _Period>
  explicit SleepAwaiter(std::chrono::duration<_Rep, _Period> &&duration) noexcept
      : _duration(std::chrono::duration_cast<std::chrono::milliseconds>(duration).count()) {}

  void after_suspend() override {
    // 这部分逻辑以前写在 await_suspend 当中
    // 现在我们写在覆写的 after_suspend 当中
    // 调用位置实际上没有变化,但我们不用再关心 handle 和 调度器了
    static Scheduler scheduler;
    scheduler.execute([this] { resume(); }, _duration);
  }

 private:
  long long _duration;
};

重构之后,我们无需单独为 SleepAwaiter 添加 await_transform 的支持,就可以写出下面的代码:

Task<void, LooperExecutor> task()) {
    // co_await 300ms;
    // 等价于前面的 co_await 300ms
    co_await SleepAwaiter(300ms);
  }
}

如果觉得不够美观,也可以定义一个协程版本的函数 sleep_for:

template<typename _Rep, typename _Period>
SleepAwaiter sleep_for(std::chrono::duration<_Rep, _Period> &&duration) {
  return SleepAwaiter(duration);
}

这样写出来的代码就变成了:

Task<void, LooperExecutor> task()) {
    // co_await 300ms;
    // 等价于前面的 co_await 300ms
    co_await sleep_for(300ms);
  }
}

重构 Channel 的 Awaiter

Channel 有两个 Awaiter,分别是 ReaderAwaiterWriterAwaiter,以前者为例:

重构前

template<typename ValueType>
struct ReaderAwaiter {
  Channel<ValueType> *channel;
  AbstractExecutor *executor = nullptr;
  ValueType _value;
  ValueType *p_value = nullptr;
  std::coroutine_handle<> handle;

  explicit ReaderAwaiter(Channel<ValueType> *channel) : channel(channel) {}

  ReaderAwaiter(ReaderAwaiter &&other) noexcept
      : channel(std::exchange(other.channel, nullptr)),
        executor(std::exchange(other.executor, nullptr)),
        _value(other._value),
        p_value(std::exchange(other.p_value, nullptr)),
        handle(other.handle) {}

  bool await_ready() { return false; }

  auto await_suspend(std::coroutine_handle<> coroutine_handle) {
    this->handle = coroutine_handle;
    channel->try_push_reader(this);
  }

  int await_resume() {
    channel->check_closed();
    channel = nullptr;
    return _value;
  }

  void resume(ValueType value) {
    this->_value = value;
    if (p_value) {
      *p_value = value;
    }
    resume();
  }

  void resume() {
    if (executor) {
      executor->execute([this]() { handle.resume(); });
    } else {
      handle.resume();
    }
  }

  ~ReaderAwaiter() {
    if (channel) channel->remove_reader(this);
  }
};

这代码大家已经见过,这里同样贴出来只是为了让大家能够直接对比:

重构后

template<typename ValueType>
struct ReaderAwaiter : public Awaiter<ValueType> {
  Channel<ValueType> *channel;
  ValueType *p_value = nullptr;

  explicit ReaderAwaiter(Channel<ValueType> *channel) : Awaiter<ValueType>(), channel(channel) {}

  ReaderAwaiter(ReaderAwaiter &&other) noexcept
      : Awaiter<ValueType>(other),
        channel(std::exchange(other.channel, nullptr)),
        p_value(std::exchange(other.p_value, nullptr)) {}

  void after_suspend() override {
    channel->try_push_reader(this);
  }

  void before_resume() override {
    channel->check_closed();
    if (p_value) {
      *p_value = this->_result->get_or_throw();
    }
    channel = nullptr;
  }

  ~ReaderAwaiter() {
    if (channel) channel->remove_reader(this);
  }
};

可以看到,调度的逻辑统一抽象到父类 Awaiter 当中,代码的逻辑更加紧凑了。不仅如此,之前在 TaskPromise 当中定义的 await_transform 也不需要了:

// 不再需要
template<typename _ValueType>
auto await_transform(ReaderAwaiter<_ValueType> reader_awaiter) {
  reader_awaiter.executor = &executor;
  return reader_awaiter;
}

WriterAwaiter 同理,不再赘述。

重构 TaskAwaiter

TaskAwaiter 是用来等待其他 Task 的执行完成的。它同样可以用前面的通用 Awaiter 改造:

重构前

template<typename Result, typename Executor>
struct TaskAwaiter {
  explicit TaskAwaiter(AbstractExecutor *executor, Task<Result, Executor> &&task) noexcept
      : _executor(executor), task(std::move(task)) {}

  TaskAwaiter(TaskAwaiter &&completion) noexcept
      : _executor(completion._executor), task(std::exchange(completion.task, {})) {}

  TaskAwaiter(TaskAwaiter &) = delete;

  TaskAwaiter &operator=(TaskAwaiter &) = delete;

  constexpr bool await_ready() const noexcept {
    return false;
  }

  void await_suspend(std::coroutine_handle<> handle) noexcept {
    task.finally([handle, this]() {
      _executor->execute([handle]() {
        handle.resume();
      });
    });
  }

  Result await_resume() noexcept {
    return task.get_result();
  }

 private:
  Task<Result, Executor> task;
  AbstractExecutor *_executor;

};

作为对比,重构后的代码同样变得简洁:

template<typename R, typename Executor>
struct TaskAwaiter : public Awaiter<R> {
  explicit TaskAwaiter(Task<R, Executor> &&task) noexcept
      : task(std::move(task)) {}

  TaskAwaiter(TaskAwaiter &&awaiter) noexcept
      : Awaiter<R>(awaiter), task(std::move(awaiter.task)) {}

  TaskAwaiter(TaskAwaiter &) = delete;

  TaskAwaiter &operator=(TaskAwaiter &) = delete;

 protected:
  void after_suspend() override {
    task.finally([this]() {
      // 先不去获取结果,原因是除了正常的返回值以外,还可能是异常
      this->resume_unsafe();
    });
  }

  void before_resume() override {
    // 如果有返回值,则赋值给 _result,否则直接抛异常
    this->_result = Result(task.get_result());
  }

 private:
  Task<R, Executor> task;
};

改造完成之后,如果不希望为 Task 增加特权支持的话,之前对 TaskAwaiterawait_transform 同样可以删除掉:

// 直接删掉
template<typename _ResultType, typename _Executor>
TaskAwaiter<_ResultType, _Executor> await_transform(Task<_ResultType, _Executor> &&task) {
  return TaskAwaiter<_ResultType, _Executor>(&executor, std::move(task));
}

然后为 Task 类型增加一个函数来获取 TaskAwaiter

template<typename ResultType, typename Executor = NoopExecutor>
struct Task {

  auto as_awaiter() {
    return TaskAwaiter<ResultType, Executor>(std::move(*this));
  }
  ...
}

一旦调用 as_awaiter,我们就会将 Task 的内容全部转移到新创建的 TaskAwaiter 当中,并且返回给外部使用:

Task<int, LooperExecutor> simple_task() {
  // 如果要删除 TaskAwaiter<> await_transform(Task<>)
  // 可以采用以下方式在外部将 Task 转成 TaskAwaiter,然后再 co_await
  auto result2 = co_await simple_task2().as_awaiter();
  ...
}

当然,在我们自己实现的这套 Task 框架当中,Task 自然是“特权阶层”,我们不会真的删除为 Task 定制的 await_transform。但也不难看出,经过改造的 Awaiter 的子类代码量和复杂度都有降低;同时也不再需要定义专门的 await_transform 函数来明确支持 TaskAwaiter,避免了扩展性不强的尴尬。

添加对 std::future 的扩展支持

按照 C++ 标准的发展趋势来看,std::future 应该在将来会支持类似于 Task::then 这样的函数回调,那时候我们完全不需要自己独立定义一套 Task,只需要基于 std::future 进行扩展即可。

当然这都是后话了。现在 std::future 还不支持回调,我们可以另起一个线程来阻塞得等待它的结果,并在结果返回之后恢复协程的执行,这样一来,我们的 Task 框架也就能够支持形如 co_await as_awaiter(future) 这样的写法了。

想要做到这一点,我们只需要基于前面的 Awaiter 来依样画葫芦:

template<typename R>
struct FutureAwaiter : public Awaiter<R> {
  explicit FutureAwaiter(std::future<R> &&future) noexcept
      : _future(std::move(future)) {}

  FutureAwaiter(FutureAwaiter &&awaiter) noexcept
      : Awaiter<R>(awaiter), _future(std::move(awaiter._future)) {}

  FutureAwaiter(FutureAwaiter &) = delete;

  FutureAwaiter &operator=(FutureAwaiter &) = delete;

 protected:
  void after_suspend() override {
    // std::future::get 会阻塞等待结果的返回,因此我们新起一个线程等待结果的返回
    // 如果后续 std::future 增加了回调,这里直接注册回调即可
    std::thread([this](){
      // 获取结果,并恢复协程
      this->resume(this->_future.get());
    }).detach(); 
    // std::thread 必须 detach 或者 join 二选一
    // 也可以使用 std::jthread 
  }

 private:
  std::future<R> _future;
};

FutureAwaiterTaskAwaiter 除了 after_suspendbefore_resume 处有些不同之外,几乎完全一样(当然除了这俩函数以外也基本上没有其他逻辑了)。

如果你愿意,你也可以定义一个 as_awaiter 函数:

template<typename R>
FutureAwaiter<R> as_awaiter(std::future<R> &&future) {
  return FutureAwaiter(std::move(future));
}

这样我们在协程当中就可以使用 co_await 来等待 std::future 的返回了:

Task<void> task() {
  auto result = co_await as_awaiter(std::async([]() {
    std::this_thread::sleep_for(1s);
    return 1000;
  }));
  ...
}

AwaiterImpl 的类型约束

本文给出的通用的 await_transform 有个小小的漏洞,我们不妨再次观察一下这个函数的定义:

template<typename AwaiterImpl>
AwaiterImpl await_transform(AwaiterImpl awaiter) {
  awaiter.install_executor(&executor);
  return awaiter;
}

不难发现,只要 AwaiterImpl 类型定义了协程的 Awaiter 类型的三个函数,并且定义有 install_executor 函数,在这里就可以蒙混过关,例如:

struct FakeAwaiter {
  bool await_ready() { return false; }

  void await_suspend(std::coroutine_handle<> handle) {}

  void await_resume() {}

  void install_executor(AbstractExecutor *) {}
};

Task<void> task()) {

  co_await FakeAwaiter();

}

这个 FakeAwaiter 的定义符合前面的模板类型 AwaiteImpl 的要求,但却不符合我们的预期。为了避免这种情况发生,我们必须想办法要求 AwaiterImpl 只能是 Awaiter 或者它的子类。

这如果是在 Java 当中,我们可以很轻松地指定泛型的上界来达到目的。但 C++ 的模板显然与 Java 泛型的设计相差较大,不能直接在定义模板参数时指定上界。不过 C++ 20 的 concept 可以用来为模板参数限定父类。

我们需要定义一个用来检查类关系的 concept:

template<typename AwaiterImpl, typename R>
concept AwaiterImplRestriction = std::is_base_of<Awaiter<R>, AwaiterImpl>::value;

接下来我们只需要在 await_transform 的模板声明后面加上这个 concept 即可:

template<typename AwaiterImpl>
// ??? 是 AwaiterImpl 继承父类 Awaiter 时传入的模板参数,但我们现在还不知道是什么
requires AwaiterImplRestriction<AwaiterImpl, ???>
AwaiterImpl await_transform(AwaiterImpl awaiter) {  ...  }

不过这里有个问题,我们其实并不知道 AwaiterImpl 的实际类型在继承 Awaiter 时到底用了什么类型的模板参数,这怎么办呢?

有一个简单的办法,那就是为 Awaiter 声明一个内部类型 ResultType

template<typename R>
struct Awaiter {

  using ResultType = R;

  ...
}

这样我们就可以使用 Awaiter::ResultType 来获取这个类型:

template<typename AwaiterImpl>
requires AwaiterImplRestriction<AwaiterImpl, typename AwaiterImpl::ResultType>
AwaiterImpl await_transform(AwaiterImpl awaiter) {
  ...
}

这样像前面提到的 FakeAwaiter 那样的类型,就不能作为 co_await 表达式的参数了。即便我们为 FakeAwaiter 声明 ResultType 也不行,co_await FakeAwaiter() 的报错信息如下:

candidate template ignored: constraints not satisfied [with AwaiterImpl = FakeAwaiter] 
because 'AwaiterImplRestriction<FakeAwaiter, typename FakeAwaiter::ResultType>' evaluated to false 
because 'std::is_base_of<Awaiter<void>, FakeAwaiter>::value' evaluated to false call to 'await_transform' implicitly required by 'co_await' here

可见 FakeAwaiter 并不能满足与 Awaiter 的父子类关系,因此无法作为 AwaiterImpl 的模板实参。

小结

本文介绍了一种实现较为通用的 Awaiter 的方法,目的在于增加现有 Task 框架的扩展性,避免通过频繁改动 TaskPromise 来新增功能。

关于作者

霍丙乾 bennyhuo,Google 开发者专家(Kotlin 方向);《深入理解 Kotlin 协程》 作者(机械工业出版社,2020.6);《深入实践 Kotlin 元编程》 作者(机械工业出版社,2023.8);前腾讯高级工程师,现就职于猿辅导

上次编辑于:
贡献者: bennyhuo