diff --git a/Cargo.lock b/Cargo.lock index d846af6..2d97e92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -160,9 +160,9 @@ dependencies = [ [[package]] name = "aws-sdk-ec2" -version = "1.103.0" +version = "1.104.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bee0579a75067317dbddfc8bfb48ada916d4c44cb3c603af236c3512fa6c618f" +checksum = "adc49f53bf7e2f3fc23c1f37dca224d4e94b295f87eacb90a1874d56074655df" dependencies = [ "aws-credential-types", "aws-runtime", @@ -184,9 +184,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.54.0" +version = "1.55.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "921a13ed6aabe2d1258f65ef7804946255c799224440774c30e1a2c65cdf983a" +checksum = "33993c0b054f4251ff2946941b56c26b582677303eeca34087594eb901ece022" dependencies = [ "aws-credential-types", "aws-runtime", @@ -206,9 +206,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.55.0" +version = "1.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196c952738b05dfc917d82a3e9b5ba850822a6d6a86d677afda2a156cc172ceb" +checksum = "3bd3ceba74a584337a8f3839c818f14f1a2288bfd24235120ff22d7e17a0dd54" dependencies = [ "aws-credential-types", "aws-runtime", @@ -228,9 +228,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.55.0" +version = "1.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ef5b73a927ed80b44096f8c20fb4abae65469af15198367e179ae267256e9d" +checksum = "07835598e52dd354368429cb2abf447ce523ea446d0a533a63cb42cd0d2d9280" dependencies = [ "aws-credential-types", "aws-runtime", @@ -494,9 +494,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", "clap_derive", @@ -504,9 +504,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", @@ -617,10 +617,7 @@ dependencies = [ "aws-sdk-sts", "clap", "comfy-table", - "open", - "regex", "rstest", - "semver", "serde", "serde_yaml", "tempdir", @@ -1095,33 +1092,14 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown", ] -[[package]] -name = "is-docker" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3" -dependencies = [ - "once_cell", -] - -[[package]] -name = "is-wsl" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5" -dependencies = [ - "is-docker", - "once_cell", -] - [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -1258,17 +1236,6 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" -[[package]] -name = "open" -version = "5.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2483562e62ea94312f3576a7aca397306df7990b8d89033e18766744377ef95" -dependencies = [ - "is-wsl", - "libc", - "pathdiff", -] - [[package]] name = "openssl-probe" version = "0.1.5" @@ -1304,12 +1271,6 @@ dependencies = [ "windows-targets", ] -[[package]] -name = "pathdiff" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" - [[package]] name = "percent-encoding" version = "2.3.1" @@ -1519,9 +1480,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.43" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags", "errno", @@ -1635,9 +1596,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "serde" @@ -2043,9 +2004,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4" +checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" [[package]] name = "version_check" diff --git a/Cargo.toml b/Cargo.toml index 160c8af..e485945 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,9 +13,6 @@ serde_yaml = "0.9" tempdir = "0.3" toml = "0.8" comfy-table = "7.1" -regex = "1.11" -open = "5.3" -semver = "1.0" [dependencies.anyhow] version = "1.0" diff --git a/src/main.rs b/src/main.rs index 276e335..3c378dd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,15 @@ +macro_rules! not_available_for_byoc { + ($command:literal) => { + anyhow::bail!(concat!( + "The command `", + $command, + "` is not available for byoc configurations" + )) + }; +} + +mod ssh; + use std::{ collections::HashMap, io::{Error, ErrorKind}, @@ -16,19 +28,17 @@ mod tests; use anyhow::bail; use aws_config::{BehaviorVersion, Region}; use aws_sdk_ec2::{types::InstanceStateName, Client}; -use clap::{Parser, Subcommand}; +use clap::{Parser, Subcommand, ValueEnum}; use comfy_table::{ modifiers, presets, Attribute, Cell, CellAlignment, Color, ContentArrangement, Table, }; -use semver::{Version, VersionReq}; use serde::{Deserialize, Serialize}; use tempdir::TempDir; use tokio::{ fs, - io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, process::{Child, Command}, - time::timeout, }; +use versions::{Requirement, Versioning}; type StrRef = Arc; type PathRef = Arc; @@ -38,90 +48,80 @@ type PathRef = Arc; struct DaftLauncher { #[command(subcommand)] sub_command: SubCommand, - - /// Enable verbose printing - #[arg(short, long, action = clap::ArgAction::Count)] - verbosity: u8, } #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum SubCommand { /// Manage Daft-provisioned clusters (AWS) - Provisioned(ProvisionedCommands), + #[command(subcommand)] + Provisioned(ProvisionedCommand), + /// Manage existing clusters (Kubernetes) - Byoc(ByocCommands), + #[command(subcommand)] + Byoc(ByocCommand), + /// Manage jobs across all cluster types - Job(JobCommands), - /// Manage configurations - Config(ConfigCommands), -} + #[command(subcommand)] + Job(JobCommand), -#[derive(Debug, Parser, Clone, PartialEq, Eq)] -struct ProvisionedCommands { + /// Manage configurations #[command(subcommand)] - command: ProvisionedCommand, + Config(ConfigCommand), } #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum ProvisionedCommand { /// Create a new cluster Up(ConfigPath), + /// Stop a running cluster Down(ConfigPath), + /// Terminate a cluster Kill(ConfigPath), + /// List all clusters List(List), + /// Connect to cluster dashboard Connect(Connect), + /// SSH into cluster head node Ssh(ConfigPath), } -#[derive(Debug, Parser, Clone, PartialEq, Eq)] -struct ByocCommands { - #[command(subcommand)] - command: ByocCommand, -} - #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum ByocCommand { /// Verify connection to existing cluster Verify(ConfigPath), + /// Show cluster information Info(ConfigPath), } -#[derive(Debug, Parser, Clone, PartialEq, Eq)] -struct JobCommands { - #[command(subcommand)] - command: JobCommand, -} - #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum JobCommand { /// Submit a job to the cluster Submit(Submit), + /// Execute SQL queries Sql(Sql), + /// Check job status Status(ConfigPath), + /// View job logs Logs(ConfigPath), } -#[derive(Debug, Parser, Clone, PartialEq, Eq)] -struct ConfigCommands { - #[command(subcommand)] - command: ConfigCommand, -} - #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum ConfigCommand { /// Initialize a new configuration Init(Init), + /// Validate configuration Check(ConfigPath), + /// Export configuration to Ray format Export(ConfigPath), } @@ -202,11 +202,8 @@ struct DaftConfig { #[serde(rename_all = "kebab-case", deny_unknown_fields)] struct DaftSetup { name: StrRef, - #[serde(deserialize_with = "parse_version_req")] - version: VersionReq, - provider: DaftProvider, - #[serde(default)] - dependencies: Vec, + #[serde(deserialize_with = "parse_requirement")] + version: Requirement, #[serde(flatten)] provider_config: ProviderConfig, } @@ -214,19 +211,10 @@ struct DaftSetup { #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] enum ProviderConfig { - #[serde(rename = "provisioned")] - Provisioned(AwsConfigWithRun), - #[serde(rename = "byoc")] + Provisioned(AwsConfig), Byoc(K8sConfig), } -#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "kebab-case", deny_unknown_fields)] -struct AwsConfigWithRun { - #[serde(flatten)] - config: AwsConfig, -} - #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] struct AwsConfig { @@ -240,7 +228,10 @@ struct AwsConfig { instance_type: StrRef, #[serde(default = "default_image_id")] image_id: StrRef, + #[serde(skip_serializing_if = "Option::is_none")] iam_instance_profile_name: Option, + #[serde(default)] + dependencies: Vec, } #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] @@ -331,49 +322,37 @@ fn default_k8s_namespace() -> StrRef { "default".into() } -fn parse_version_req<'de, D>(deserializer: D) -> Result +fn parse_requirement<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let raw: StrRef = Deserialize::deserialize(deserializer)?; - let version_req = raw - .parse::() + let requirement = raw + .parse::() .map_err(serde::de::Error::custom)?; let current_version = env!("CARGO_PKG_VERSION") - .parse::() + .parse::() .expect("CARGO_PKG_VERSION must exist"); - if version_req.matches(¤t_version) { - Ok(version_req) + if requirement.matches(¤t_version) { + Ok(requirement) } else { - Err(serde::de::Error::custom(format!("You're running daft-launcher version {current_version}, but your configuration file requires version {version_req}"))) + Err(serde::de::Error::custom(format!("You're running daft-launcher version {current_version}, but your configuration file requires version {requirement}"))) } } -#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "kebab-case", deny_unknown_fields)] +#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] enum DaftProvider { Provisioned, Byoc, } -impl FromStr for DaftProvider { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "provisioned" => Ok(DaftProvider::Provisioned), - "byoc" => Ok(DaftProvider::Byoc), - _ => anyhow::bail!("Invalid provider '{}'. Must be either 'provisioned' or 'byoc'", s), - } - } -} - impl ToString for DaftProvider { fn to_string(&self) -> String { match self { - DaftProvider::Provisioned => "provisioned".to_string(), - DaftProvider::Byoc => "byoc".to_string(), + DaftProvider::Provisioned => "provisioned", + DaftProvider::Byoc => "byoc", } + .to_string() } } @@ -437,11 +416,9 @@ struct RayResources { cpu: usize, } -async fn read_and_convert( - daft_config_path: &Path, - teardown_behaviour: Option, -) -> anyhow::Result<(DaftConfig, Option)> { - let contents = fs::read_to_string(&daft_config_path) +async fn read_daft_config(daft_config_path: impl AsRef) -> anyhow::Result { + let daft_config_path = daft_config_path.as_ref(); + let contents = fs::read_to_string(daft_config_path) .await .map_err(|error| { if let ErrorKind::NotFound = error.kind() { @@ -453,89 +430,103 @@ async fn read_and_convert( error } })?; - let daft_config = toml::from_str::(&contents)?; - - let ray_config = match &daft_config.setup.provider_config { - ProviderConfig::Byoc(_) => None, - ProviderConfig::Provisioned(aws_config) => { - let key_name = aws_config.config.ssh_private_key - .clone() - .file_stem() - .ok_or_else(|| anyhow::anyhow!(r#"Private key doesn't have a name of the format "name.ext""#))? - .to_str() - .ok_or_else(|| anyhow::anyhow!("The file {:?} does not have a valid UTF-8 name", aws_config.config.ssh_private_key))? - .into(); - - let node_config = RayNodeConfig { - key_name, - instance_type: aws_config.config.instance_type.clone(), - image_id: aws_config.config.image_id.clone(), - iam_instance_profile: aws_config.config.iam_instance_profile_name.clone().map(|name| IamInstanceProfile { name }), - }; - - Some(RayConfig { - cluster_name: daft_config.setup.name.clone(), - max_workers: aws_config.config.number_of_workers, - provider: RayProvider { - r#type: "aws".into(), - region: aws_config.config.region.clone(), - cache_stopped_nodes: teardown_behaviour.map(TeardownBehaviour::to_cache_stopped_nodes), - }, - auth: RayAuth { - ssh_user: aws_config.config.ssh_user.clone(), - ssh_private_key: aws_config.config.ssh_private_key.clone(), - }, - available_node_types: vec![ - ( - "ray.head.default".into(), - RayNodeType { - max_workers: aws_config.config.number_of_workers, - node_config: node_config.clone(), - resources: Some(RayResources { cpu: 0 }), - }, - ), - ( - "ray.worker.default".into(), - RayNodeType { - max_workers: aws_config.config.number_of_workers, - node_config, - resources: None, - }, - ), - ] - .into_iter() - .collect(), - setup_commands: { - let mut commands = vec![ - "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), - "uv python install 3.12".into(), - "uv python pin 3.12".into(), - "uv venv".into(), - "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), - "source ~/.bashrc".into(), - "uv pip install boto3 pip ray[default] getdaft py-spy deltalake".into(), - ]; - if !daft_config.setup.dependencies.is_empty() { - let deps = daft_config.setup.dependencies - .iter() - .map(|dep| format!(r#""{dep}""#)) - .collect::>() - .join(" "); - let deps = format!("uv pip install {deps}").into(); - commands.push(deps); - } - commands - }, - }) - } + Ok(daft_config) +} + +fn convert( + daft_config: &DaftConfig, + teardown_behaviour: Option, +) -> anyhow::Result { + let ProviderConfig::Provisioned(aws_config) = &daft_config.setup.provider_config else { + unreachable!("Can only convert to a ray config-file for provisioned configurations; this should be statically determined"); + }; + + let key_name = aws_config + .ssh_private_key + .clone() + .file_stem() + .ok_or_else(|| { + anyhow::anyhow!(r#"Private key doesn't have a name of the format "name.ext""#) + })? + .to_str() + .ok_or_else(|| { + anyhow::anyhow!( + "The file {:?} does not have a valid UTF-8 name", + aws_config.ssh_private_key + ) + })? + .into(); + + let node_config = RayNodeConfig { + key_name, + instance_type: aws_config.instance_type.clone(), + image_id: aws_config.image_id.clone(), + iam_instance_profile: aws_config + .iam_instance_profile_name + .clone() + .map(|name| IamInstanceProfile { name }), }; - Ok((daft_config, ray_config)) + Ok(RayConfig { + cluster_name: daft_config.setup.name.clone(), + max_workers: aws_config.number_of_workers, + provider: RayProvider { + r#type: "aws".into(), + region: aws_config.region.clone(), + cache_stopped_nodes: teardown_behaviour.map(TeardownBehaviour::to_cache_stopped_nodes), + }, + auth: RayAuth { + ssh_user: aws_config.ssh_user.clone(), + ssh_private_key: aws_config.ssh_private_key.clone(), + }, + available_node_types: vec![ + ( + "ray.head.default".into(), + RayNodeType { + max_workers: 0, + node_config: node_config.clone(), + resources: Some(RayResources { cpu: 0 }), + }, + ), + ( + "ray.worker.default".into(), + RayNodeType { + max_workers: aws_config.number_of_workers, + node_config, + resources: None, + }, + ), + ] + .into_iter() + .collect(), + setup_commands: { + let mut commands = vec![ + "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), + "uv python install 3.12".into(), + "uv python pin 3.12".into(), + "uv venv".into(), + "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), + "source ~/.bashrc".into(), + "uv pip install boto3 pip py-spy deltalake getdaft ray[default]".into(), + ]; + if !aws_config.dependencies.is_empty() { + let deps = aws_config + .dependencies + .iter() + .map(|dep| format!(r#""{dep}""#)) + .collect::>() + .join(" "); + let deps = format!("uv pip install {deps}").into(); + commands.push(deps); + } + commands + }, + }) } -async fn write_ray_config(ray_config: RayConfig, dest: impl AsRef) -> anyhow::Result<()> { - let ray_config = serde_yaml::to_string(&ray_config)?; +async fn write_ray_config(ray_config: &RayConfig, dest: impl AsRef) -> anyhow::Result<()> { + let ray_config = serde_yaml::to_string(ray_config)?; fs::write(dest, ray_config).await?; Ok(()) } @@ -557,14 +548,14 @@ impl SpinDirection { #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum TeardownBehaviour { - Stop, + Down, Kill, } impl TeardownBehaviour { fn to_cache_stopped_nodes(self) -> bool { match self { - Self::Stop => true, + Self::Down => true, Self::Kill => false, } } @@ -608,7 +599,7 @@ struct AwsInstance { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum NodeType { +enum NodeType { Head, Worker, } @@ -746,161 +737,7 @@ async fn assert_is_logged_in_with_aws() -> anyhow::Result<()> { } } -async fn get_region(region: Option, config: impl AsRef) -> anyhow::Result { - let config = config.as_ref(); - Ok(if let Some(region) = region { - region - } else if config.exists() { - let (daft_config, _) = read_and_convert(&config, None).await?; - match &daft_config.setup.provider_config { - ProviderConfig::Provisioned(aws_config) => aws_config.config.region.clone(), - ProviderConfig::Byoc(_) => "us-west-2".into(), - } - } else { - "us-west-2".into() - }) -} - -async fn get_head_node_ip(ray_path: impl AsRef) -> anyhow::Result { - let mut ray_command = Command::new("ray") - .arg("get-head-ip") - .arg(ray_path.as_ref()) - .stdout(Stdio::piped()) - .spawn()?; - - let mut tail_command = Command::new("tail") - .args(["-n", "1"]) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .spawn()?; - - let mut writer = tail_command.stdin.take().expect("stdin must exist"); - - tokio::spawn(async move { - let mut reader = BufReader::new(ray_command.stdout.take().expect("stdout must exist")); - let mut buffer = Vec::new(); - reader.read_to_end(&mut buffer).await?; - writer.write_all(&buffer).await?; - Ok::<_, anyhow::Error>(()) - }); - let output = tail_command.wait_with_output().await?; - if !output.status.success() { - anyhow::bail!("Failed to fetch ip address of head node"); - }; - let addr = String::from_utf8_lossy(&output.stdout) - .trim() - .parse::()?; - Ok(addr) -} - -async fn ssh(ray_path: impl AsRef, aws_config: &AwsConfig) -> anyhow::Result<()> { - let addr = get_head_node_ip(ray_path).await?; - let exit_status = Command::new("ssh") - .arg("-i") - .arg(aws_config.ssh_private_key.as_ref()) - .arg(format!("{}@{}", aws_config.ssh_user, addr)) - .kill_on_drop(true) - .spawn()? - .wait() - .await?; - - if exit_status.success() { - Ok(()) - } else { - Err(anyhow::anyhow!("Failed to ssh into the ray cluster")) - } -} - -async fn establish_ssh_portforward( - ray_path: impl AsRef, - aws_config: &AwsConfig, - port: Option, -) -> anyhow::Result { - let addr = get_head_node_ip(ray_path).await?; - let port = port.unwrap_or(8265); - let mut child = Command::new("ssh") - .arg("-N") - .arg("-i") - .arg(aws_config.ssh_private_key.as_ref()) - .arg("-L") - .arg(format!("{port}:localhost:8265")) - .arg(format!("{}@{}", aws_config.ssh_user, addr)) - .arg("-v") - .stderr(Stdio::piped()) - .kill_on_drop(true) - .spawn()?; - - // We wait for the ssh port-forwarding process to write a specific string to the - // output. - // - // This is a little hacky (and maybe even incorrect across platforms) since we - // are just parsing the output and observing if a specific string has been - // printed. It may be incorrect across platforms because the SSH standard - // does *not* specify a standard "success-message" to printout if the ssh - // port-forward was successful. - timeout(Duration::from_secs(5), { - let stderr = child.stderr.take().expect("stderr must exist"); - async move { - let mut lines = BufReader::new(stderr).lines(); - loop { - let Some(line) = lines.next_line().await? else { - anyhow::bail!("Failed to establish ssh port-forward to {addr}"); - }; - if line.starts_with(format!("Authenticated to {addr}").as_str()) { - break Ok(()); - } - } - } - }) - .await - .map_err(|_| anyhow::anyhow!("Establishing an ssh port-forward to {addr} timed out"))??; - - Ok(child) -} - -struct PortForward { - process: Child, -} - -impl Drop for PortForward { - fn drop(&mut self) { - let _ = self.process.start_kill(); - } -} - -async fn submit_k8s( - working_dir: &Path, - command_segments: impl AsRef<[&str]>, - namespace: &str, -) -> anyhow::Result<()> { - let command_segments = command_segments.as_ref(); - - // Start port forwarding - it will be automatically killed when _port_forward is dropped - let _port_forward = establish_kubernetes_port_forward(namespace).await?; - - // Give the port-forward a moment to fully establish - tokio::time::sleep(Duration::from_secs(1)).await; - - // Submit the job - let exit_status = Command::new("ray") - .env("PYTHONUNBUFFERED", "1") - .args(["job", "submit", "--address", "http://localhost:8265"]) - .arg("--working-dir") - .arg(working_dir) - .arg("--") - .args(command_segments) - .spawn()? - .wait() - .await?; - - if exit_status.success() { - Ok(()) - } else { - Err(anyhow::anyhow!("Failed to submit job to the ray cluster")) - } -} - -async fn establish_kubernetes_port_forward(namespace: &str) -> anyhow::Result { +async fn establish_kubernetes_port_forward(namespace: &str) -> anyhow::Result { let output = Command::new("kubectl") .arg("get") .arg("svc") @@ -911,22 +748,29 @@ async fn establish_kubernetes_port_forward(namespace: &str) -> anyhow::Result anyhow::Result anyhow::Result { - return Err(anyhow::anyhow!( + anyhow::bail!( "Port-forward process exited immediately with status: {}", status - )); + ); } None => { println!("Port-forwarding started successfully"); - Ok(PortForward { - process: port_forward, - }) + Ok(port_forward) } } } -async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> { - match daft_launcher.sub_command { - SubCommand::Config(config_cmd) => { - config_cmd.command.run(daft_launcher.verbosity).await - } - SubCommand::Job(job_cmd) => { - job_cmd.command.run(daft_launcher.verbosity).await - } - SubCommand::Provisioned(provisioned_cmd) => { - provisioned_cmd.command.run(daft_launcher.verbosity).await - } - SubCommand::Byoc(byoc_cmd) => { - byoc_cmd.command.run(daft_launcher.verbosity).await - } +async fn submit( + working_dir: impl AsRef, + command_segments: impl AsRef<[&str]>, +) -> anyhow::Result<()> { + let exit_status = Command::new("ray") + .env("PYTHONUNBUFFERED", "1") + .args(["job", "submit", "--address", "http://localhost:8265"]) + .arg("--working-dir") + .arg(working_dir.as_ref()) + .arg("--") + .args(command_segments.as_ref()) + .spawn()? + .wait() + .await?; + + if exit_status.success() { + Ok(()) + } else { + Err(anyhow::anyhow!("Failed to submit job to the ray cluster")) } } +async fn submit_k8s( + working_dir: impl AsRef, + command_segments: impl AsRef<[&str]>, + namespace: &str, +) -> anyhow::Result<()> { + // Start port forwarding - it will be automatically killed when _port_forward is + // dropped + let _port_forward = establish_kubernetes_port_forward(namespace).await?; + + // Give the port-forward a moment to fully establish + tokio::time::sleep(Duration::from_secs(1)).await; + + submit(working_dir, command_segments).await?; + + Ok(()) +} + #[tokio::main] async fn main() -> anyhow::Result<()> { - run(DaftLauncher::parse()).await + DaftLauncher::parse().run().await } -// Helper function to get AWS config -fn get_aws_config(config: &DaftConfig) -> anyhow::Result<&AwsConfig> { - match &config.setup.provider_config { - ProviderConfig::Provisioned(aws_config) => Ok(&aws_config.config), - ProviderConfig::Byoc(_) => anyhow::bail!("Expected provisioned configuration but found Kubernetes configuration"), +impl DaftLauncher { + async fn run(&self) -> anyhow::Result<()> { + match &self.sub_command { + SubCommand::Config(config_cmd) => config_cmd.run().await, + SubCommand::Job(job_cmd) => job_cmd.run().await, + SubCommand::Provisioned(provisioned_cmd) => provisioned_cmd.run().await, + SubCommand::Byoc(byoc_cmd) => byoc_cmd.run().await, + } } } impl ConfigCommand { - async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { match self { ConfigCommand::Init(Init { path, provider }) => { #[cfg(not(test))] @@ -1002,18 +869,15 @@ impl ConfigCommand { DaftProvider::Byoc => include_str!("template_byoc.toml"), DaftProvider::Provisioned => include_str!("template_provisioned.toml"), } - .replace("", env!("CARGO_PKG_VERSION")); + .replace("", concat!("=", env!("CARGO_PKG_VERSION"))); fs::write(path, contents).await?; } ConfigCommand::Check(ConfigPath { config }) => { - let _ = read_and_convert(&config, None).await?; + let _ = read_daft_config(config).await?; } ConfigCommand::Export(ConfigPath { config }) => { - let (_, ray_config) = read_and_convert(&config, None).await?; - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + let daft_config = read_daft_config(config).await?; + let ray_config = convert(&daft_config, None)?; let ray_config_str = serde_yaml::to_string(&ray_config)?; println!("{ray_config_str}"); } @@ -1023,202 +887,164 @@ impl ConfigCommand { } impl JobCommand { - async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { match self { - JobCommand::Submit(Submit { config_path, job_name }) => { - let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; - let daft_job = daft_config - .jobs - .get(job_name) - .ok_or_else(|| anyhow::anyhow!("A job with the name {job_name} was not found"))?; + JobCommand::Submit(Submit { + config_path, + job_name, + }) => { + let daft_config = read_daft_config(&config_path.config).await?; + let daft_job = daft_config.jobs.get(job_name).ok_or_else(|| { + anyhow::anyhow!("A job with the name {job_name} was not found") + })?; + + let working_dir = daft_job.working_dir.as_ref(); + let command_segments = daft_job.command.as_ref().split(' ').collect::>(); match &daft_config.setup.provider_config { - ProviderConfig::Provisioned(_) => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + ProviderConfig::Provisioned(aws_config) => { + assert_is_logged_in_with_aws().await?; + + let ray_config = convert(&daft_config, None)?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - - let aws_config = get_aws_config(&daft_config)?; - // Start port forwarding - it will be automatically killed when _port_forward is dropped - let _port_forward = establish_ssh_portforward(ray_path, aws_config, Some(8265)).await?; - - // Give the port-forward a moment to fully establish - tokio::time::sleep(Duration::from_secs(1)).await; - - // Submit the job - let exit_status = Command::new("ray") - .env("PYTHONUNBUFFERED", "1") - .args(["job", "submit", "--address", "http://localhost:8265"]) - .arg("--working-dir") - .arg(daft_job.working_dir.as_ref()) - .arg("--") - .args(daft_job.command.as_ref().split(' ').collect::>()) - .spawn()? - .wait() - .await?; + write_ray_config(&ray_config, &ray_path).await?; - if !exit_status.success() { - anyhow::bail!("Failed to submit job to the ray cluster"); - } + let _child = ssh::ssh_portforward(ray_path, aws_config, None).await?; + submit(working_dir, command_segments).await?; } ProviderConfig::Byoc(k8s_config) => { - submit_k8s( - daft_job.working_dir.as_ref(), - daft_job.command.as_ref().split(' ').collect::>(), - k8s_config.namespace.as_ref(), - ) - .await?; + submit_k8s(working_dir, command_segments, k8s_config.namespace.as_ref()) + .await?; } } } JobCommand::Sql(Sql { sql, config_path }) => { - let (daft_config, _) = read_and_convert(&config_path.config, None).await?; + let daft_config = read_daft_config(&config_path.config).await?; + let (temp_sql_dir, sql_path) = create_temp_file("sql.py")?; + fs::write(sql_path, include_str!("sql.py")).await?; + + let working_dir = temp_sql_dir.path(); + let command_segments = vec!["python", "sql.py", sql.as_ref()]; + match &daft_config.setup.provider_config { - ProviderConfig::Provisioned(_) => { - anyhow::bail!("'sql' command is only available for BYOC configurations"); + ProviderConfig::Provisioned(aws_config) => { + assert_is_logged_in_with_aws().await?; + + let ray_config = convert(&daft_config, None)?; + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(&ray_config, &ray_path).await?; + + let _child = ssh::ssh_portforward(ray_path, aws_config, None).await?; + submit(working_dir, command_segments).await?; } ProviderConfig::Byoc(k8s_config) => { - let (temp_sql_dir, sql_path) = create_temp_file("sql.py")?; - fs::write(sql_path, include_str!("sql.py")).await?; - submit_k8s( - temp_sql_dir.path(), - vec!["python", "sql.py", sql.as_ref()], - k8s_config.namespace.as_ref(), - ) - .await?; + submit_k8s(working_dir, command_segments, k8s_config.namespace.as_ref()) + .await?; } } } - JobCommand::Status(_) => { - anyhow::bail!("Job status command not yet implemented"); - } - JobCommand::Logs(_) => { - anyhow::bail!("Job logs command not yet implemented"); - } + JobCommand::Status(..) => todo!(), + JobCommand::Logs(..) => todo!(), } Ok(()) } } impl ProvisionedCommand { - async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { match self { ProvisionedCommand::Up(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, None).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + let daft_config = read_daft_config(config).await?; + match daft_config.setup.provider_config { + ProviderConfig::Provisioned(..) => { assert_is_logged_in_with_aws().await?; + let ray_config = convert(&daft_config, None)?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; + write_ray_config(&ray_config, &ray_path).await?; run_ray_up_or_down_command(SpinDirection::Up, ray_path).await?; } - DaftProvider::Byoc => { - anyhow::bail!("'up' command is only available for provisioned configurations"); - } + ProviderConfig::Byoc(..) => not_available_for_byoc!("up"), } } ProvisionedCommand::Down(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Stop)).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + let daft_config = read_daft_config(config).await?; + match daft_config.setup.provider_config { + ProviderConfig::Provisioned(..) => { assert_is_logged_in_with_aws().await?; + let ray_config = convert(&daft_config, Some(TeardownBehaviour::Down))?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; + write_ray_config(&ray_config, &ray_path).await?; run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; } - DaftProvider::Byoc => { - anyhow::bail!("'down' command is only available for provisioned configurations"); - } + ProviderConfig::Byoc(..) => not_available_for_byoc!("down"), } } ProvisionedCommand::Kill(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Kill)).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + let daft_config = read_daft_config(config).await?; + match daft_config.setup.provider_config { + ProviderConfig::Provisioned(..) => { assert_is_logged_in_with_aws().await?; + let ray_config = convert(&daft_config, Some(TeardownBehaviour::Kill))?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; + write_ray_config(&ray_config, &ray_path).await?; run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; } - DaftProvider::Byoc => { - anyhow::bail!("'kill' command is only available for provisioned configurations"); - } + ProviderConfig::Byoc(..) => not_available_for_byoc!("kill"), } } - ProvisionedCommand::List(List { config_path, region, head, running }) => { - let (daft_config, _) = read_and_convert(&config_path.config, None).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { + ProvisionedCommand::List(List { + config_path, + region, + head, + running, + }) => { + let daft_config = read_daft_config(&config_path.config).await?; + match &daft_config.setup.provider_config { + ProviderConfig::Provisioned(aws_config) => { assert_is_logged_in_with_aws().await?; - let aws_config = get_aws_config(&daft_config)?; + let region = region.as_ref().unwrap_or_else(|| &aws_config.region); let instances = get_ray_clusters_from_aws(region.clone()).await?; print_instances(&instances, *head, *running); } - DaftProvider::Byoc => { - anyhow::bail!("'list' command is only available for provisioned configurations"); - } + ProviderConfig::Byoc(..) => not_available_for_byoc!("list"), } } - ProvisionedCommand::Connect(Connect { port, config_path }) => { - let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + &ProvisionedCommand::Connect(Connect { + port, + ref config_path, + }) => { + let daft_config = read_daft_config(&config_path.config).await?; + match &daft_config.setup.provider_config { + ProviderConfig::Provisioned(aws_config) => { assert_is_logged_in_with_aws().await?; + let ray_config = convert(&daft_config, None)?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - let aws_config = get_aws_config(&daft_config)?; - let _ = establish_ssh_portforward(ray_path, aws_config, Some(*port)) + write_ray_config(&ray_config, &ray_path).await?; + let _ = ssh::ssh_portforward(ray_path, aws_config, Some(port)) .await? .wait_with_output() .await?; } - DaftProvider::Byoc => { - anyhow::bail!("'connect' command is only available for provisioned configurations"); - } + ProviderConfig::Byoc(..) => not_available_for_byoc!("connect"), } } ProvisionedCommand::Ssh(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, None).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + let daft_config = read_daft_config(config).await?; + match &daft_config.setup.provider_config { + ProviderConfig::Provisioned(aws_config) => { assert_is_logged_in_with_aws().await?; + let ray_config = convert(&daft_config, None)?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - let aws_config = get_aws_config(&daft_config)?; - ssh(ray_path, aws_config).await?; - } - DaftProvider::Byoc => { - anyhow::bail!("'ssh' command is only available for provisioned configurations"); + write_ray_config(&ray_config, &ray_path).await?; + ssh::ssh(ray_path, aws_config).await?; } + ProviderConfig::Byoc(..) => not_available_for_byoc!("ssh"), } } } @@ -1227,15 +1053,10 @@ impl ProvisionedCommand { } impl ByocCommand { - async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { match self { - ByocCommand::Verify(ConfigPath { config: _ }) => { - anyhow::bail!("Verify command not yet implemented"); - } - ByocCommand::Info(ConfigPath { config: _ }) => { - anyhow::bail!("Info command not yet implemented"); - } + ByocCommand::Verify(..) => todo!(), + ByocCommand::Info(..) => todo!(), } - Ok(()) } } diff --git a/src/ssh.rs b/src/ssh.rs index e02a674..6e64d46 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -6,7 +6,7 @@ use tokio::{ time::timeout, }; -use crate::DaftConfig; +use crate::AwsConfig; async fn get_head_node_ip(ray_path: impl AsRef) -> anyhow::Result { let mut ray_command = Command::new("ray") @@ -42,18 +42,18 @@ async fn get_head_node_ip(ray_path: impl AsRef) -> anyhow::Result, - daft_config: &DaftConfig, + aws_config: &AwsConfig, portforward: Option, verbose: bool, ) -> anyhow::Result<(Ipv4Addr, Command)> { - let user = daft_config.setup.ssh_user.as_ref(); + let user = aws_config.ssh_user.as_ref(); let addr = get_head_node_ip(ray_path).await?; let mut command = Command::new("ssh"); command .arg("-i") - .arg(daft_config.setup.ssh_private_key.as_ref()) + .arg(aws_config.ssh_private_key.as_ref()) .arg("-o") .arg("StrictHostKeyChecking=no"); @@ -73,24 +73,26 @@ async fn generate_ssh_command( Ok((addr, command)) } -pub async fn ssh(ray_path: impl AsRef, daft_config: &DaftConfig) -> anyhow::Result<()> { - let (_, mut command) = generate_ssh_command(ray_path, daft_config, None, false).await?; +pub async fn ssh(ray_path: impl AsRef, aws_config: &AwsConfig) -> anyhow::Result<()> { + let (addr, mut command) = generate_ssh_command(ray_path, aws_config, None, false).await?; let exit_status = command.spawn()?.wait().await?; if exit_status.success() { Ok(()) } else { - Err(anyhow::anyhow!("Failed to ssh into the ray cluster")) + Err(anyhow::anyhow!( + "Failed to ssh into the ray cluster at address {addr}" + )) } } pub async fn ssh_portforward( ray_path: impl AsRef, - daft_config: &DaftConfig, + aws_config: &AwsConfig, portforward: Option, ) -> anyhow::Result { let (addr, mut command) = generate_ssh_command( ray_path, - daft_config, + aws_config, Some(portforward.unwrap_or(8265)), true, ) diff --git a/src/template_byoc.toml b/src/template_byoc.toml index e70adc0..38f1564 100644 --- a/src/template_byoc.toml +++ b/src/template_byoc.toml @@ -1,15 +1,14 @@ -# This is a template configuration file for daft-launcher with BYOC provider +# This is a template configuration file for daft-launcher with a BYOC provider + [setup] name = "my-daft-cluster" version = "" -provider = "byoc" -# TODO: support dependencies [setup.byoc] -namespace = "default" # Optional, defaults to "default" +namespace = "default" # Optional, defaults to "default" # Job definitions [[job]] name = "example-job" command = "python my_script.py" -working-dir = "~/my_project" \ No newline at end of file +working-dir = "~/my_project" diff --git a/src/template_provisioned.toml b/src/template_provisioned.toml index 4299fbf..45cc7cb 100644 --- a/src/template_provisioned.toml +++ b/src/template_provisioned.toml @@ -1,9 +1,8 @@ -# This is a template configuration file for daft-launcher with provisioned provider +# This is a template configuration file for daft-launcher with a provisioned provider + [setup] name = "my-daft-cluster" version = "" -provider = "provisioned" -dependencies = [] # Optional additional Python packages to install # Provisioned (AWS) configuration [setup.provisioned] @@ -13,10 +12,11 @@ ssh-user = "ubuntu" ssh-private-key = "~/.ssh/id_rsa" instance-type = "i3.2xlarge" image-id = "ami-04dd23e62ed049936" -iam-instance-profile-name = "YourInstanceProfileName" # Optional +iam-instance-profile-name = "YourInstanceProfileName" # Optional +dependencies = [] # Optional additional Python packages to install # Job definitions [[job]] name = "example-job" command = "python my_script.py" -working-dir = "~/my_project" \ No newline at end of file +working-dir = "~/my_project" diff --git a/src/tests.rs b/src/tests.rs index bcb836d..8aa3e14 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,234 +1,176 @@ -// use std::io::ErrorKind; -// use tempdir::TempDir; -// use tokio::fs; - -// use super::*; - -// fn not_found_okay(result: std::io::Result<()>) -> std::io::Result<()> { -// match result { -// Ok(()) => Ok(()), -// Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), -// Err(err) => Err(err), -// } -// } - -// async fn get_path() -> (TempDir, PathBuf) { -// let (temp_dir, path) = create_temp_file(".test.toml").unwrap(); -// not_found_okay(fs::remove_file(path.as_ref()).await).unwrap(); -// not_found_okay(fs::remove_dir_all(path.as_ref()).await).unwrap(); -// (temp_dir, PathBuf::from(path.as_ref())) -// } - -// /// This tests the creation of a daft-launcher configuration file. -// /// -// /// # Note -// /// This does *not* check the contents of the newly created configuration file. -// /// The reason is because we perform some minor templatization of the -// /// `template.toml` file before writing it. Thus, the outputted configuration -// /// file does not *exactly* match the original `template.toml` file. -// #[tokio::test] -// async fn test_init() { -// let (_temp_dir, path) = get_path().await; - -// run(DaftLauncher { -// sub_command: SubCommand::Config(ConfigCommands { -// command: ConfigCommand::Init(Init { -// path: path.clone(), -// provider: DaftProvider::Provisioned, -// }), -// }), -// verbosity: 0, -// }) -// .await -// .unwrap(); - -// assert!(path.exists()); -// assert!(path.is_file()); -// } - -// /// Tests to make sure that `daft check` properly asserts the schema of the -// /// newly created daft-launcher configuration file. -// #[tokio::test] -// async fn test_check() { -// let (_temp_dir, path) = get_path().await; - -// run(DaftLauncher { -// sub_command: SubCommand::Config(ConfigCommands { -// command: ConfigCommand::Init(Init { -// path: path.clone(), -// provider: DaftProvider::Provisioned, -// }), -// }), -// verbosity: 0, -// }) -// .await -// .unwrap(); - -// run(DaftLauncher { -// sub_command: SubCommand::Config(ConfigCommands { -// command: ConfigCommand::Check(ConfigPath { config: path }), -// }), -// verbosity: 0, -// }) -// .await -// .unwrap(); -// } - -// /// This tests the core conversion functionality, from a `DaftConfig` to a -// /// `RayConfig`. -// /// -// /// # Note -// /// Fields which expect a filesystem path (i.e., "ssh_private_key" and -// /// "job.working_dir") are not checked for existence. Therefore, you can really -// /// put any value in there and this test will pass. -// /// -// /// This is because the point of this test is not to check for existence, but -// /// rather to test the mapping from `DaftConfig` to `RayConfig`. -// #[rstest::rstest] -// #[case(simple_config())] -// fn test_conversion( -// #[case] (daft_config, teardown_behaviour, expected): ( -// DaftConfig, -// Option, -// RayConfig, -// ), -// ) { -// let actual = convert(&daft_config, teardown_behaviour).unwrap(); -// assert_eq!(actual, expected); -// } - -// #[rstest::rstest] -// #[case("3.9".parse().unwrap(), "2.34".parse().unwrap(), vec![], vec![ -// "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), -// "uv python install 3.9".into(), -// "uv python pin 3.9".into(), -// "uv venv".into(), -// "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), -// "source ~/.bashrc".into(), -// r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), -// ])] -// #[case("3.9".parse().unwrap(), "2.34".parse().unwrap(), vec!["requests==0.0.0".into()], vec![ -// "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), -// "uv python install 3.9".into(), -// "uv python pin 3.9".into(), -// "uv venv".into(), -// "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), -// "source ~/.bashrc".into(), -// r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), -// r#"uv pip install "requests==0.0.0""#.into(), -// ])] -// fn test_generate_setup_commands( -// #[case] python_version: Versioning, -// #[case] ray_version: Versioning, -// #[case] dependencies: Vec, -// #[case] expected: Vec, -// ) { -// let actual = generate_setup_commands(python_version, ray_version, dependencies.as_slice()); -// assert_eq!(actual, expected); -// } - -// #[rstest::fixture] -// pub fn simple_config() -> (DaftConfig, Option, RayConfig) { -// let test_name: StrRef = "test".into(); -// let ssh_private_key: PathRef = Arc::from(PathBuf::from("testkey.pem")); -// let number_of_workers = 4; -// let daft_config = DaftConfig { -// setup: DaftSetup { -// name: test_name.clone(), -// version: "=1.2.3".parse().unwrap(), -// provider: DaftProvider::Provisioned, -// dependencies: vec![], -// provider_config: ProviderConfig::Provisioned(AwsConfigWithRun { -// config: AwsConfig { -// region: test_name.clone(), -// number_of_workers, -// ssh_user: test_name.clone(), -// ssh_private_key: ssh_private_key.clone(), -// instance_type: test_name.clone(), -// image_id: test_name.clone(), -// iam_instance_profile_name: Some(test_name.clone()), -// }, -// }), -// }, -// jobs: HashMap::default(), -// }; -// let node_config = RayNodeConfig { -// key_name: "testkey".into(), -// instance_type: test_name.clone(), -// image_id: test_name.clone(), -// iam_instance_profile: Some(IamInstanceProfile { -// name: test_name.clone(), -// }), -// }; - -// let ray_config = RayConfig { -// cluster_name: test_name.clone(), -// max_workers: number_of_workers, -// provider: RayProvider { -// r#type: "aws".into(), -// region: test_name.clone(), -// cache_stopped_nodes: None, -// }, -// auth: RayAuth { -// ssh_user: test_name.clone(), -// ssh_private_key, -// }, -// available_node_types: vec![ -// ( -// "ray.head.default".into(), -// RayNodeType { -// max_workers: 0, -// node_config: node_config.clone(), -// resources: Some(RayResources { cpu: 0 }), -// }, -// ), -// ( -// "ray.worker.default".into(), -// RayNodeType { -// max_workers: number_of_workers, -// node_config, -// resources: None, -// }, -// ), -// ] -// .into_iter() -// .collect(), -// setup_commands: vec![ -// "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), -// "uv python install 3.12".into(), -// "uv python pin 3.12".into(), -// "uv venv".into(), -// "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), -// "source ~/.bashrc".into(), -// r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), -// ], -// }; - -// (daft_config, None, ray_config) -// } - -// #[tokio::test] -// async fn test_init_and_export() { -// run(DaftLauncher { -// sub_command: SubCommand::Config(ConfigCommands { -// command: ConfigCommand::Init(Init { -// path: ".daft.toml".into(), -// provider: DaftProvider::Provisioned, -// }), -// }), -// verbosity: 0, -// }) -// .await -// .unwrap(); - -// run(DaftLauncher { -// sub_command: SubCommand::Config(ConfigCommands { -// command: ConfigCommand::Check(ConfigPath { -// config: ".daft.toml".into(), -// }), -// }), -// verbosity: 0, -// }) -// .await -// .unwrap(); -// } +use std::io::ErrorKind; + +use tempdir::TempDir; +use tokio::fs; + +use super::*; + +fn not_found_okay(result: std::io::Result<()>) -> std::io::Result<()> { + match result { + Ok(()) => Ok(()), + Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), + Err(err) => Err(err), + } +} + +async fn get_path() -> (TempDir, PathBuf) { + let (temp_dir, path) = create_temp_file(".test.toml").unwrap(); + not_found_okay(fs::remove_file(path.as_ref()).await).unwrap(); + not_found_okay(fs::remove_dir_all(path.as_ref()).await).unwrap(); + (temp_dir, PathBuf::from(path.as_ref())) +} + +/// This tests the creation of a daft-launcher configuration file. +/// +/// # Note +/// This does *not* check the contents of the newly created configuration file. +/// The reason is because we perform some minor templatization of the +/// `template.toml` file before writing it. Thus, the outputted configuration +/// file does not *exactly* match the original `template.toml` file. +#[tokio::test] +#[rstest::rstest] +#[case(DaftProvider::Provisioned)] +#[case(DaftProvider::Byoc)] +async fn test_init(#[case] provider: DaftProvider) { + let (_temp_dir, path) = get_path().await; + + DaftLauncher { + sub_command: SubCommand::Config(ConfigCommand::Init(Init { + path: path.clone(), + provider, + })), + } + .run() + .await + .unwrap(); + + assert!(path.exists()); + assert!(path.is_file()); +} + +/// Tests to make sure that `daft check` properly asserts the schema of the +/// newly created daft-launcher configuration file. +#[tokio::test] +#[rstest::rstest] +#[case(DaftProvider::Provisioned)] +#[case(DaftProvider::Byoc)] +async fn test_check(#[case] provider: DaftProvider) { + let (_temp_dir, path) = get_path().await; + + DaftLauncher { + sub_command: SubCommand::Config(ConfigCommand::Init(Init { + path: path.clone(), + provider, + })), + } + .run() + .await + .unwrap(); + + DaftLauncher { + sub_command: SubCommand::Config(ConfigCommand::Check(ConfigPath { config: path })), + } + .run() + .await + .unwrap(); +} + +/// This tests the core conversion functionality, from a `DaftConfig` to a +/// `RayConfig`. +/// +/// # Note +/// Fields which expect a filesystem path (i.e., "ssh_private_key" and +/// "job.working_dir") are not checked for existence. Therefore, you can really +/// put any value in there and this test will pass. +/// +/// This is because the point of this test is not to check for existence, but +/// rather to test the mapping from `DaftConfig` to `RayConfig`. +#[rstest::rstest] +#[case(simple_config())] +fn test_conversion( + #[case] (daft_config, teardown_behaviour, expected): ( + DaftConfig, + Option, + RayConfig, + ), +) { + let actual = convert(&daft_config, teardown_behaviour).unwrap(); + assert_eq!(actual, expected); +} + +#[rstest::fixture] +pub fn simple_config() -> (DaftConfig, Option, RayConfig) { + let test_name: StrRef = "test".into(); + let ssh_private_key: PathRef = Arc::from(PathBuf::from("testkey.pem")); + let number_of_workers = 4; + let daft_config = DaftConfig { + setup: DaftSetup { + name: test_name.clone(), + version: "=1.2.3".parse().unwrap(), + provider_config: ProviderConfig::Provisioned(AwsConfig { + region: test_name.clone(), + number_of_workers, + ssh_user: test_name.clone(), + ssh_private_key: ssh_private_key.clone(), + instance_type: test_name.clone(), + image_id: test_name.clone(), + iam_instance_profile_name: Some(test_name.clone()), + dependencies: vec![], + }), + }, + jobs: HashMap::default(), + }; + let node_config = RayNodeConfig { + key_name: "testkey".into(), + instance_type: test_name.clone(), + image_id: test_name.clone(), + iam_instance_profile: Some(IamInstanceProfile { + name: test_name.clone(), + }), + }; + + let ray_config = RayConfig { + cluster_name: test_name.clone(), + max_workers: number_of_workers, + provider: RayProvider { + r#type: "aws".into(), + region: test_name.clone(), + cache_stopped_nodes: None, + }, + auth: RayAuth { + ssh_user: test_name.clone(), + ssh_private_key, + }, + available_node_types: vec![ + ( + "ray.head.default".into(), + RayNodeType { + max_workers: 0, + node_config: node_config.clone(), + resources: Some(RayResources { cpu: 0 }), + }, + ), + ( + "ray.worker.default".into(), + RayNodeType { + max_workers: number_of_workers, + node_config, + resources: None, + }, + ), + ] + .into_iter() + .collect(), + setup_commands: vec![ + "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), + "uv python install 3.12".into(), + "uv python pin 3.12".into(), + "uv venv".into(), + "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), + "source ~/.bashrc".into(), + "uv pip install boto3 pip py-spy deltalake getdaft ray[default]".into(), + ], + }; + + (daft_config, None, ray_config) +}