Skip to content

Commit

Permalink
Run shuffle task in TaskCenter
Browse files Browse the repository at this point in the history
  • Loading branch information
tillrohrmann committed Mar 6, 2024
1 parent 4494d82 commit c31bd06
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 103 deletions.
1 change: 1 addition & 0 deletions crates/core/src/task_center_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pub enum TaskKind {
PartitionProcessor,
#[strum(props(OnError = "log"))]
ConnectionReactor,
Shuffle,
// -- Bifrost Tasks
/// A background task that the system needs for its operation. The task requires a system
/// shutdown on errors and the system will wait for its graceful cancellation on shutdown.
Expand Down
110 changes: 23 additions & 87 deletions crates/worker/src/partition/leadership/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@ use crate::partition::shuffle::{HintSender, Shuffle, ShuffleMetadata};
use crate::partition::{shuffle, storage};
use assert2::let_assert;
use futures::{future, StreamExt};
use restate_core::metadata;
use restate_core::{metadata, task_center, ShutdownError, TaskId, TaskKind};
use restate_invoker_api::InvokeInputJournal;
use restate_timer::TokioClock;
use std::fmt::Debug;
use std::ops::{Deref, RangeInclusive};

use bytes::Bytes;
use futures::future::OptionFuture;
use prost::Message;
use std::panic;
use std::pin::Pin;
use tokio::sync::mpsc;
use tokio::task;
use tokio::task::JoinError;
use tracing::trace;

mod action_collector;
Expand All @@ -34,6 +33,7 @@ use crate::partition::services::non_deterministic;
use crate::partition::services::non_deterministic::ServiceInvoker;
use crate::partition::state_machine::Action;
pub(crate) use action_collector::{ActionEffect, ActionEffectStream};
use restate_bifrost::{bifrost, with_bifrost};
use restate_errors::NotRunningError;
use restate_ingress_dispatcher::{IngressDispatcherInput, IngressDispatcherInputSender};
use restate_schema_impl::Schemas;
Expand All @@ -51,9 +51,8 @@ type TimerService = restate_timer::TimerService<TimerValue, TokioClock, Partitio

pub(crate) struct LeaderState {
leader_epoch: LeaderEpoch,
shutdown_signal: drain::Signal,
shuffle_hint_tx: HintSender,
shuffle_handle: task::JoinHandle<Result<(), anyhow::Error>>,
shuffle_task_id: TaskId,
timer_service: Pin<Box<TimerService>>,
non_deterministic_service_invoker: non_deterministic::ServiceInvoker,
action_effect_handler: ActionEffectHandler,
Expand All @@ -72,12 +71,12 @@ pub(crate) struct FollowerState<I> {
pub(crate) enum Error {
#[error("invoker is unreachable. This indicates a bug or the system is shutting down: {0}")]
Invoker(NotRunningError),
#[error("shuffle failed. This indicates a bug or the system is shutting down: {0}")]
FailedShuffleTask(anyhow::Error),
#[error(transparent)]
Storage(#[from] restate_storage_api::StorageError),
#[error("action effect handler failed: {0}")]
ActionEffectHandler(anyhow::Error),
#[error(transparent)]
Shutdown(#[from] ShutdownError),
}

pub(crate) enum LeadershipState<InvokerInputSender> {
Expand Down Expand Up @@ -180,9 +179,12 @@ where

let shuffle_hint_tx = shuffle.create_hint_sender();

let (shutdown_signal, shutdown_watch) = drain::channel();

let shuffle_handle = tokio::spawn(shuffle.run(shutdown_watch));
let shuffle_task_id = task_center().spawn_child(
TaskKind::Shuffle,
"shuffle",
Some(follower_state.partition_id),
with_bifrost(shuffle.run(), bifrost()),
)?;

let action_effect_handler = ActionEffectHandler::new(
follower_state.partition_id,
Expand All @@ -195,9 +197,8 @@ where
follower_state,
leader_state: LeaderState {
leader_epoch,
shutdown_signal,
shuffle_task_id,
shuffle_hint_tx,
shuffle_handle,
timer_service,
action_effect_handler,
non_deterministic_service_invoker: service_invoker,
Expand Down Expand Up @@ -301,14 +302,12 @@ where
leader_state:
LeaderState {
leader_epoch,
shutdown_signal,
shuffle_handle,
shuffle_task_id,
..
},
} = self
{
// trigger shut down of all leadership tasks
shutdown_signal.drain().await;
let shuffle_handle = OptionFuture::from(task_center().cancel_task(shuffle_task_id));

let (shuffle_result, abort_result) = tokio::join!(
shuffle_handle,
Expand All @@ -317,7 +316,11 @@ where

abort_result.map_err(Error::Invoker)?;

Self::unwrap_task_result(shuffle_result).map_err(Error::FailedShuffleTask)?;
if let Some(Err(err)) = shuffle_result {
if err.is_panic() {
panic::resume_unwind(err.into_panic());
}
}

Ok(Self::follower(
partition_id,
Expand All @@ -332,61 +335,13 @@ where
}
}

fn unwrap_task_result<E>(result: Result<Result<(), E>, JoinError>) -> Result<(), E> {
if let Err(err) = result {
if err.is_panic() {
panic::resume_unwind(err.into_panic());
}

Ok(())
} else {
result.unwrap()
}
}

pub(crate) async fn run_tasks(&mut self) -> TaskResult {
pub(crate) async fn run_timer(&mut self) -> TimerValue {
match self {
LeadershipState::Follower { .. } => future::pending().await,
LeadershipState::Leader {
leader_state:
LeaderState {
shuffle_handle,
timer_service,
..
},
leader_state: LeaderState { timer_service, .. },
..
} => {
tokio::select! {
result = shuffle_handle => TaskResult::TerminatedTask(Self::into_task_result("shuffle", result)),
timer = timer_service.as_mut().next_timer() => TaskResult::Timer(timer)
}
}
}
}

fn into_task_result<E: Into<anyhow::Error>>(
name: &'static str,
result: Result<Result<(), E>, JoinError>,
) -> TokioTaskResult {
if let Err(err) = result {
if err.is_panic() {
panic::resume_unwind(err.into_panic());
}

TokioTaskResult::FailedTask {
name,
error: TaskError::Cancelled,
}
} else {
let result = result.unwrap();

result
.err()
.map(|err| TokioTaskResult::FailedTask {
name,
error: TaskError::Error(err.into()),
})
.unwrap_or(TokioTaskResult::TerminatedTask(name))
} => timer_service.as_mut().next_timer().await,
}
}

Expand Down Expand Up @@ -588,27 +543,8 @@ where
}
}

#[derive(Debug)]
pub(crate) enum TaskResult {
TerminatedTask(TokioTaskResult),
Timer(TimerValue),
}

#[derive(Debug, thiserror::Error)]
pub(crate) enum TokioTaskResult {
#[error("task '{0}' terminated unexpectedly")]
TerminatedTask(&'static str),
#[error("task '{name}' failed: {error}")]
FailedTask {
name: &'static str,
error: TaskError,
},
}

#[derive(Debug, thiserror::Error)]
pub(crate) enum TaskError {
#[error("task was cancelled")]
Cancelled,
#[error(transparent)]
Error(#[from] anyhow::Error),
}
15 changes: 4 additions & 11 deletions crates/worker/src/partition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// by the Apache License, Version 2.0.

use crate::metric_definitions::{PARTITION_ACTUATOR_HANDLED, PARTITION_TIMER_DUE_HANDLED};
use crate::partition::leadership::{ActionEffect, LeadershipState, TaskResult};
use crate::partition::leadership::{ActionEffect, LeadershipState};
use crate::partition::state_machine::{ActionCollector, Effects, StateMachine};
use crate::partition::storage::{DedupSequenceNumberResolver, PartitionStorage, Transaction};
use assert2::let_assert;
Expand Down Expand Up @@ -189,16 +189,9 @@ where
let action_effect = action_effect.ok_or_else(|| anyhow::anyhow!("action effect stream is closed"))?;
state.handle_action_effect(action_effect).await?;
},
task_result = state.run_tasks() => {
match task_result {
TaskResult::Timer(timer) => {
counter!(PARTITION_TIMER_DUE_HANDLED).increment(1);
state.handle_action_effect(ActionEffect::Timer(timer)).await?;
},
TaskResult::TerminatedTask(result) => {
Err(result)?
}
}
timer = state.run_timer() => {
counter!(PARTITION_TIMER_DUE_HANDLED).increment(1);
state.handle_action_effect(ActionEffect::Timer(timer)).await?;
},
}
}
Expand Down
8 changes: 3 additions & 5 deletions crates/worker/src/partition/shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::partition::shuffle::state_machine::StateMachine;
use assert2::let_assert;
use async_channel::{TryRecvError, TrySendError};
use restate_bifrost::bifrost;
use restate_core::cancellation_watcher;
use restate_storage_api::outbox_table::OutboxMessage;
use restate_types::dedup::DedupInformation;
use restate_types::identifiers::{LeaderEpoch, PartitionId, PartitionKey, WithPartitionKey};
Expand Down Expand Up @@ -235,7 +236,7 @@ where
HintSender::new(self.hint_tx.clone(), self.hint_rx.clone())
}

pub(super) async fn run(self, shutdown_watch: drain::Watch) -> anyhow::Result<()> {
pub(super) async fn run(self) -> anyhow::Result<()> {
let Self {
metadata,
mut hint_rx,
Expand All @@ -246,9 +247,6 @@ where

debug!(restate.node = %metadata.node_id, restate.partition.id = %metadata.partition_id, "Running shuffle");

let shutdown = shutdown_watch.signaled();
tokio::pin!(shutdown);

let node_id = metadata.node_id;
let state_machine = StateMachine::new(
metadata,
Expand Down Expand Up @@ -282,7 +280,7 @@ where
// this is just a hint which we can drop
let _ = truncation_tx.try_send(OutboxTruncation::new(shuffled_message_index));
},
_ = &mut shutdown => {
_ = cancellation_watcher() => {
break;
}
}
Expand Down

0 comments on commit c31bd06

Please sign in to comment.