From 7ad1af0cbb1e1343b0e5dc05d12eab8049ee784f Mon Sep 17 00:00:00 2001 From: Jessie Young Date: Fri, 17 Jan 2025 14:28:56 -0800 Subject: [PATCH 01/15] Instructions to install Ray and Daft on an existing Kubernetes cluster, and support BYOC k8s clusters in daft-launcher Added docs for kuberay + daft installation, fixed minor linter issue --- README.md | 241 ++++++-- docs/kubernetes/README.md | 18 + docs/kubernetes/cloud.md | 50 ++ docs/kubernetes/local.md | 127 ++++ docs/kubernetes/on-prem.md | 33 ++ docs/kubernetes/ray-installation.md | 100 ++++ examples/hello_daft.py | 27 + src/main.rs | 875 +++++++++++++++++----------- src/template.toml | 46 +- src/template_k8s.toml | 19 + template.toml | 17 + 11 files changed, 1154 insertions(+), 399 deletions(-) create mode 100644 docs/kubernetes/README.md create mode 100644 docs/kubernetes/cloud.md create mode 100644 docs/kubernetes/local.md create mode 100644 docs/kubernetes/on-prem.md create mode 100644 docs/kubernetes/ray-installation.md create mode 100644 examples/hello_daft.py create mode 100644 src/template_k8s.toml create mode 100644 template.toml diff --git a/README.md b/README.md index aaa7c05..587d60d 100644 --- a/README.md +++ b/README.md @@ -10,45 +10,120 @@ [![Latest](https://img.shields.io/github/v/tag/Eventual-Inc/daft-launcher?label=latest&logo=GitHub)](https://github.com/Eventual-Inc/daft-launcher/tags) [![License](https://img.shields.io/badge/daft_launcher-docs-red.svg)](https://eventual-inc.github.io/daft-launcher) -# Daft Launcher +# Daft Launcher CLI Tool `daft-launcher` is a simple launcher for spinning up and managing Ray clusters for [`daft`](https://github.com/Eventual-Inc/Daft). -It abstracts away all the complexities of dealing with Ray yourself, allowing you to focus on running `daft` in a distributed manner. + +## Goal + +Getting started with Daft in a local environment is easy. +However, getting started with Daft in a cloud environment is substantially more difficult. +So much more difficult, in fact, that users end up spending more time setting up their environment than actually playing with our query engine. + +Daft Launcher aims to solve this problem by providing a simple CLI tool to remove all of this unnecessary heavy-lifting. ## Capabilities -1. Spinning up clusters. -2. Listing all available clusters (as well as their statuses). -3. Submitting jobs to a cluster. -4. Connecting to the cluster (to view the Ray dashboard and submit jobs using the Ray protocol). -5. Spinning down clusters. -6. Creating configuration files. -7. Running raw SQL statements using Daft's SQL API. +What Daft Launcher is capable of: +1. Spinning up clusters (AWS only) +2. Listing all available clusters as well as their statuses (AWS only) +3. Submitting jobs to a cluster (AWS and Kubernetes) +4. Connecting to the cluster (AWS only, Kubernetes coming soon) +5. Spinning down clusters (AWS only) +6. Creating configuration files (AWS and Kubernetes) +7. Running raw SQL statements (AWS only, Kubernetes coming soon) ## Currently supported cloud providers -- [x] AWS -- [ ] GCP -- [ ] Azure +Daft Launcher supports two modes of operation: +- **AWS**: For automatically provisioning and managing Ray clusters in AWS +- **Kubernetes**: For connecting to existing Ray clusters in Kubernetes -## Usage +### Command Support Matrix -You'll need a python package manager installed. -We highly recommend using [`uv`](https://astral.sh/blog/uv) for all things python! +| Command | AWS | Kubernetes | +|----------|-----|------------| +| init | ✅ | ✅ | +| up | ✅ | ❌ | +| submit | ✅ | ✅ | +| stop | ✅ | ❌ | +| kill | ✅ | ❌ | +| list | ✅ | ❌ | +| connect | ✅ | ❌ | +| ssh | ✅ | ❌ | +| sql | ✅ | ❌ | -### AWS +## Usage -If you're using AWS, you'll need: -1. A valid AWS account with the necessary IAM role to spin up EC2 instances. - This IAM role can either be created by you (assuming you have the appropriate permissions). - Or this IAM role will need to be created by your administrator. -2. The [AWS CLI](https://aws.amazon.com/cli) installed and configured on your machine. -3. To login using the AWS CLI. - For full instructions, please look [here](https://google.com). +### Pre-requisites -## Installation +You'll need some python package manager installed. +We recommend using [`uv`](https://astral.sh/blog/uv) for all things python. -Using `uv` (recommended): +#### For AWS Provider +1. A valid AWS account with the necessary IAM role to spin up EC2 instances. + This IAM role can either be created by you (assuming you have the appropriate permissions) + or will need to be created by your administrator. +2. The [AWS CLI](https://aws.amazon.com/cli/) installed and configured on your machine. +3. Login using the AWS CLI. + +#### For Kubernetes Provider +1. A Kubernetes cluster with Ray already deployed + - Can be local (minikube/kind), cloud-managed (EKS/GKE/AKS), or on-premise. + - See our [Kubernetes setup guides](./docs/kubernetes/README.md) for detailed instructions +2. Ray cluster running in your Kubernetes cluster + - Must be installed and configured using Helm + - See provider-specific guides for installation steps +3. Daft installed on the Ray cluster +4. `kubectl` installed and configured with the correct context +5. Appropriate permissions to access the namespace where Ray is deployed + +### SSH Key Setup for AWS + +To enable SSH access and port forwarding for AWS clusters, you need to: + +1. Create an SSH key pair (if you don't already have one): + ```bash + # Generate a new key pair + ssh-keygen -t rsa -b 2048 -f ~/.ssh/daft-key + + # This will create: + # ~/.ssh/daft-key (private key) + # ~/.ssh/daft-key.pub (public key) + ``` + +2. Import the public key to AWS: + ```bash + # Import the public key to AWS + aws ec2 import-key-pair \ + --key-name "daft-key" \ + --public-key-material fileb://~/.ssh/daft-key.pub + ``` + +3. Set proper permissions on your private key: + ```bash + chmod 600 ~/.ssh/daft-key + ``` + +4. Update your daft configuration to use this key: + ```toml + [setup.aws] + # ... other aws config ... + ssh-private-key = "~/.ssh/daft-key" # Path to your private key + ssh-user = "ubuntu" # User depends on the AMI (ubuntu for Ubuntu AMIs) + ``` + +Notes: +- The key name in AWS must match the name of your key file (without the extension) +- The private key must be readable only by you (hence the chmod 600) +- Different AMIs use different default users: + - Ubuntu AMIs: use "ubuntu" + - Amazon Linux AMIs: use "ec2-user" + - Make sure this matches your `ssh-user` configuration + +### Installation + +Using `uv`: ```bash # create project @@ -64,32 +139,110 @@ source .venv/bin/activate uv pip install daft-launcher ``` -## Example +### Example Usage -```sh -# create a new configuration file -daft init -``` -That should create a configuration file for you. -Feel free to modify some of the configuration values. -If you have any confusions on a value, you can always run `daft check` to check the syntax and schema of your configuration file. +All interactions with Daft Launcher are primarily communicated via a configuration file. +By default, Daft Launcher will look inside your `$CWD` for a file named `.daft.toml`. +You can override this behaviour by specifying a custom configuration file. + +#### AWS Provider (Default) -Once you're content with your configuration file, go back to your terminal and run the following: +```bash +# create a new AWS configuration file +daft init +# or explicitly specify AWS provider +daft init --provider aws -```sh -# spin your cluster up +# spin up a cluster (AWS only) daft up +# or optionally, pass in a custom config file +daft up -c my-custom-config.toml -# list all the active clusters +# list all active clusters (AWS only) daft list -# submit a directory and command to run on the cluster -# (where `my-job-name` should be an entry in your .daft.toml file) -daft submit my-job-name +# submit a job defined in your config +daft submit --working-dir <...> example-job + +# execute SQL query +daft sql "SELECT * FROM my_table" + +# connect to the Ray dashboard +daft connect + +# SSH into the head node +daft ssh + +# stop the cluster +daft stop + +# terminate the cluster +daft kill +``` + +#### Kubernetes Provider + +```bash +# create a new Kubernetes configuration file +daft init --provider k8s + +# submit a job defined in your config +daft submit example-job + +# execute SQL query (K8s only) +daft sql "SELECT * FROM my_table" +``` + +### Configuration Files + +You can specify a custom configuration file path with the `-c` flag: +```bash +daft -c my-config.toml submit example-job +``` + +Example AWS configuration: +```toml +[setup] +name = "my-daft-cluster" +version = "0.1.0" +provider = "aws" + +[setup.aws] +region = "us-west-2" +number-of-workers = 4 +ssh-user = "ubuntu" +ssh-private-key = "~/.ssh/daft-key" +instance-type = "i3.2xlarge" +image-id = "ami-04dd23e62ed049936" +iam-instance-profile-name = "YourInstanceProfileName" # Optional +dependencies = [] # Optional additional Python packages + +[run] +pre-setup-commands = [] +post-setup-commands = [] + +[[job]] +name = "example-job" +command = "python my_script.py" +working-dir = "~/my_project" +``` + +Example Kubernetes configuration: +```toml +[setup] +name = "my-daft-cluster" +version = "0.1.0" +provider = "k8s" + +[setup.k8s] +namespace = "default" # Optional, defaults to "default" -# run a direct SQL query on daft -daft sql "SELECT * FROM my_table WHERE column = 'value'" +[run] +pre-setup-commands = [] +post-setup-commands = [] -# finally, once you're done, spin the cluster down -daft down +[[job]] +name = "example-job" +command = "python my_script.py" +working-dir = "~/my_project" ``` diff --git a/docs/kubernetes/README.md b/docs/kubernetes/README.md new file mode 100644 index 0000000..2842d0e --- /dev/null +++ b/docs/kubernetes/README.md @@ -0,0 +1,18 @@ +# Kubernetes Setup for Daft + +> **Note**: This documentation is housed in the `daft-launcher` repository while the Kubernetes approach is being reviewed. Once finalized, these docs will be copied to the separate documentation repository. + +This directory contains guides for setting up Ray and Daft on various Kubernetes environments: + +- [Local Development](./local.md) - Setting up a local Kubernetes cluster for development +- [Cloud Providers](./cloud.md) - Instructions for EKS, GKE, and AKS setups +- [On-Premises](./on-prem.md) - Guide for on-premises Kubernetes deployments + +## Prerequisites + +Before using `daft-launcher` with Kubernetes, you must: +1. Have a running Kubernetes cluster (local, cloud-managed, or on-premise) +2. Install and configure Ray on your Kubernetes cluster +3. Install Daft on your cluster + +Please follow the appropriate guide above for your environment. \ No newline at end of file diff --git a/docs/kubernetes/cloud.md b/docs/kubernetes/cloud.md new file mode 100644 index 0000000..0e34ab6 --- /dev/null +++ b/docs/kubernetes/cloud.md @@ -0,0 +1,50 @@ +# Cloud Provider Kubernetes Setup + +This guide covers using Ray and Daft with managed Kubernetes services from major cloud providers. + +## Prerequisites + +### General Requirements +- `kubectl` installed and configured +- `helm` installed +- A running Kubernetes cluster in one of the following cloud providers: + - Amazon Elastic Kubernetes Service (EKS) + - Google Kubernetes Engine (GKE) + - Azure Kubernetes Service (AKS) + +### Cloud-Specific Requirements + +#### For AWS EKS +- AWS CLI installed and configured +- Access to an existing EKS cluster +- `kubectl` configured for your EKS cluster: + ```bash + aws eks update-kubeconfig --name your-cluster-name --region your-region + ``` + +#### For Google GKE +- Google Cloud SDK installed +- Access to an existing GKE cluster +- `kubectl` configured for your GKE cluster: + ```bash + gcloud container clusters get-credentials your-cluster-name --zone your-zone + ``` + +#### For Azure AKS +- Azure CLI installed +- Access to an existing AKS cluster +- `kubectl` configured for your AKS cluster: + ```bash + az aks get-credentials --resource-group your-resource-group --name your-cluster-name + ``` + +## Installing Ray and Daft + +Once your cloud Kubernetes cluster is running and `kubectl` is configured, follow the [Ray Installation Guide](./ray-installation.md) to: +1. Install KubeRay Operator +2. Deploy Ray cluster +3. Install Daft +4. Set up port forwarding +5. Submit test jobs + +> **Note**: For cloud providers, you'll typically use x86/AMD64 images unless you're specifically using ARM-based instances (like AWS Graviton). \ No newline at end of file diff --git a/docs/kubernetes/local.md b/docs/kubernetes/local.md new file mode 100644 index 0000000..130aeab --- /dev/null +++ b/docs/kubernetes/local.md @@ -0,0 +1,127 @@ +# Local Kubernetes Development Setup + +This guide walks you through setting up a local Kubernetes cluster for Daft development. + +## Prerequisites + +- Docker Desktop installed and running +- `kubectl` CLI tool installed +- `helm` installed +- One of the following local Kubernetes solutions: + - Kind (Recommended) + - Minikube + - Docker Desktop's built-in Kubernetes + +## Option 1: Using Kind (Recommended) + +1. Install Kind: + ```bash + # On macOS with Homebrew + brew install kind + + # On Linux + curl -Lo ./kind https://kind.sigs.k8s.io/dl/v0.20.0/kind-linux-amd64 + chmod +x ./kind + sudo mv ./kind /usr/local/bin/kind + ``` + +2. Create a cluster: + ```bash + # For Apple Silicon (M1, M2, M3): + kind create cluster --name daft-dev --config - < **Note**: For Apple Silicon (M1, M2, M3) machines, make sure to use the ARM64-specific Ray image as specified in the installation guide. + +## Resource Requirements + +Local Kubernetes clusters need sufficient resources to run Ray and Daft effectively: + +- Minimum requirements: + - 4 CPU cores + - 8GB RAM + - 20GB disk space + +- Recommended: + - 8 CPU cores + - 16GB RAM + - 40GB disk space + +You can adjust these in Docker Desktop's settings or when starting Minikube. + +## Troubleshooting + +### Resource Issues +- If pods are stuck in `Pending` state: + - For Docker Desktop: Increase resources in Docker Desktop settings + - For Minikube: Start with more resources: `minikube start --cpus 6 --memory 12288` + +### Architecture Issues +- For Apple Silicon users: + - Ensure you're using ARM64-compatible images + - Check Docker Desktop is running in native ARM64 mode + - Verify Kubernetes components are ARM64-compatible + +## Cleanup + +To delete your local cluster: + +```bash +# For Kind +kind delete cluster --name daft-dev + +# For Minikube +minikube delete +``` \ No newline at end of file diff --git a/docs/kubernetes/on-prem.md b/docs/kubernetes/on-prem.md new file mode 100644 index 0000000..fd9258e --- /dev/null +++ b/docs/kubernetes/on-prem.md @@ -0,0 +1,33 @@ +# On-Premises Kubernetes Setup + +This guide covers setting up Ray and Daft on self-managed Kubernetes clusters. + +## Prerequisites + +Before proceeding with Ray and Daft installation, ensure you have: + +- A running Kubernetes cluster (v1.16+) +- `kubectl` installed and configured with access to your cluster +- `helm` installed +- Load balancer solution configured if needed + +## Verifying Cluster Requirements + +1. Check Kubernetes version: + ```bash + kubectl version --short + ``` + +2. Verify cluster nodes: + ```bash + kubectl get nodes + ``` + +## Installing Ray and Daft + +Once your on-premises Kubernetes cluster is ready, follow the [Cloud Provider Setup Guide](./cloud.md#installing-ray-common-steps-for-all-providers) for: +- Installing Ray using Helm +- Installing Daft on the Ray cluster +- Configuring and using daft-launcher + +The installation steps are identical regardless of where your Kubernetes cluster is running. \ No newline at end of file diff --git a/docs/kubernetes/ray-installation.md b/docs/kubernetes/ray-installation.md new file mode 100644 index 0000000..a78d5cc --- /dev/null +++ b/docs/kubernetes/ray-installation.md @@ -0,0 +1,100 @@ +# Installing Ray on Kubernetes + +This guide covers the common steps for installing Ray on Kubernetes using KubeRay, regardless of where your cluster is running (local, cloud, or on-premise). + +## Prerequisites +- A running Kubernetes cluster +- `kubectl` configured with the correct context +- `helm` installed + +## Installation Steps + +1. Add the KubeRay Helm repository: + ```bash + helm repo add kuberay https://ray-project.github.io/kuberay-helm/ + helm repo update + ``` + +2. Install KubeRay Operator: + ```bash + helm install kuberay-operator kuberay/kuberay-operator + ``` + +3. Create a values file (`values.yaml`): + ```yaml + head: + args: ["sudo apt-get update && sudo apt-get install -y curl; curl -LsSf https://astral.sh/uv/install.sh | sh; export PATH=$HOME/.local/bin:$PATH; uv pip install --system getdaft"] + worker: + args: ["sudo apt-get update && sudo apt-get install -y curl; curl -LsSf https://astral.sh/uv/install.sh | sh; export PATH=$HOME/.local/bin:$PATH; uv pip install --system getdaft"] + + rayCluster: + headGroupSpec: + template: + spec: + containers: + - name: ray-head + image: rayproject/ray:2.40.0-py310 # Use the desired Python version + command: ["ray", "start", "--head"] + workerGroupSpecs: + template: + spec: + containers: + - name: ray-worker + image: rayproject/ray:2.40.0-py310 # Same image to ensure compatibility + ``` + +4. Install Ray Cluster: + + For Apple Silicon (M1, M2, M3, M4) or other ARM64 processors (AWS Graviton, etc.): + ```bash + helm install raycluster kuberay/ray-cluster --version 1.2.2 \ + --set 'image.tag=2.40.0-py310-aarch64' \ + -f values.yaml + ``` + + For x86/AMD64 processors: + ```bash + helm install raycluster kuberay/ray-cluster --version 1.2.2 \ + -f values.yaml + ``` + +6. Verify the installation: + ```bash + kubectl get pods + ``` + +## Accessing Ray + +### Port Forwarding +To access the Ray dashboard and submit jobs, set up port forwarding: +```bash +kubectl port-forward service/raycluster-kuberay-head-svc 8265:8265 +``` + +### Ray Dashboard +Once port forwarding is set up, access the dashboard at: +http://localhost:8265 + +### Submitting Jobs +You can submit Ray jobs using the following command: +```bash +ray job submit --address http://localhost:8265 -- python -c "import ray; import daft; ray.init(); print(ray.cluster_resources())" +``` + +## Troubleshooting + +1. Check pod status: + ```bash + kubectl get pods + kubectl describe pod + ``` + +2. View pod logs: + ```bash + kubectl logs + ``` + +3. Common issues: + - If pods are stuck in `Pending` state, check resource availability + - If pods are `CrashLoopBackOff`, check the logs for errors + - For ARM64 issues, ensure you're using the correct image tag with `-aarch64` suffix \ No newline at end of file diff --git a/examples/hello_daft.py b/examples/hello_daft.py new file mode 100644 index 0000000..8102620 --- /dev/null +++ b/examples/hello_daft.py @@ -0,0 +1,27 @@ +import sys +import daft +from daft import DataType, udf + +print(f"Python version: {sys.version}") + + +import datetime +df = daft.from_pydict( + { + "integers": [1, 2, 3, 4], + "floats": [1.5, 2.5, 3.5, 4.5], + "bools": [True, True, False, False], + "strings": ["a", "b", "c", "d"], + "bytes": [b"a", b"b", b"c", b"d"], + "dates": [ + datetime.date(1994, 1, 1), + datetime.date(1994, 1, 2), + datetime.date(1994, 1, 3), + datetime.date(1994, 1, 4), + ], + "lists": [[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]], + "nulls": [None, None, None, None], + } +) + +df.show(2) diff --git a/src/main.rs b/src/main.rs index 86866fa..5f426f7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,3 @@ -mod ssh; -#[cfg(test)] -mod tests; - use std::{ collections::HashMap, io::{Error, ErrorKind}, @@ -10,7 +6,6 @@ use std::{ process::Stdio, str::FromStr, sync::Arc, - thread::{sleep, spawn}, time::Duration, }; @@ -22,11 +17,15 @@ use clap::{Parser, Subcommand}; use comfy_table::{ modifiers, presets, Attribute, Cell, CellAlignment, Color, ContentArrangement, Table, }; -use regex::Regex; +use semver::{Version, VersionReq}; use serde::{Deserialize, Serialize}; use tempdir::TempDir; -use tokio::{fs, process::Command}; -use versions::{Requirement, Versioning}; +use tokio::{ + fs, + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, + process::{Child, Command}, + time::timeout, +}; type StrRef = Arc; type PathRef = Arc; @@ -100,13 +99,14 @@ struct Init { /// The path at which to create the config file. #[arg(default_value = ".daft.toml")] path: PathBuf, + + /// The provider to use - either 'aws' (default) to auto-generate a cluster or 'k8s' for existing Kubernetes clusters + #[arg(long, default_value = "aws")] + provider: String, } #[derive(Debug, Parser, Clone, PartialEq, Eq)] struct List { - /// A regex to filter for the Ray clusters which match the given name. - regex: Option, - /// The region which to list all the available clusters for. #[arg(long)] region: Option, @@ -138,10 +138,6 @@ struct Connect { #[arg(long, default_value = "8265")] port: u16, - /// Prevent the dashboard from opening automatically. - #[arg(long)] - no_dashboard: bool, - #[clap(flatten)] config_path: ConfigPath, } @@ -167,11 +163,55 @@ struct ConfigPath { struct DaftConfig { setup: DaftSetup, #[serde(default)] - run: Vec, - #[serde(default, rename = "job", deserialize_with = "parse_jobs")] + run: DaftRun, + #[serde(rename = "job", deserialize_with = "parse_jobs")] jobs: HashMap, } +#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +struct DaftSetup { + name: StrRef, + #[serde(deserialize_with = "parse_version_req")] + version: VersionReq, + provider: DaftProvider, + #[serde(flatten)] + provider_config: ProviderConfig, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +enum ProviderConfig { + #[serde(rename = "aws")] + Aws(AwsConfig), + #[serde(rename = "k8s")] + K8s(K8sConfig), +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +struct AwsConfig { + region: StrRef, + #[serde(default = "default_number_of_workers")] + number_of_workers: usize, + ssh_user: StrRef, + #[serde(deserialize_with = "parse_ssh_private_key")] + ssh_private_key: PathRef, + #[serde(default = "default_instance_type")] + instance_type: StrRef, + #[serde(default = "default_image_id")] + image_id: StrRef, + iam_instance_profile_name: Option, + #[serde(default)] + dependencies: Vec, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +struct K8sConfig { + namespace: Option, +} + fn parse_jobs<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, @@ -202,31 +242,6 @@ where Ok(jobs) } -#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "kebab-case", deny_unknown_fields)] -struct DaftSetup { - name: StrRef, - #[serde(deserialize_with = "parse_daft_launcher_requirement")] - requires: Requirement, - #[serde(deserialize_with = "parse_python_version")] - python_version: Versioning, - #[serde(deserialize_with = "parse_ray_version")] - ray_version: Versioning, - region: StrRef, - #[serde(default = "default_number_of_workers")] - number_of_workers: usize, - ssh_user: StrRef, - #[serde(deserialize_with = "parse_ssh_private_key")] - ssh_private_key: PathRef, - #[serde(default = "default_instance_type")] - instance_type: StrRef, - #[serde(default = "default_image_id")] - image_id: StrRef, - iam_instance_profile_name: Option, - #[serde(default)] - dependencies: Vec, -} - fn parse_ssh_private_key<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>, @@ -274,52 +289,52 @@ fn default_image_id() -> StrRef { "ami-04dd23e62ed049936".into() } -fn parse_python_version<'de, D>(deserializer: D) -> Result +fn parse_version_req<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let raw: StrRef = Deserialize::deserialize(deserializer)?; - let requested_py_version = raw - .parse::() + let version_req = raw + .parse::() .map_err(serde::de::Error::custom)?; - let minimum_py_requirement = ">=3.9" - .parse::() - .expect("Parsing a static, constant version should always succeed"); - - if minimum_py_requirement.matches(&requested_py_version) { - Ok(requested_py_version) + let current_version = env!("CARGO_PKG_VERSION") + .parse::() + .expect("CARGO_PKG_VERSION must exist"); + if version_req.matches(¤t_version) { + Ok(version_req) } else { - Err(serde::de::Error::custom(format!("The minimum supported python version is {minimum_py_requirement}, but your configuration file requested python version {requested_py_version}"))) + Err(serde::de::Error::custom(format!("You're running daft-launcher version {current_version}, but your configuration file requires version {version_req}"))) } } -fn parse_ray_version<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - let raw: StrRef = Deserialize::deserialize(deserializer)?; - let version = raw.parse().map_err(serde::de::Error::custom)?; - Ok(version) +#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +enum DaftProvider { + Aws, + K8s, } -fn parse_daft_launcher_requirement<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - let raw: StrRef = Deserialize::deserialize(deserializer)?; - let requested_requirement = raw - .parse::() - .map_err(serde::de::Error::custom)?; - let current_version = env!("CARGO_PKG_VERSION") - .parse::() - .expect("CARGO_PKG_VERSION must exist"); - if requested_requirement.matches(¤t_version) { - Ok(requested_requirement) - } else { - Err(serde::de::Error::custom(format!("You're running daft-launcher version {current_version}, but your configuration file requires version {requested_requirement}"))) +impl FromStr for DaftProvider { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "aws" => Ok(DaftProvider::Aws), + "k8s" => Ok(DaftProvider::K8s), + _ => anyhow::bail!("Invalid provider '{}'. Must be either 'aws' or 'k8s'", s), + } } } +#[derive(Default, Debug, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +struct DaftRun { + #[serde(default)] + pre_setup_commands: Vec, + #[serde(default)] + post_setup_commands: Vec, +} + #[derive(Debug, Clone, PartialEq, Eq)] struct DaftJob { command: StrRef, @@ -365,12 +380,12 @@ struct RayNodeConfig { instance_type: StrRef, image_id: StrRef, #[serde(skip_serializing_if = "Option::is_none")] - iam_instance_profile: Option, + iam_instance_profile: Option, } #[derive(Default, Debug, Serialize, Clone, PartialEq, Eq)] #[serde(rename_all = "PascalCase")] -struct RayIamInstanceProfile { +struct IamInstanceProfile { name: StrRef, } @@ -380,112 +395,10 @@ struct RayResources { cpu: usize, } -fn generate_setup_commands( - python_version: Versioning, - ray_version: Versioning, - dependencies: &[StrRef], -) -> Vec { - let mut commands = vec![ - "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), - format!("uv python install {python_version}").into(), - format!("uv python pin {python_version}").into(), - "uv venv".into(), - "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), - "source ~/.bashrc".into(), - format!( - r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]=={ray_version}""# - ) - .into(), - ]; - - if !dependencies.is_empty() { - let deps = dependencies - .iter() - .map(|dep| format!(r#""{dep}""#)) - .collect::>() - .join(" "); - let deps = format!("uv pip install {deps}").into(); - commands.push(deps); - } - - commands -} - -fn convert( - daft_config: &DaftConfig, - teardown_behaviour: Option, -) -> anyhow::Result { - let key_name = daft_config - .setup - .ssh_private_key - .clone() - .file_stem() - .ok_or_else(|| { - anyhow::anyhow!(r#"Private key doesn't have a name of the format "name.ext""#) - })? - .to_str() - .ok_or_else(|| { - anyhow::anyhow!( - "The file {:?} does not a valid UTF-8 name", - daft_config.setup.ssh_private_key, - ) - })? - .into(); - let iam_instance_profile = daft_config - .setup - .iam_instance_profile_name - .clone() - .map(|name| RayIamInstanceProfile { name }); - let node_config = RayNodeConfig { - key_name, - instance_type: daft_config.setup.instance_type.clone(), - image_id: daft_config.setup.image_id.clone(), - iam_instance_profile, - }; - Ok(RayConfig { - cluster_name: daft_config.setup.name.clone(), - max_workers: daft_config.setup.number_of_workers, - provider: RayProvider { - r#type: "aws".into(), - region: daft_config.setup.region.clone(), - cache_stopped_nodes: teardown_behaviour.map(TeardownBehaviour::to_cache_stopped_nodes), - }, - auth: RayAuth { - ssh_user: daft_config.setup.ssh_user.clone(), - ssh_private_key: daft_config.setup.ssh_private_key.clone(), - }, - available_node_types: vec![ - ( - "ray.head.default".into(), - RayNodeType { - max_workers: 0, - node_config: node_config.clone(), - resources: Some(RayResources { cpu: 0 }), - }, - ), - ( - "ray.worker.default".into(), - RayNodeType { - max_workers: daft_config.setup.number_of_workers, - node_config, - resources: None, - }, - ), - ] - .into_iter() - .collect(), - setup_commands: generate_setup_commands( - daft_config.setup.python_version.clone(), - daft_config.setup.ray_version.clone(), - daft_config.setup.dependencies.as_ref(), - ), - }) -} - async fn read_and_convert( daft_config_path: &Path, teardown_behaviour: Option, -) -> anyhow::Result<(DaftConfig, RayConfig)> { +) -> anyhow::Result<(DaftConfig, Option)> { let contents = fs::read_to_string(&daft_config_path) .await .map_err(|error| { @@ -498,8 +411,84 @@ async fn read_and_convert( error } })?; + let daft_config = toml::from_str::(&contents)?; - let ray_config = convert(&daft_config, teardown_behaviour)?; + + let ray_config = match &daft_config.setup.provider_config { + ProviderConfig::K8s(_) => None, + ProviderConfig::Aws(aws_config) => { + let key_name = aws_config.ssh_private_key + .clone() + .file_stem() + .ok_or_else(|| anyhow::anyhow!(r#"Private key doesn't have a name of the format "name.ext""#))? + .to_str() + .ok_or_else(|| anyhow::anyhow!("The file {:?} does not have a valid UTF-8 name", aws_config.ssh_private_key))? + .into(); + + let node_config = RayNodeConfig { + key_name, + instance_type: aws_config.instance_type.clone(), + image_id: aws_config.image_id.clone(), + iam_instance_profile: aws_config.iam_instance_profile_name.clone().map(|name| IamInstanceProfile { name }), + }; + + Some(RayConfig { + cluster_name: daft_config.setup.name.clone(), + max_workers: aws_config.number_of_workers, + provider: RayProvider { + r#type: "aws".into(), + region: aws_config.region.clone(), + cache_stopped_nodes: teardown_behaviour.map(TeardownBehaviour::to_cache_stopped_nodes), + }, + auth: RayAuth { + ssh_user: aws_config.ssh_user.clone(), + ssh_private_key: aws_config.ssh_private_key.clone(), + }, + available_node_types: vec![ + ( + "ray.head.default".into(), + RayNodeType { + max_workers: aws_config.number_of_workers, + node_config: node_config.clone(), + resources: Some(RayResources { cpu: 0 }), + }, + ), + ( + "ray.worker.default".into(), + RayNodeType { + max_workers: aws_config.number_of_workers, + node_config, + resources: None, + }, + ), + ] + .into_iter() + .collect(), + setup_commands: { + let mut commands = vec![ + "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), + "uv python install 3.12".into(), + "uv python pin 3.12".into(), + "uv venv".into(), + "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), + "source ~/.bashrc".into(), + "uv pip install boto3 pip ray[default] getdaft py-spy deltalake".into(), + ]; + if !aws_config.dependencies.is_empty() { + let deps = aws_config + .dependencies + .iter() + .map(|dep| format!(r#""{dep}""#)) + .collect::>() + .join(" "); + let deps = format!("uv pip install {deps}").into(); + commands.push(deps); + } + commands + }, + }) + } + }; Ok((daft_config, ray_config)) } @@ -583,15 +572,6 @@ pub enum NodeType { Worker, } -impl NodeType { - pub fn as_str(self) -> &'static str { - match self { - Self::Head => "head", - Self::Worker => "worker", - } - } -} - impl FromStr for NodeType { type Err = anyhow::Error; @@ -661,37 +641,24 @@ async fn get_ray_clusters_from_aws(region: StrRef) -> anyhow::Result, - head: bool, - running: bool, -) -> anyhow::Result { +fn print_instances(instances: &[AwsInstance], head: bool, running: bool) { let mut table = Table::default(); table .load_preset(presets::UTF8_FULL) .apply_modifier(modifiers::UTF8_ROUND_CORNERS) .apply_modifier(modifiers::UTF8_SOLID_INNER_BORDERS) .set_content_arrangement(ContentArrangement::DynamicFullWidth) - .set_header( - ["Name", "Instance ID", "Node Type", "Status", "IPv4"].map(|header| { - Cell::new(header) - .set_alignment(CellAlignment::Center) - .add_attribute(Attribute::Bold) - }), - ); - let regex = regex.as_deref().map(Regex::new).transpose()?; + .set_header(["Name", "Instance ID", "Status", "IPv4"].map(|header| { + Cell::new(header) + .set_alignment(CellAlignment::Center) + .add_attribute(Attribute::Bold) + })); for instance in instances.iter().filter(|instance| { if head && instance.node_type != NodeType::Head { return false; } else if running && instance.state != Some(InstanceStateName::Running) { return false; }; - if let Some(regex) = regex.as_ref() { - if !regex.is_match(&instance.regular_name) { - return false; - }; - }; true }) { let status = instance.state.as_ref().map_or_else( @@ -717,13 +684,12 @@ fn format_table( .map_or("n/a".into(), ToString::to_string); table.add_row(vec![ Cell::new(instance.regular_name.to_string()).fg(Color::Cyan), - Cell::new(instance.instance_id.as_ref()), - Cell::new(instance.node_type.as_str()), + Cell::new(&*instance.instance_id), status, Cell::new(ipv4), ]); } - Ok(table) + println!("{table}"); } async fn assert_is_logged_in_with_aws() -> anyhow::Result<()> { @@ -745,81 +711,228 @@ async fn get_region(region: Option, config: impl AsRef) -> anyhow: region } else if config.exists() { let (daft_config, _) = read_and_convert(&config, None).await?; - daft_config.setup.region + match &daft_config.setup.provider_config { + ProviderConfig::Aws(aws_config) => aws_config.region.clone(), + ProviderConfig::K8s(_) => "us-west-2".into(), + } } else { "us-west-2".into() }) } -async fn submit(working_dir: &Path, command_segments: impl AsRef<[&str]>) -> anyhow::Result<()> { - let command_segments = command_segments.as_ref(); +async fn get_head_node_ip(ray_path: impl AsRef) -> anyhow::Result { + let mut ray_command = Command::new("ray") + .arg("get-head-ip") + .arg(ray_path.as_ref()) + .stdout(Stdio::piped()) + .spawn()?; - let exit_status = Command::new("ray") - .env("PYTHONUNBUFFERED", "1") - .args(["job", "submit", "--address", "http://localhost:8265"]) - .arg("--working-dir") - .arg(working_dir) - .arg("--") - .args(command_segments) + let mut tail_command = Command::new("tail") + .args(["-n", "1"]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn()?; + + let mut writer = tail_command.stdin.take().expect("stdin must exist"); + + tokio::spawn(async move { + let mut reader = BufReader::new(ray_command.stdout.take().expect("stdout must exist")); + let mut buffer = Vec::new(); + reader.read_to_end(&mut buffer).await?; + writer.write_all(&buffer).await?; + Ok::<_, anyhow::Error>(()) + }); + let output = tail_command.wait_with_output().await?; + if !output.status.success() { + anyhow::bail!("Failed to fetch ip address of head node"); + }; + let addr = String::from_utf8_lossy(&output.stdout) + .trim() + .parse::()?; + Ok(addr) +} + +async fn ssh(ray_path: impl AsRef, aws_config: &AwsConfig) -> anyhow::Result<()> { + let addr = get_head_node_ip(ray_path).await?; + let exit_status = Command::new("ssh") + .arg("-i") + .arg(aws_config.ssh_private_key.as_ref()) + .arg(format!("{}@{}", aws_config.ssh_user, addr)) + .kill_on_drop(true) .spawn()? .wait() .await?; + if exit_status.success() { Ok(()) } else { - Err(anyhow::anyhow!("Failed to submit job to the ray cluster")) + Err(anyhow::anyhow!("Failed to ssh into the ray cluster")) } } -async fn get_version_from_env(bin: &str, prefix: &str) -> anyhow::Result { - let output = Command::new(bin) - .arg("--version") - .stdout(Stdio::piped()) +async fn establish_ssh_portforward( + ray_path: impl AsRef, + aws_config: &AwsConfig, + port: Option, +) -> anyhow::Result { + let addr = get_head_node_ip(ray_path).await?; + let port = port.unwrap_or(8265); + let mut child = Command::new("ssh") + .arg("-N") + .arg("-i") + .arg(aws_config.ssh_private_key.as_ref()) + .arg("-L") + .arg(format!("{port}:localhost:8265")) + .arg(format!("{}@{}", aws_config.ssh_user, addr)) + .arg("-v") .stderr(Stdio::piped()) - .spawn()? - .wait_with_output() - .await?; + .kill_on_drop(true) + .spawn()?; + + // We wait for the ssh port-forwarding process to write a specific string to the + // output. + // + // This is a little hacky (and maybe even incorrect across platforms) since we + // are just parsing the output and observing if a specific string has been + // printed. It may be incorrect across platforms because the SSH standard + // does *not* specify a standard "success-message" to printout if the ssh + // port-forward was successful. + timeout(Duration::from_secs(5), { + let stderr = child.stderr.take().expect("stderr must exist"); + async move { + let mut lines = BufReader::new(stderr).lines(); + loop { + let Some(line) = lines.next_line().await? else { + anyhow::bail!("Failed to establish ssh port-forward to {addr}"); + }; + if line.starts_with(format!("Authenticated to {addr}").as_str()) { + break Ok(()); + } + } + } + }) + .await + .map_err(|_| anyhow::anyhow!("Establishing an ssh port-forward to {addr} timed out"))??; - if output.status.success() { - let version = String::from_utf8(output.stdout)? - .strip_prefix(prefix) - .ok_or_else(|| anyhow::anyhow!("Could not parse {bin} version"))? - .trim() - .parse()?; - Ok(version) - } else { - Err(anyhow::anyhow!("Failed to find {bin} executable")) + Ok(child) +} + +struct PortForward { + process: Child, +} + +impl Drop for PortForward { + fn drop(&mut self) { + let _ = self.process.start_kill(); } } -async fn get_python_version_from_env() -> anyhow::Result { - let python_version = get_version_from_env("python", "Python ").await?; - Ok(python_version) +async fn establish_kubernetes_port_forward(namespace: Option<&str>) -> anyhow::Result { + let namespace = namespace.unwrap_or("default"); + let output = Command::new("kubectl") + .arg("get") + .arg("svc") + .arg("-n") + .arg(namespace) + .arg("-l") + .arg("ray.io/node-type=head") + .arg("--no-headers") + .arg("-o") + .arg("custom-columns=:metadata.name") + .output() + .await?; + if !output.status.success() { + return Err(anyhow::anyhow!("Failed to get Ray head node services with kubectl in namespace {}", namespace)); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + if stdout.trim().is_empty() { + return Err(anyhow::anyhow!("Ray head node service not found in namespace {}", namespace)); + } + + let head_node_service_name = stdout + .lines() + .next() + .ok_or_else(|| anyhow::anyhow!("Failed to get the head node service name"))?; + println!("Found Ray head node service: {} in namespace {}", head_node_service_name, namespace); + + // Start port-forward with stderr piped so we can monitor the process + let mut port_forward = Command::new("kubectl") + .arg("port-forward") + .arg("-n") + .arg(namespace) + .arg(format!("svc/{}", head_node_service_name)) + .arg("8265:8265") + .stderr(Stdio::piped()) + .stdout(Stdio::piped()) // Capture stdout too + .kill_on_drop(true) + .spawn()?; + + // Give the port-forward a moment to start and check for immediate failures + tokio::time::sleep(Duration::from_secs(2)).await; + + // Check if process is still running + match port_forward.try_wait()? { + Some(status) => { + return Err(anyhow::anyhow!( + "Port-forward process exited immediately with status: {}", + status + )); + } + None => { + println!("Port-forwarding started successfully"); + Ok(PortForward { + process: port_forward, + }) + } + } } -async fn get_ray_version_from_env() -> anyhow::Result { - let python_version = get_version_from_env("ray", "ray, version ").await?; - Ok(python_version) +async fn submit_k8s( + working_dir: &Path, + command_segments: impl AsRef<[&str]>, + namespace: Option<&str>, +) -> anyhow::Result<()> { + let command_segments = command_segments.as_ref(); + + // Start port forwarding - it will be automatically killed when _port_forward is dropped + let _port_forward = establish_kubernetes_port_forward(namespace).await?; + + // Give the port-forward a moment to fully establish + tokio::time::sleep(Duration::from_secs(1)).await; + + // Submit the job + let exit_status = Command::new("ray") + .env("PYTHONUNBUFFERED", "1") + .args(["job", "submit", "--address", "http://localhost:8265"]) + .arg("--working-dir") + .arg(working_dir) + .arg("--") + .args(command_segments) + .spawn()? + .wait() + .await?; + + if exit_status.success() { + Ok(()) + } else { + Err(anyhow::anyhow!("Failed to submit job to the ray cluster")) + } } async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> { match daft_launcher.sub_command { - SubCommand::Init(Init { path }) => { + SubCommand::Init(Init { path, provider }) => { #[cfg(not(test))] if path.exists() { bail!("The path {path:?} already exists; the path given must point to a new location on your filesystem"); } - let contents = include_str!("template.toml"); - let contents = contents - .replace("", concat!("=", env!("CARGO_PKG_VERSION"))) - .replace( - "", - get_python_version_from_env().await?.to_string().as_str(), - ) - .replace( - "", - get_ray_version_from_env().await?.to_string().as_str(), - ); + let contents = if provider == "k8s" { + include_str!("template_k8s.toml") + } else { + include_str!("template.toml") + } + .replace("", env!("CARGO_PKG_VERSION")); fs::write(path, contents).await?; } SubCommand::Check(ConfigPath { config }) => { @@ -827,116 +940,186 @@ async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> { } SubCommand::Export(ConfigPath { config }) => { let (_, ray_config) = read_and_convert(&config, None).await?; + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); let ray_config_str = serde_yaml::to_string(&ray_config)?; + println!("{ray_config_str}"); } SubCommand::Up(ConfigPath { config }) => { - let (_, ray_config) = read_and_convert(&config, None).await?; - assert_is_logged_in_with_aws().await?; + let (daft_config, ray_config) = read_and_convert(&config, None).await?; + match daft_config.setup.provider { + DaftProvider::Aws => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + assert_is_logged_in_with_aws().await?; - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - run_ray_up_or_down_command(SpinDirection::Up, ray_path).await?; + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + run_ray_up_or_down_command(SpinDirection::Up, ray_path).await?; + } + DaftProvider::K8s => { + anyhow::bail!("'up' command is only available for AWS configurations"); + } + } } SubCommand::List(List { - regex, config_path, region, head, running, }) => { - assert_is_logged_in_with_aws().await?; - - let region = get_region(region, &config_path.config).await?; - let instances = get_ray_clusters_from_aws(region).await?; - let table = format_table(&instances, regex, head, running)?; - println!("{table}"); + let (daft_config, _) = read_and_convert(&config_path.config, None).await?; + match daft_config.setup.provider { + DaftProvider::Aws => { + assert_is_logged_in_with_aws().await?; + let aws_config = get_aws_config(&daft_config)?; + let region = region.unwrap_or_else(|| aws_config.region.clone()); + let instances = get_ray_clusters_from_aws(region).await?; + print_instances(&instances, head, running); + } + DaftProvider::K8s => { + anyhow::bail!("'list' command is only available for AWS configurations"); + } + } } SubCommand::Submit(Submit { config_path, job_name, }) => { let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; - assert_is_logged_in_with_aws().await?; let daft_job = daft_config .jobs .get(&job_name) .ok_or_else(|| anyhow::anyhow!("A job with the name {job_name} was not found"))?; - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - let _child = ssh::ssh_portforward(ray_path, &daft_config, None).await?; - submit( - daft_job.working_dir.as_ref(), - daft_job.command.as_ref().split(' ').collect::>(), - ) - .await?; + match &daft_config.setup.provider_config { + ProviderConfig::Aws(_aws_config) => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + submit_k8s( + daft_job.working_dir.as_ref(), + daft_job.command.as_ref().split(' ').collect::>(), + None, + ) + .await?; + } + ProviderConfig::K8s(k8s_config) => { + submit_k8s( + daft_job.working_dir.as_ref(), + daft_job.command.as_ref().split(' ').collect::>(), + k8s_config.namespace.as_deref(), + ) + .await?; + } + } } - SubCommand::Connect(Connect { - port, - no_dashboard, - config_path, - }) => { + SubCommand::Connect(Connect { port, config_path }) => { let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; - assert_is_logged_in_with_aws().await?; - - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - let open_join_handle = if !no_dashboard { - Some(spawn(|| { - sleep(Duration::from_millis(500)); - open::that("http://localhost:8265")?; - Ok::<_, anyhow::Error>(()) - })) - } else { - None - }; - - let _ = ssh::ssh_portforward(ray_path, &daft_config, Some(port)) - .await? - .wait_with_output() - .await?; - - if let Some(open_join_handle) = open_join_handle { - open_join_handle - .join() - .map_err(|_| anyhow::anyhow!("Failed to join browser-opening thread"))??; - }; + match daft_config.setup.provider { + DaftProvider::Aws => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + assert_is_logged_in_with_aws().await?; + + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + let aws_config = get_aws_config(&daft_config)?; + let _ = establish_ssh_portforward(ray_path, aws_config, Some(port)) + .await? + .wait_with_output() + .await?; + } + DaftProvider::K8s => { + anyhow::bail!("'connect' command is only available for AWS configurations"); + } + } } SubCommand::Ssh(ConfigPath { config }) => { let (daft_config, ray_config) = read_and_convert(&config, None).await?; - assert_is_logged_in_with_aws().await?; + match daft_config.setup.provider { + DaftProvider::Aws => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + assert_is_logged_in_with_aws().await?; - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - ssh::ssh(ray_path, &daft_config).await?; + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + let aws_config = get_aws_config(&daft_config)?; + ssh(ray_path, aws_config).await?; + } + DaftProvider::K8s => { + anyhow::bail!("'ssh' command is only available for AWS configurations"); + } + } } SubCommand::Sql(Sql { sql, config_path }) => { - let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; - assert_is_logged_in_with_aws().await?; - - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - let _child = ssh::ssh_portforward(ray_path, &daft_config, None).await?; - let (temp_sql_dir, sql_path) = create_temp_file("sql.py")?; - fs::write(sql_path, include_str!("sql.py")).await?; - submit(temp_sql_dir.path(), vec!["python", "sql.py", sql.as_ref()]).await?; + let (daft_config, _) = read_and_convert(&config_path.config, None).await?; + match &daft_config.setup.provider_config { + ProviderConfig::Aws(_) => { + anyhow::bail!("'sql' command is only available for Kubernetes configurations"); + } + ProviderConfig::K8s(k8s_config) => { + let (temp_sql_dir, sql_path) = create_temp_file("sql.py")?; + fs::write(sql_path, include_str!("sql.py")).await?; + submit_k8s( + temp_sql_dir.path(), + vec!["python", "sql.py", sql.as_ref()], + k8s_config.namespace.as_deref(), + ) + .await?; + } + } } SubCommand::Stop(ConfigPath { config }) => { - let (_, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Stop)).await?; - assert_is_logged_in_with_aws().await?; + let (daft_config, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Stop)).await?; + match daft_config.setup.provider { + DaftProvider::Aws => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + assert_is_logged_in_with_aws().await?; - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; + } + DaftProvider::K8s => { + anyhow::bail!("'stop' command is only available for AWS configurations"); + } + } } SubCommand::Kill(ConfigPath { config }) => { - let (_, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Kill)).await?; - assert_is_logged_in_with_aws().await?; + let (daft_config, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Kill)).await?; + match daft_config.setup.provider { + DaftProvider::Aws => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + assert_is_logged_in_with_aws().await?; - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; + } + DaftProvider::K8s => { + anyhow::bail!("'kill' command is only available for AWS configurations"); + } + } } } @@ -947,3 +1130,37 @@ async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> { run(DaftLauncher::parse()).await } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_init_and_export() { + run(DaftLauncher { + sub_command: SubCommand::Init(Init { + path: ".daft.toml".into(), + provider: "aws".into(), + }), + verbosity: 0, + }) + .await + .unwrap(); + run(DaftLauncher { + sub_command: SubCommand::Check(ConfigPath { + config: ".daft.toml".into(), + }), + verbosity: 0, + }) + .await + .unwrap(); + } +} + +// Helper function to get AWS config +fn get_aws_config(config: &DaftConfig) -> anyhow::Result<&AwsConfig> { + match &config.setup.provider_config { + ProviderConfig::Aws(aws_config) => Ok(aws_config), + ProviderConfig::K8s(_) => anyhow::bail!("Expected AWS configuration but found Kubernetes configuration"), + } +} diff --git a/src/template.toml b/src/template.toml index 69c291d..ac1227a 100644 --- a/src/template.toml +++ b/src/template.toml @@ -1,33 +1,27 @@ -# This is a default configuration file that you can use to spin up a ray-cluster using `daft-launcher`. -# Change up some of the configurations in here, and then run `daft up`. -# -# For more information on the availale commands and configuration options, visit [here](https://eventual-inc.github.io/daft-launcher). -# -# Happy daft-ing 🚀! - +# This is a template configuration file for daft-launcher with AWS provider [setup] -name = "daft-launcher-example" -requires = "" -python-version = "" -ray-version = "" +name = "my-daft-cluster" +version = "" +provider = "aws" + +# AWS-specific configuration +[setup.aws] region = "us-west-2" number-of-workers = 4 - -# The following configurations specify the type of servers in your cluster. -# The machine type below is what we usually use at Eventual, and the image id is Ubuntu based. -# If you want a smaller or bigger cluster, change the below two configurations accordingly. +ssh-user = "ubuntu" +ssh-private-key = "~/.ssh/id_rsa" instance-type = "i3.2xlarge" image-id = "ami-04dd23e62ed049936" +iam-instance-profile-name = "YourInstanceProfileName" # Optional +dependencies = [] # Optional additional Python packages to install -# This is the user profile that ssh's into the head machine. -# This value depends upon the `image-id` value up above. -# For Ubuntu AMIs, keep it as 'ubuntu'; for AWS AMIs, change it to 'ec2-user'. -ssh-user = "ubuntu" - -# Fill this out with your custom `.pem` key, or generate a new one by running `ssh-keygen -t rsa -b 2048 -m PEM -f my-key.pem`. -# Make sure the public key is uploaded to AWS. -ssh-private-key = "~/.ssh/my-keypair.pem" +# Run configuration (optional) +[run] +pre-setup-commands = [] +post-setup-commands = [] -# Fill in your python dependencies here. -# They'll be downloaded using `uv`. -dependencies = [] +# Job definitions +[[job]] +name = "example-job" +command = "python my_script.py" +working-dir = "~/my_project" \ No newline at end of file diff --git a/src/template_k8s.toml b/src/template_k8s.toml new file mode 100644 index 0000000..ded06c0 --- /dev/null +++ b/src/template_k8s.toml @@ -0,0 +1,19 @@ +# This is a template configuration file for daft-launcher with Kubernetes provider +[setup] +name = "my-daft-cluster" +version = "" +provider = "k8s" + +[setup.k8s] +namespace = "default" # Optional, defaults to "default" + +# Run configuration (optional) +[run] +pre-setup-commands = [] +post-setup-commands = [] + +# Job definitions +[[job]] +name = "example-job" +command = "python my_script.py" +working-dir = "~/my_project" \ No newline at end of file diff --git a/template.toml b/template.toml new file mode 100644 index 0000000..6d0751e --- /dev/null +++ b/template.toml @@ -0,0 +1,17 @@ +# This is a default configuration file that you can use to connect to an existing Kubernetes cluster running Ray (BYOC). +# Change up some of the configurations in here, and then run `daft up`. +# +# For more information on the availale commands and configuration options, visit [here](https://eventual-inc.github.io/daft-launcher). +# +# Happy daft-ing! + +[setup] +provider = "k8s" + +# They'll be downloaded using `uv`. +dependencies = [] + +[[job]] +name = "my-job" +command = "python hello_daft.py" +working-dir = "working-dir" \ No newline at end of file From 926c8236062300fe52ba442cf79542117c13c48f Mon Sep 17 00:00:00 2001 From: Jessie Young Date: Tue, 21 Jan 2025 17:36:13 -0800 Subject: [PATCH 02/15] Addressed PR comments, command groups, provisioned and byoc instead of aws and k8s --- Cargo.lock | 1 + Cargo.toml | 1 + README.md | 152 ++--- docs/{kubernetes => byoc}/README.md | 8 +- docs/{kubernetes => byoc}/cloud.md | 0 docs/{kubernetes => byoc}/local.md | 0 docs/{kubernetes => byoc}/on-prem.md | 0 docs/{kubernetes => byoc}/ray-installation.md | 0 src/main.rs | 632 ++++++++++-------- src/{template_k8s.toml => template_byoc.toml} | 12 +- ...emplate.toml => template_provisioned.toml} | 15 +- src/tests.rs | 26 + template.toml | 17 - 13 files changed, 456 insertions(+), 408 deletions(-) rename docs/{kubernetes => byoc}/README.md (63%) rename docs/{kubernetes => byoc}/cloud.md (100%) rename docs/{kubernetes => byoc}/local.md (100%) rename docs/{kubernetes => byoc}/on-prem.md (100%) rename docs/{kubernetes => byoc}/ray-installation.md (100%) rename src/{template_k8s.toml => template_byoc.toml} (52%) rename src/{template.toml => template_provisioned.toml} (67%) delete mode 100644 template.toml diff --git a/Cargo.lock b/Cargo.lock index 1ae5054..d846af6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -620,6 +620,7 @@ dependencies = [ "open", "regex", "rstest", + "semver", "serde", "serde_yaml", "tempdir", diff --git a/Cargo.toml b/Cargo.toml index a65f1ee..160c8af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ toml = "0.8" comfy-table = "7.1" regex = "1.11" open = "5.3" +semver = "1.0" [dependencies.anyhow] version = "1.0" diff --git a/README.md b/README.md index 587d60d..d5f09a2 100644 --- a/README.md +++ b/README.md @@ -25,33 +25,37 @@ Daft Launcher aims to solve this problem by providing a simple CLI tool to remov ## Capabilities What Daft Launcher is capable of: -1. Spinning up clusters (AWS only) -2. Listing all available clusters as well as their statuses (AWS only) -3. Submitting jobs to a cluster (AWS and Kubernetes) -4. Connecting to the cluster (AWS only, Kubernetes coming soon) -5. Spinning down clusters (AWS only) -6. Creating configuration files (AWS and Kubernetes) -7. Running raw SQL statements (AWS only, Kubernetes coming soon) +1. Spinning up clusters (Provisioned mode only) +2. Listing all available clusters as well as their statuses (Provisioned mode only) +3. Submitting jobs to a cluster (Both Provisioned and BYOC modes) +4. Connecting to the cluster (Provisioned mode only) +5. Spinning down clusters (Provisioned mode only) +6. Creating configuration files (Both modes) +7. Running raw SQL statements (BYOC mode only) -## Currently supported cloud providers +## Operation Modes Daft Launcher supports two modes of operation: -- **AWS**: For automatically provisioning and managing Ray clusters in AWS -- **Kubernetes**: For connecting to existing Ray clusters in Kubernetes - -### Command Support Matrix - -| Command | AWS | Kubernetes | -|----------|-----|------------| -| init | ✅ | ✅ | -| up | ✅ | ❌ | -| submit | ✅ | ✅ | -| stop | ✅ | ❌ | -| kill | ✅ | ❌ | -| list | ✅ | ❌ | -| connect | ✅ | ❌ | -| ssh | ✅ | ❌ | -| sql | ✅ | ❌ | +- **Provisioned**: Automatically provisions and manages Ray clusters in AWS +- **BYOC (Bring Your Own Cluster)**: Connects to existing Ray clusters in Kubernetes + +### Command Groups and Support Matrix + +| Command Group | Command | Provisioned | BYOC | +|--------------|---------|-------------|------| +| cluster | up | ✅ | ❌ | +| | down | ✅ | ❌ | +| | kill | ✅ | ❌ | +| | list | ✅ | ❌ | +| | connect | ✅ | ❌ | +| | ssh | ✅ | ❌ | +| job | submit | ✅ | ✅ | +| | sql | ✅ | ❌ | +| | status | ✅ | ❌ | +| | logs | ✅ | ❌ | +| config | init | ✅ | ✅ | +| | check | ✅ | ❌ | +| | export | ✅ | ❌ | ## Usage @@ -60,17 +64,17 @@ Daft Launcher supports two modes of operation: You'll need some python package manager installed. We recommend using [`uv`](https://astral.sh/blog/uv) for all things python. -#### For AWS Provider +#### For Provisioned Mode (AWS) 1. A valid AWS account with the necessary IAM role to spin up EC2 instances. This IAM role can either be created by you (assuming you have the appropriate permissions) or will need to be created by your administrator. 2. The [AWS CLI](https://aws.amazon.com/cli/) installed and configured on your machine. 3. Login using the AWS CLI. -#### For Kubernetes Provider +#### For BYOC Mode (Kubernetes) 1. A Kubernetes cluster with Ray already deployed - Can be local (minikube/kind), cloud-managed (EKS/GKE/AKS), or on-premise. - - See our [Kubernetes setup guides](./docs/kubernetes/README.md) for detailed instructions + - See our [BYOC setup guides](./docs/byoc/README.md) for detailed instructions 2. Ray cluster running in your Kubernetes cluster - Must be installed and configured using Helm - See provider-specific guides for installation steps @@ -78,9 +82,9 @@ We recommend using [`uv`](https://astral.sh/blog/uv) for all things python. 4. `kubectl` installed and configured with the correct context 5. Appropriate permissions to access the namespace where Ray is deployed -### SSH Key Setup for AWS +### SSH Key Setup for Provisioned Mode -To enable SSH access and port forwarding for AWS clusters, you need to: +To enable SSH access and port forwarding for provisioned clusters, you need to: 1. Create an SSH key pair (if you don't already have one): ```bash @@ -107,8 +111,8 @@ To enable SSH access and port forwarding for AWS clusters, you need to: 4. Update your daft configuration to use this key: ```toml - [setup.aws] - # ... other aws config ... + [setup.provisioned] + # ... other config ... ssh-private-key = "~/.ssh/daft-key" # Path to your private key ssh-user = "ubuntu" # User depends on the AMI (ubuntu for Ubuntu AMIs) ``` @@ -145,69 +149,55 @@ All interactions with Daft Launcher are primarily communicated via a configurati By default, Daft Launcher will look inside your `$CWD` for a file named `.daft.toml`. You can override this behaviour by specifying a custom configuration file. -#### AWS Provider (Default) +#### Provisioned Mode (AWS) ```bash -# create a new AWS configuration file -daft init -# or explicitly specify AWS provider -daft init --provider aws - -# spin up a cluster (AWS only) -daft up -# or optionally, pass in a custom config file -daft up -c my-custom-config.toml - -# list all active clusters (AWS only) -daft list - -# submit a job defined in your config -daft submit --working-dir <...> example-job - -# execute SQL query -daft sql "SELECT * FROM my_table" - -# connect to the Ray dashboard -daft connect - -# SSH into the head node -daft ssh - -# stop the cluster -daft stop - -# terminate the cluster -daft kill +# Initialize a new provisioned mode configuration +daft config init --provider provisioned +# or use the default provider (provisioned) +daft config init + +# Cluster management +daft provisioned up +daft provisioned list +daft provisioned connect +daft provisioned ssh +daft provisioned down +daft provisioned kill + +# Job management (works in both modes) +daft job submit example-job +daft job status example-job +daft job logs example-job + +# Configuration management +daft config check +daft config export ``` -#### Kubernetes Provider +#### BYOC Mode (Kubernetes) ```bash -# create a new Kubernetes configuration file -daft init --provider k8s - -# submit a job defined in your config -daft submit example-job - -# execute SQL query (K8s only) -daft sql "SELECT * FROM my_table" +# Initialize a new BYOC mode configuration +daft config init --provider byoc ``` ### Configuration Files You can specify a custom configuration file path with the `-c` flag: ```bash -daft -c my-config.toml submit example-job +daft -c my-config.toml job submit example-job ``` -Example AWS configuration: +Example Provisioned mode configuration: ```toml [setup] name = "my-daft-cluster" version = "0.1.0" -provider = "aws" +provider = "provisioned" +dependencies = [] # Optional additional Python packages to install -[setup.aws] +[setup.provisioned] region = "us-west-2" number-of-workers = 4 ssh-user = "ubuntu" @@ -215,7 +205,6 @@ ssh-private-key = "~/.ssh/daft-key" instance-type = "i3.2xlarge" image-id = "ami-04dd23e62ed049936" iam-instance-profile-name = "YourInstanceProfileName" # Optional -dependencies = [] # Optional additional Python packages [run] pre-setup-commands = [] @@ -227,20 +216,17 @@ command = "python my_script.py" working-dir = "~/my_project" ``` -Example Kubernetes configuration: +Example BYOC mode configuration: ```toml [setup] name = "my-daft-cluster" version = "0.1.0" -provider = "k8s" +provider = "byoc" +dependencies = [] # Optional additional Python packages to install -[setup.k8s] +[setup.byoc] namespace = "default" # Optional, defaults to "default" -[run] -pre-setup-commands = [] -post-setup-commands = [] - [[job]] name = "example-job" command = "python my_script.py" diff --git a/docs/kubernetes/README.md b/docs/byoc/README.md similarity index 63% rename from docs/kubernetes/README.md rename to docs/byoc/README.md index 2842d0e..4debdd0 100644 --- a/docs/kubernetes/README.md +++ b/docs/byoc/README.md @@ -1,8 +1,6 @@ -# Kubernetes Setup for Daft +# BYOC (Bring Your Own Cluster) Mode Setup for Daft -> **Note**: This documentation is housed in the `daft-launcher` repository while the Kubernetes approach is being reviewed. Once finalized, these docs will be copied to the separate documentation repository. - -This directory contains guides for setting up Ray and Daft on various Kubernetes environments: +This directory contains guides for setting up Ray and Daft on various Kubernetes environments for BYOC mode: - [Local Development](./local.md) - Setting up a local Kubernetes cluster for development - [Cloud Providers](./cloud.md) - Instructions for EKS, GKE, and AKS setups @@ -10,7 +8,7 @@ This directory contains guides for setting up Ray and Daft on various Kubernetes ## Prerequisites -Before using `daft-launcher` with Kubernetes, you must: +Before using `daft-launcher` in BYOC mode with Kubernetes, you must: 1. Have a running Kubernetes cluster (local, cloud-managed, or on-premise) 2. Install and configure Ray on your Kubernetes cluster 3. Install Daft on your cluster diff --git a/docs/kubernetes/cloud.md b/docs/byoc/cloud.md similarity index 100% rename from docs/kubernetes/cloud.md rename to docs/byoc/cloud.md diff --git a/docs/kubernetes/local.md b/docs/byoc/local.md similarity index 100% rename from docs/kubernetes/local.md rename to docs/byoc/local.md diff --git a/docs/kubernetes/on-prem.md b/docs/byoc/on-prem.md similarity index 100% rename from docs/kubernetes/on-prem.md rename to docs/byoc/on-prem.md diff --git a/docs/kubernetes/ray-installation.md b/docs/byoc/ray-installation.md similarity index 100% rename from docs/kubernetes/ray-installation.md rename to docs/byoc/ray-installation.md diff --git a/src/main.rs b/src/main.rs index 5f426f7..6cea310 100644 --- a/src/main.rs +++ b/src/main.rs @@ -43,55 +43,84 @@ struct DaftLauncher { #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum SubCommand { - /// Initialize a daft-launcher configuration file. - /// - /// If no path is provided, this will create a default ".daft.toml" in the - /// current working directory. - Init(Init), - - /// Check to make sure the daft-launcher configuration file is correct. - Check(ConfigPath), + /// Manage Daft-provisioned clusters (AWS) + Provisioned(ProvisionedCommands), + /// Manage existing clusters (Kubernetes) + Byoc(ByocCommands), + /// Manage jobs across all cluster types + Job(JobCommands), + /// Manage configurations + Config(ConfigCommands), +} - /// Export the daft-launcher configuration file to a Ray configuration file. - Export(ConfigPath), +#[derive(Debug, Parser, Clone, PartialEq, Eq)] +struct ProvisionedCommands { + #[command(subcommand)] + command: ProvisionedCommand, +} - /// Spin up a new cluster. +#[derive(Debug, Subcommand, Clone, PartialEq, Eq)] +enum ProvisionedCommand { + /// Create a new cluster Up(ConfigPath), - - /// List all Ray clusters in your AWS account. - /// - /// This will *only* list clusters that have been spun up by Ray. + /// Stop a running cluster + Down(ConfigPath), + /// Terminate a cluster + Kill(ConfigPath), + /// List all clusters List(List), + /// Connect to cluster dashboard + Connect(Connect), + /// SSH into cluster head node + Ssh(ConfigPath), +} - /// Submit a job to the Ray cluster. - /// - /// The configurations of the job should be placed inside of your - /// daft-launcher configuration file. - Submit(Submit), +#[derive(Debug, Parser, Clone, PartialEq, Eq)] +struct ByocCommands { + #[command(subcommand)] + command: ByocCommand, +} - /// Establish an ssh port-forward connection from your local machine to the - /// Ray cluster. - Connect(Connect), +#[derive(Debug, Subcommand, Clone, PartialEq, Eq)] +enum ByocCommand { + /// Verify connection to existing cluster + Verify(ConfigPath), + /// Show cluster information + Info(ConfigPath), +} - /// SSH into the head of the remote Ray cluster. - Ssh(ConfigPath), +#[derive(Debug, Parser, Clone, PartialEq, Eq)] +struct JobCommands { + #[command(subcommand)] + command: JobCommand, +} - /// Submit a SQL query string to the Ray cluster. - /// - /// This is executed using Daft's SQL API support. +#[derive(Debug, Subcommand, Clone, PartialEq, Eq)] +enum JobCommand { + /// Submit a job to the cluster + Submit(Submit), + /// Execute SQL queries Sql(Sql), + /// Check job status + Status(ConfigPath), + /// View job logs + Logs(ConfigPath), +} - /// Spin down a given cluster and put the nodes to "sleep". - /// - /// This will *not* delete the nodes, only stop them. The nodes can be - /// restarted at a future time. - Stop(ConfigPath), +#[derive(Debug, Parser, Clone, PartialEq, Eq)] +struct ConfigCommands { + #[command(subcommand)] + command: ConfigCommand, +} - /// Spin down a given cluster and fully terminate the nodes. - /// - /// This *will* delete the nodes; they will not be accessible from here on - /// out. - Kill(ConfigPath), +#[derive(Debug, Subcommand, Clone, PartialEq, Eq)] +enum ConfigCommand { + /// Initialize a new configuration + Init(Init), + /// Validate configuration + Check(ConfigPath), + /// Export configuration to Ray format + Export(ConfigPath), } #[derive(Debug, Parser, Clone, PartialEq, Eq)] @@ -100,8 +129,8 @@ struct Init { #[arg(default_value = ".daft.toml")] path: PathBuf, - /// The provider to use - either 'aws' (default) to auto-generate a cluster or 'k8s' for existing Kubernetes clusters - #[arg(long, default_value = "aws")] + /// The provider to use - either 'provisioned' (default) to auto-generate a cluster or 'byoc' for existing Kubernetes clusters + #[arg(long, default_value = "provisioned")] provider: String, } @@ -162,8 +191,6 @@ struct ConfigPath { #[serde(rename_all = "kebab-case", deny_unknown_fields)] struct DaftConfig { setup: DaftSetup, - #[serde(default)] - run: DaftRun, #[serde(rename = "job", deserialize_with = "parse_jobs")] jobs: HashMap, } @@ -175,6 +202,8 @@ struct DaftSetup { #[serde(deserialize_with = "parse_version_req")] version: VersionReq, provider: DaftProvider, + #[serde(default)] + dependencies: Vec, #[serde(flatten)] provider_config: ProviderConfig, } @@ -182,10 +211,17 @@ struct DaftSetup { #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] enum ProviderConfig { - #[serde(rename = "aws")] - Aws(AwsConfig), - #[serde(rename = "k8s")] - K8s(K8sConfig), + #[serde(rename = "provisioned")] + Provisioned(AwsConfigWithRun), + #[serde(rename = "byoc")] + Byoc(K8sConfig), +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +struct AwsConfigWithRun { + #[serde(flatten)] + config: AwsConfig, } #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] @@ -202,8 +238,6 @@ struct AwsConfig { #[serde(default = "default_image_id")] image_id: StrRef, iam_instance_profile_name: Option, - #[serde(default)] - dependencies: Vec, } #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] @@ -310,8 +344,10 @@ where #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] enum DaftProvider { - Aws, - K8s, + #[serde(rename = "provisioned")] + Provisioned, + #[serde(rename = "byoc")] + Byoc, } impl FromStr for DaftProvider { @@ -319,22 +355,13 @@ impl FromStr for DaftProvider { fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { - "aws" => Ok(DaftProvider::Aws), - "k8s" => Ok(DaftProvider::K8s), - _ => anyhow::bail!("Invalid provider '{}'. Must be either 'aws' or 'k8s'", s), + "provisioned" => Ok(DaftProvider::Provisioned), + "byoc" => Ok(DaftProvider::Byoc), + _ => anyhow::bail!("Invalid provider '{}'. Must be either 'provisioned' or 'byoc'", s), } } } -#[derive(Default, Debug, Deserialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "kebab-case", deny_unknown_fields)] -struct DaftRun { - #[serde(default)] - pre_setup_commands: Vec, - #[serde(default)] - post_setup_commands: Vec, -} - #[derive(Debug, Clone, PartialEq, Eq)] struct DaftJob { command: StrRef, @@ -415,40 +442,40 @@ async fn read_and_convert( let daft_config = toml::from_str::(&contents)?; let ray_config = match &daft_config.setup.provider_config { - ProviderConfig::K8s(_) => None, - ProviderConfig::Aws(aws_config) => { - let key_name = aws_config.ssh_private_key + ProviderConfig::Byoc(_) => None, + ProviderConfig::Provisioned(aws_config) => { + let key_name = aws_config.config.ssh_private_key .clone() .file_stem() .ok_or_else(|| anyhow::anyhow!(r#"Private key doesn't have a name of the format "name.ext""#))? .to_str() - .ok_or_else(|| anyhow::anyhow!("The file {:?} does not have a valid UTF-8 name", aws_config.ssh_private_key))? + .ok_or_else(|| anyhow::anyhow!("The file {:?} does not have a valid UTF-8 name", aws_config.config.ssh_private_key))? .into(); let node_config = RayNodeConfig { key_name, - instance_type: aws_config.instance_type.clone(), - image_id: aws_config.image_id.clone(), - iam_instance_profile: aws_config.iam_instance_profile_name.clone().map(|name| IamInstanceProfile { name }), + instance_type: aws_config.config.instance_type.clone(), + image_id: aws_config.config.image_id.clone(), + iam_instance_profile: aws_config.config.iam_instance_profile_name.clone().map(|name| IamInstanceProfile { name }), }; Some(RayConfig { cluster_name: daft_config.setup.name.clone(), - max_workers: aws_config.number_of_workers, + max_workers: aws_config.config.number_of_workers, provider: RayProvider { r#type: "aws".into(), - region: aws_config.region.clone(), + region: aws_config.config.region.clone(), cache_stopped_nodes: teardown_behaviour.map(TeardownBehaviour::to_cache_stopped_nodes), }, auth: RayAuth { - ssh_user: aws_config.ssh_user.clone(), - ssh_private_key: aws_config.ssh_private_key.clone(), + ssh_user: aws_config.config.ssh_user.clone(), + ssh_private_key: aws_config.config.ssh_private_key.clone(), }, available_node_types: vec![ ( "ray.head.default".into(), RayNodeType { - max_workers: aws_config.number_of_workers, + max_workers: aws_config.config.number_of_workers, node_config: node_config.clone(), resources: Some(RayResources { cpu: 0 }), }, @@ -456,7 +483,7 @@ async fn read_and_convert( ( "ray.worker.default".into(), RayNodeType { - max_workers: aws_config.number_of_workers, + max_workers: aws_config.config.number_of_workers, node_config, resources: None, }, @@ -474,9 +501,8 @@ async fn read_and_convert( "source ~/.bashrc".into(), "uv pip install boto3 pip ray[default] getdaft py-spy deltalake".into(), ]; - if !aws_config.dependencies.is_empty() { - let deps = aws_config - .dependencies + if !daft_config.setup.dependencies.is_empty() { + let deps = daft_config.setup.dependencies .iter() .map(|dep| format!(r#""{dep}""#)) .collect::>() @@ -712,8 +738,8 @@ async fn get_region(region: Option, config: impl AsRef) -> anyhow: } else if config.exists() { let (daft_config, _) = read_and_convert(&config, None).await?; match &daft_config.setup.provider_config { - ProviderConfig::Aws(aws_config) => aws_config.region.clone(), - ProviderConfig::K8s(_) => "us-west-2".into(), + ProviderConfig::Provisioned(aws_config) => aws_config.config.region.clone(), + ProviderConfig::Byoc(_) => "us-west-2".into(), } } else { "us-west-2".into() @@ -922,245 +948,281 @@ async fn submit_k8s( async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> { match daft_launcher.sub_command { - SubCommand::Init(Init { path, provider }) => { - #[cfg(not(test))] - if path.exists() { - bail!("The path {path:?} already exists; the path given must point to a new location on your filesystem"); - } - let contents = if provider == "k8s" { - include_str!("template_k8s.toml") - } else { - include_str!("template.toml") - } - .replace("", env!("CARGO_PKG_VERSION")); - fs::write(path, contents).await?; + SubCommand::Config(config_cmd) => { + config_cmd.command.run(daft_launcher.verbosity).await } - SubCommand::Check(ConfigPath { config }) => { - let _ = read_and_convert(&config, None).await?; + SubCommand::Job(job_cmd) => { + job_cmd.command.run(daft_launcher.verbosity).await } - SubCommand::Export(ConfigPath { config }) => { - let (_, ray_config) = read_and_convert(&config, None).await?; - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); - let ray_config_str = serde_yaml::to_string(&ray_config)?; - - println!("{ray_config_str}"); + SubCommand::Provisioned(provisioned_cmd) => { + provisioned_cmd.command.run(daft_launcher.verbosity).await } - SubCommand::Up(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, None).await?; - match daft_config.setup.provider { - DaftProvider::Aws => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); - assert_is_logged_in_with_aws().await?; + SubCommand::Byoc(byoc_cmd) => { + byoc_cmd.command.run(daft_launcher.verbosity).await + } + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + run(DaftLauncher::parse()).await +} + +// Helper function to get AWS config +fn get_aws_config(config: &DaftConfig) -> anyhow::Result<&AwsConfig> { + match &config.setup.provider_config { + ProviderConfig::Provisioned(aws_config) => Ok(&aws_config.config), + ProviderConfig::Byoc(_) => anyhow::bail!("Expected provisioned configuration but found Kubernetes configuration"), + } +} - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - run_ray_up_or_down_command(SpinDirection::Up, ray_path).await?; +impl ConfigCommand { + async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + match self { + ConfigCommand::Init(Init { path, provider }) => { + #[cfg(not(test))] + if path.exists() { + bail!("The path {path:?} already exists; the path given must point to a new location on your filesystem"); } - DaftProvider::K8s => { - anyhow::bail!("'up' command is only available for AWS configurations"); + let contents = if provider == "byoc" { + include_str!("template_byoc.toml") + } else { + include_str!("template_provisioned.toml") } + .replace("", env!("CARGO_PKG_VERSION")); + fs::write(path, contents).await?; } - } - SubCommand::List(List { - config_path, - region, - head, - running, - }) => { - let (daft_config, _) = read_and_convert(&config_path.config, None).await?; - match daft_config.setup.provider { - DaftProvider::Aws => { - assert_is_logged_in_with_aws().await?; - let aws_config = get_aws_config(&daft_config)?; - let region = region.unwrap_or_else(|| aws_config.region.clone()); - let instances = get_ray_clusters_from_aws(region).await?; - print_instances(&instances, head, running); - } - DaftProvider::K8s => { - anyhow::bail!("'list' command is only available for AWS configurations"); + ConfigCommand::Check(ConfigPath { config }) => { + let _ = read_and_convert(&config, None).await?; + } + ConfigCommand::Export(ConfigPath { config }) => { + let (_, ray_config) = read_and_convert(&config, None).await?; + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); } + let ray_config = ray_config.unwrap(); + let ray_config_str = serde_yaml::to_string(&ray_config)?; + println!("{ray_config_str}"); } } - SubCommand::Submit(Submit { - config_path, - job_name, - }) => { - let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; - let daft_job = daft_config - .jobs - .get(&job_name) - .ok_or_else(|| anyhow::anyhow!("A job with the name {job_name} was not found"))?; - - match &daft_config.setup.provider_config { - ProviderConfig::Aws(_aws_config) => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); + Ok(()) + } +} + +impl JobCommand { + async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + match self { + JobCommand::Submit(Submit { config_path, job_name }) => { + let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; + let daft_job = daft_config + .jobs + .get(job_name) + .ok_or_else(|| anyhow::anyhow!("A job with the name {job_name} was not found"))?; + + match &daft_config.setup.provider_config { + ProviderConfig::Provisioned(_) => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + + let aws_config = get_aws_config(&daft_config)?; + // Start port forwarding - it will be automatically killed when _port_forward is dropped + let _port_forward = establish_ssh_portforward(ray_path, aws_config, Some(8265)).await?; + + // Give the port-forward a moment to fully establish + tokio::time::sleep(Duration::from_secs(1)).await; + + // Submit the job + let exit_status = Command::new("ray") + .env("PYTHONUNBUFFERED", "1") + .args(["job", "submit", "--address", "http://localhost:8265"]) + .arg("--working-dir") + .arg(daft_job.working_dir.as_ref()) + .arg("--") + .args(daft_job.command.as_ref().split(' ').collect::>()) + .spawn()? + .wait() + .await?; + + if !exit_status.success() { + anyhow::bail!("Failed to submit job to the ray cluster"); + } + } + ProviderConfig::Byoc(k8s_config) => { + submit_k8s( + daft_job.working_dir.as_ref(), + daft_job.command.as_ref().split(' ').collect::>(), + k8s_config.namespace.as_deref(), + ) + .await?; } - let ray_config = ray_config.unwrap(); - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - submit_k8s( - daft_job.working_dir.as_ref(), - daft_job.command.as_ref().split(' ').collect::>(), - None, - ) - .await?; - } - ProviderConfig::K8s(k8s_config) => { - submit_k8s( - daft_job.working_dir.as_ref(), - daft_job.command.as_ref().split(' ').collect::>(), - k8s_config.namespace.as_deref(), - ) - .await?; } } - } - SubCommand::Connect(Connect { port, config_path }) => { - let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; - match daft_config.setup.provider { - DaftProvider::Aws => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); + JobCommand::Sql(Sql { sql, config_path }) => { + let (daft_config, _) = read_and_convert(&config_path.config, None).await?; + match &daft_config.setup.provider_config { + ProviderConfig::Provisioned(_) => { + anyhow::bail!("'sql' command is only available for BYOC configurations"); } - let ray_config = ray_config.unwrap(); - assert_is_logged_in_with_aws().await?; - - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - let aws_config = get_aws_config(&daft_config)?; - let _ = establish_ssh_portforward(ray_path, aws_config, Some(port)) - .await? - .wait_with_output() + ProviderConfig::Byoc(k8s_config) => { + let (temp_sql_dir, sql_path) = create_temp_file("sql.py")?; + fs::write(sql_path, include_str!("sql.py")).await?; + submit_k8s( + temp_sql_dir.path(), + vec!["python", "sql.py", sql.as_ref()], + k8s_config.namespace.as_deref(), + ) .await?; + } } - DaftProvider::K8s => { - anyhow::bail!("'connect' command is only available for AWS configurations"); - } + } + JobCommand::Status(_) => { + anyhow::bail!("Job status command not yet implemented"); + } + JobCommand::Logs(_) => { + anyhow::bail!("Job logs command not yet implemented"); } } - SubCommand::Ssh(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, None).await?; - match daft_config.setup.provider { - DaftProvider::Aws => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); - assert_is_logged_in_with_aws().await?; + Ok(()) + } +} - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - let aws_config = get_aws_config(&daft_config)?; - ssh(ray_path, aws_config).await?; - } - DaftProvider::K8s => { - anyhow::bail!("'ssh' command is only available for AWS configurations"); +impl ProvisionedCommand { + async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + match self { + ProvisionedCommand::Up(ConfigPath { config }) => { + let (daft_config, ray_config) = read_and_convert(&config, None).await?; + match daft_config.setup.provider { + DaftProvider::Provisioned => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + assert_is_logged_in_with_aws().await?; + + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + run_ray_up_or_down_command(SpinDirection::Up, ray_path).await?; + } + DaftProvider::Byoc => { + anyhow::bail!("'up' command is only available for provisioned configurations"); + } } } - } - SubCommand::Sql(Sql { sql, config_path }) => { - let (daft_config, _) = read_and_convert(&config_path.config, None).await?; - match &daft_config.setup.provider_config { - ProviderConfig::Aws(_) => { - anyhow::bail!("'sql' command is only available for Kubernetes configurations"); - } - ProviderConfig::K8s(k8s_config) => { - let (temp_sql_dir, sql_path) = create_temp_file("sql.py")?; - fs::write(sql_path, include_str!("sql.py")).await?; - submit_k8s( - temp_sql_dir.path(), - vec!["python", "sql.py", sql.as_ref()], - k8s_config.namespace.as_deref(), - ) - .await?; + ProvisionedCommand::Down(ConfigPath { config }) => { + let (daft_config, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Stop)).await?; + match daft_config.setup.provider { + DaftProvider::Provisioned => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + assert_is_logged_in_with_aws().await?; + + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; + } + DaftProvider::Byoc => { + anyhow::bail!("'down' command is only available for provisioned configurations"); + } } } - } - SubCommand::Stop(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Stop)).await?; - match daft_config.setup.provider { - DaftProvider::Aws => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); + ProvisionedCommand::Kill(ConfigPath { config }) => { + let (daft_config, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Kill)).await?; + match daft_config.setup.provider { + DaftProvider::Provisioned => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + assert_is_logged_in_with_aws().await?; + + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; + } + DaftProvider::Byoc => { + anyhow::bail!("'kill' command is only available for provisioned configurations"); } - let ray_config = ray_config.unwrap(); - assert_is_logged_in_with_aws().await?; - - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; } - DaftProvider::K8s => { - anyhow::bail!("'stop' command is only available for AWS configurations"); + } + ProvisionedCommand::List(List { config_path, region, head, running }) => { + let (daft_config, _) = read_and_convert(&config_path.config, None).await?; + match daft_config.setup.provider { + DaftProvider::Provisioned => { + assert_is_logged_in_with_aws().await?; + let aws_config = get_aws_config(&daft_config)?; + let region = region.as_ref().unwrap_or_else(|| &aws_config.region); + let instances = get_ray_clusters_from_aws(region.clone()).await?; + print_instances(&instances, *head, *running); + } + DaftProvider::Byoc => { + anyhow::bail!("'list' command is only available for provisioned configurations"); + } } } - } - SubCommand::Kill(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Kill)).await?; - match daft_config.setup.provider { - DaftProvider::Aws => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); + ProvisionedCommand::Connect(Connect { port, config_path }) => { + let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; + match daft_config.setup.provider { + DaftProvider::Provisioned => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + assert_is_logged_in_with_aws().await?; + + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + let aws_config = get_aws_config(&daft_config)?; + let _ = establish_ssh_portforward(ray_path, aws_config, Some(*port)) + .await? + .wait_with_output() + .await?; + } + DaftProvider::Byoc => { + anyhow::bail!("'connect' command is only available for provisioned configurations"); } - let ray_config = ray_config.unwrap(); - assert_is_logged_in_with_aws().await?; - - let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; } - DaftProvider::K8s => { - anyhow::bail!("'kill' command is only available for AWS configurations"); + } + ProvisionedCommand::Ssh(ConfigPath { config }) => { + let (daft_config, ray_config) = read_and_convert(&config, None).await?; + match daft_config.setup.provider { + DaftProvider::Provisioned => { + if ray_config.is_none() { + anyhow::bail!("Failed to find Ray config in config file"); + } + let ray_config = ray_config.unwrap(); + assert_is_logged_in_with_aws().await?; + + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(ray_config, &ray_path).await?; + let aws_config = get_aws_config(&daft_config)?; + ssh(ray_path, aws_config).await?; + } + DaftProvider::Byoc => { + anyhow::bail!("'ssh' command is only available for provisioned configurations"); + } } } } - } - - Ok(()) -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - run(DaftLauncher::parse()).await -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_init_and_export() { - run(DaftLauncher { - sub_command: SubCommand::Init(Init { - path: ".daft.toml".into(), - provider: "aws".into(), - }), - verbosity: 0, - }) - .await - .unwrap(); - run(DaftLauncher { - sub_command: SubCommand::Check(ConfigPath { - config: ".daft.toml".into(), - }), - verbosity: 0, - }) - .await - .unwrap(); + Ok(()) } } -// Helper function to get AWS config -fn get_aws_config(config: &DaftConfig) -> anyhow::Result<&AwsConfig> { - match &config.setup.provider_config { - ProviderConfig::Aws(aws_config) => Ok(aws_config), - ProviderConfig::K8s(_) => anyhow::bail!("Expected AWS configuration but found Kubernetes configuration"), +impl ByocCommand { + async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + match self { + ByocCommand::Verify(ConfigPath { config: _ }) => { + anyhow::bail!("Verify command not yet implemented"); + } + ByocCommand::Info(ConfigPath { config: _ }) => { + anyhow::bail!("Info command not yet implemented"); + } + } + Ok(()) } } diff --git a/src/template_k8s.toml b/src/template_byoc.toml similarity index 52% rename from src/template_k8s.toml rename to src/template_byoc.toml index ded06c0..e70adc0 100644 --- a/src/template_k8s.toml +++ b/src/template_byoc.toml @@ -1,17 +1,13 @@ -# This is a template configuration file for daft-launcher with Kubernetes provider +# This is a template configuration file for daft-launcher with BYOC provider [setup] name = "my-daft-cluster" version = "" -provider = "k8s" +provider = "byoc" +# TODO: support dependencies -[setup.k8s] +[setup.byoc] namespace = "default" # Optional, defaults to "default" -# Run configuration (optional) -[run] -pre-setup-commands = [] -post-setup-commands = [] - # Job definitions [[job]] name = "example-job" diff --git a/src/template.toml b/src/template_provisioned.toml similarity index 67% rename from src/template.toml rename to src/template_provisioned.toml index ac1227a..4299fbf 100644 --- a/src/template.toml +++ b/src/template_provisioned.toml @@ -1,11 +1,12 @@ -# This is a template configuration file for daft-launcher with AWS provider +# This is a template configuration file for daft-launcher with provisioned provider [setup] name = "my-daft-cluster" version = "" -provider = "aws" +provider = "provisioned" +dependencies = [] # Optional additional Python packages to install -# AWS-specific configuration -[setup.aws] +# Provisioned (AWS) configuration +[setup.provisioned] region = "us-west-2" number-of-workers = 4 ssh-user = "ubuntu" @@ -13,12 +14,6 @@ ssh-private-key = "~/.ssh/id_rsa" instance-type = "i3.2xlarge" image-id = "ami-04dd23e62ed049936" iam-instance-profile-name = "YourInstanceProfileName" # Optional -dependencies = [] # Optional additional Python packages to install - -# Run configuration (optional) -[run] -pre-setup-commands = [] -post-setup-commands = [] # Job definitions [[job]] diff --git a/src/tests.rs b/src/tests.rs index 5aad7cc..960bb21 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,6 +1,7 @@ use tokio::fs; use super::*; +use crate::{ConfigCommand, ConfigCommands, ConfigPath, DaftLauncher, Init, SubCommand}; fn not_found_okay(result: std::io::Result<()>) -> std::io::Result<()> { match result { @@ -189,3 +190,28 @@ pub fn simple_config() -> (DaftConfig, Option, RayConfig) { (daft_config, None, ray_config) } + +#[tokio::test] +async fn test_init_and_export() { + crate::run(DaftLauncher { + sub_command: SubCommand::Config(ConfigCommands { + command: ConfigCommand::Init(Init { + path: ".daft.toml".into(), + provider: "provisioned".into(), + }), + }), + verbosity: 0, + }) + .await + .unwrap(); + crate::run(DaftLauncher { + sub_command: SubCommand::Config(ConfigCommands { + command: ConfigCommand::Check(ConfigPath { + config: ".daft.toml".into(), + }), + }), + verbosity: 0, + }) + .await + .unwrap(); +} diff --git a/template.toml b/template.toml deleted file mode 100644 index 6d0751e..0000000 --- a/template.toml +++ /dev/null @@ -1,17 +0,0 @@ -# This is a default configuration file that you can use to connect to an existing Kubernetes cluster running Ray (BYOC). -# Change up some of the configurations in here, and then run `daft up`. -# -# For more information on the availale commands and configuration options, visit [here](https://eventual-inc.github.io/daft-launcher). -# -# Happy daft-ing! - -[setup] -provider = "k8s" - -# They'll be downloaded using `uv`. -dependencies = [] - -[[job]] -name = "my-job" -command = "python hello_daft.py" -working-dir = "working-dir" \ No newline at end of file From 9e50b05a8ff7d2fd01f59f1c39de8e076058d96d Mon Sep 17 00:00:00 2001 From: Jessie Young Date: Tue, 21 Jan 2025 17:44:25 -0800 Subject: [PATCH 03/15] switched to DaftProvider enum --- src/main.rs | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/main.rs b/src/main.rs index 6cea310..a9787ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -130,8 +130,8 @@ struct Init { path: PathBuf, /// The provider to use - either 'provisioned' (default) to auto-generate a cluster or 'byoc' for existing Kubernetes clusters - #[arg(long, default_value = "provisioned")] - provider: String, + #[arg(long, default_value_t = DaftProvider::Provisioned)] + provider: DaftProvider, } #[derive(Debug, Parser, Clone, PartialEq, Eq)] @@ -362,6 +362,15 @@ impl FromStr for DaftProvider { } } +impl ToString for DaftProvider { + fn to_string(&self) -> String { + match self { + DaftProvider::Provisioned => "provisioned".to_string(), + DaftProvider::Byoc => "byoc".to_string(), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] struct DaftJob { command: StrRef, @@ -984,10 +993,9 @@ impl ConfigCommand { if path.exists() { bail!("The path {path:?} already exists; the path given must point to a new location on your filesystem"); } - let contents = if provider == "byoc" { - include_str!("template_byoc.toml") - } else { - include_str!("template_provisioned.toml") + let contents = match provider { + DaftProvider::Byoc => include_str!("template_byoc.toml"), + DaftProvider::Provisioned => include_str!("template_provisioned.toml"), } .replace("", env!("CARGO_PKG_VERSION")); fs::write(path, contents).await?; From 4446dea0bce7c6d153c57a4ba0b329a90a517163 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 17:43:27 -0800 Subject: [PATCH 04/15] Address some changes with the PR --- Cargo.lock | 148 -------- Cargo.toml | 6 - src/main.rs | 650 +++++++++++++--------------------- src/ssh.rs | 20 +- src/template_byoc.toml | 7 +- src/template_provisioned.toml | 6 +- src/tests.rs | 10 +- 7 files changed, 274 insertions(+), 573 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d846af6..8834141 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,15 +17,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" -[[package]] -name = "aho-corasick" -version = "1.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" -dependencies = [ - "memchr", -] - [[package]] name = "anstream" version = "0.6.18" @@ -617,10 +608,6 @@ dependencies = [ "aws-sdk-sts", "clap", "comfy-table", - "open", - "regex", - "rstest", - "semver", "serde", "serde_yaml", "tempdir", @@ -724,17 +711,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" -[[package]] -name = "futures-macro" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "futures-sink" version = "0.3.31" @@ -747,12 +723,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" -[[package]] -name = "futures-timer" -version = "3.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" - [[package]] name = "futures-util" version = "0.3.31" @@ -760,11 +730,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", - "futures-macro", "futures-task", "pin-project-lite", "pin-utils", - "slab", ] [[package]] @@ -794,12 +762,6 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" -[[package]] -name = "glob" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" - [[package]] name = "h2" version = "0.3.26" @@ -1103,25 +1065,6 @@ dependencies = [ "hashbrown", ] -[[package]] -name = "is-docker" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3" -dependencies = [ - "once_cell", -] - -[[package]] -name = "is-wsl" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5" -dependencies = [ - "is-docker", - "once_cell", -] - [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -1258,17 +1201,6 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" -[[package]] -name = "open" -version = "5.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2483562e62ea94312f3576a7aca397306df7990b8d89033e18766744377ef95" -dependencies = [ - "is-wsl", - "libc", - "pathdiff", -] - [[package]] name = "openssl-probe" version = "0.1.5" @@ -1304,12 +1236,6 @@ dependencies = [ "windows-targets", ] -[[package]] -name = "pathdiff" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" - [[package]] name = "percent-encoding" version = "2.3.1" @@ -1334,15 +1260,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" -[[package]] -name = "proc-macro-crate" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" -dependencies = [ - "toml_edit", -] - [[package]] name = "proc-macro2" version = "1.0.93" @@ -1407,47 +1324,12 @@ dependencies = [ "bitflags", ] -[[package]] -name = "regex" -version = "1.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata", - "regex-syntax", -] - -[[package]] -name = "regex-automata" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", -] - [[package]] name = "regex-lite" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" -[[package]] -name = "regex-syntax" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" - -[[package]] -name = "relative-path" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" - [[package]] name = "remove_dir_all" version = "0.5.3" @@ -1472,36 +1354,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "rstest" -version = "0.24.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03e905296805ab93e13c1ec3a03f4b6c4f35e9498a3d5fa96dc626d22c03cd89" -dependencies = [ - "futures-timer", - "futures-util", - "rstest_macros", - "rustc_version", -] - -[[package]] -name = "rstest_macros" -version = "0.24.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef0053bbffce09062bee4bcc499b0fbe7a57b879f1efe088d6d8d4c7adcdef9b" -dependencies = [ - "cfg-if", - "glob", - "proc-macro-crate", - "proc-macro2", - "quote", - "regex", - "relative-path", - "rustc_version", - "syn", - "unicode-ident", -] - [[package]] name = "rustc-demangle" version = "0.1.24" diff --git a/Cargo.toml b/Cargo.toml index 160c8af..1765999 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,9 +13,6 @@ serde_yaml = "0.9" tempdir = "0.3" toml = "0.8" comfy-table = "7.1" -regex = "1.11" -open = "5.3" -semver = "1.0" [dependencies.anyhow] version = "1.0" @@ -36,6 +33,3 @@ features = ["derive", "rc"] [dependencies.versions] version = "6.3" features = ["serde"] - -[dev-dependencies] -rstest = "0.24" diff --git a/src/main.rs b/src/main.rs index a9787ad..78730e2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,15 @@ +macro_rules! not_available_for_byoc { + ($command:literal) => { + anyhow::bail!(concat!( + "The command `", + $command, + "` is only available for provisioned configurations" + )) + }; +} + +mod ssh; + use std::{ collections::HashMap, io::{Error, ErrorKind}, @@ -17,15 +29,13 @@ use clap::{Parser, Subcommand}; use comfy_table::{ modifiers, presets, Attribute, Cell, CellAlignment, Color, ContentArrangement, Table, }; -use semver::{Version, VersionReq}; use serde::{Deserialize, Serialize}; use tempdir::TempDir; use tokio::{ fs, - io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, process::{Child, Command}, - time::timeout, }; +use versions::{Requirement, Versioning}; type StrRef = Arc; type PathRef = Arc; @@ -35,90 +45,80 @@ type PathRef = Arc; struct DaftLauncher { #[command(subcommand)] sub_command: SubCommand, - - /// Enable verbose printing - #[arg(short, long, action = clap::ArgAction::Count)] - verbosity: u8, } #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum SubCommand { /// Manage Daft-provisioned clusters (AWS) - Provisioned(ProvisionedCommands), + #[command(subcommand)] + Provisioned(ProvisionedCommand), + /// Manage existing clusters (Kubernetes) - Byoc(ByocCommands), + #[command(subcommand)] + Byoc(ByocCommand), + /// Manage jobs across all cluster types - Job(JobCommands), - /// Manage configurations - Config(ConfigCommands), -} + #[command(subcommand)] + Job(JobCommand), -#[derive(Debug, Parser, Clone, PartialEq, Eq)] -struct ProvisionedCommands { + /// Manage configurations #[command(subcommand)] - command: ProvisionedCommand, + Config(ConfigCommand), } #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum ProvisionedCommand { /// Create a new cluster Up(ConfigPath), + /// Stop a running cluster Down(ConfigPath), + /// Terminate a cluster Kill(ConfigPath), + /// List all clusters List(List), + /// Connect to cluster dashboard Connect(Connect), + /// SSH into cluster head node Ssh(ConfigPath), } -#[derive(Debug, Parser, Clone, PartialEq, Eq)] -struct ByocCommands { - #[command(subcommand)] - command: ByocCommand, -} - #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum ByocCommand { /// Verify connection to existing cluster Verify(ConfigPath), + /// Show cluster information Info(ConfigPath), } -#[derive(Debug, Parser, Clone, PartialEq, Eq)] -struct JobCommands { - #[command(subcommand)] - command: JobCommand, -} - #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum JobCommand { /// Submit a job to the cluster Submit(Submit), + /// Execute SQL queries Sql(Sql), + /// Check job status Status(ConfigPath), + /// View job logs Logs(ConfigPath), } -#[derive(Debug, Parser, Clone, PartialEq, Eq)] -struct ConfigCommands { - #[command(subcommand)] - command: ConfigCommand, -} - #[derive(Debug, Subcommand, Clone, PartialEq, Eq)] enum ConfigCommand { /// Initialize a new configuration Init(Init), + /// Validate configuration Check(ConfigPath), + /// Export configuration to Ray format Export(ConfigPath), } @@ -199,9 +199,8 @@ struct DaftConfig { #[serde(rename_all = "kebab-case", deny_unknown_fields)] struct DaftSetup { name: StrRef, - #[serde(deserialize_with = "parse_version_req")] - version: VersionReq, - provider: DaftProvider, + #[serde(deserialize_with = "parse_requirement")] + version: Requirement, #[serde(default)] dependencies: Vec, #[serde(flatten)] @@ -212,18 +211,11 @@ struct DaftSetup { #[serde(rename_all = "kebab-case", deny_unknown_fields)] enum ProviderConfig { #[serde(rename = "provisioned")] - Provisioned(AwsConfigWithRun), + Provisioned(AwsConfig), #[serde(rename = "byoc")] Byoc(K8sConfig), } -#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "kebab-case", deny_unknown_fields)] -struct AwsConfigWithRun { - #[serde(flatten)] - config: AwsConfig, -} - #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] struct AwsConfig { @@ -323,21 +315,21 @@ fn default_image_id() -> StrRef { "ami-04dd23e62ed049936".into() } -fn parse_version_req<'de, D>(deserializer: D) -> Result +fn parse_requirement<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let raw: StrRef = Deserialize::deserialize(deserializer)?; - let version_req = raw - .parse::() + let requirement = raw + .parse::() .map_err(serde::de::Error::custom)?; let current_version = env!("CARGO_PKG_VERSION") - .parse::() + .parse::() .expect("CARGO_PKG_VERSION must exist"); - if version_req.matches(¤t_version) { - Ok(version_req) + if requirement.matches(¤t_version) { + Ok(requirement) } else { - Err(serde::de::Error::custom(format!("You're running daft-launcher version {current_version}, but your configuration file requires version {version_req}"))) + Err(serde::de::Error::custom(format!("You're running daft-launcher version {current_version}, but your configuration file requires version {requirement}"))) } } @@ -357,7 +349,10 @@ impl FromStr for DaftProvider { match s.to_lowercase().as_str() { "provisioned" => Ok(DaftProvider::Provisioned), "byoc" => Ok(DaftProvider::Byoc), - _ => anyhow::bail!("Invalid provider '{}'. Must be either 'provisioned' or 'byoc'", s), + _ => anyhow::bail!( + "Invalid provider '{}'. Must be either 'provisioned' or 'byoc'", + s + ), } } } @@ -431,11 +426,9 @@ struct RayResources { cpu: usize, } -async fn read_and_convert( - daft_config_path: &Path, - teardown_behaviour: Option, -) -> anyhow::Result<(DaftConfig, Option)> { - let contents = fs::read_to_string(&daft_config_path) +async fn read_daft_config(daft_config_path: impl AsRef) -> anyhow::Result { + let daft_config_path = daft_config_path.as_ref(); + let contents = fs::read_to_string(daft_config_path) .await .map_err(|error| { if let ErrorKind::NotFound = error.kind() { @@ -447,89 +440,105 @@ async fn read_and_convert( error } })?; - let daft_config = toml::from_str::(&contents)?; - - let ray_config = match &daft_config.setup.provider_config { - ProviderConfig::Byoc(_) => None, - ProviderConfig::Provisioned(aws_config) => { - let key_name = aws_config.config.ssh_private_key - .clone() - .file_stem() - .ok_or_else(|| anyhow::anyhow!(r#"Private key doesn't have a name of the format "name.ext""#))? - .to_str() - .ok_or_else(|| anyhow::anyhow!("The file {:?} does not have a valid UTF-8 name", aws_config.config.ssh_private_key))? - .into(); - - let node_config = RayNodeConfig { - key_name, - instance_type: aws_config.config.instance_type.clone(), - image_id: aws_config.config.image_id.clone(), - iam_instance_profile: aws_config.config.iam_instance_profile_name.clone().map(|name| IamInstanceProfile { name }), - }; - - Some(RayConfig { - cluster_name: daft_config.setup.name.clone(), - max_workers: aws_config.config.number_of_workers, - provider: RayProvider { - r#type: "aws".into(), - region: aws_config.config.region.clone(), - cache_stopped_nodes: teardown_behaviour.map(TeardownBehaviour::to_cache_stopped_nodes), - }, - auth: RayAuth { - ssh_user: aws_config.config.ssh_user.clone(), - ssh_private_key: aws_config.config.ssh_private_key.clone(), - }, - available_node_types: vec![ - ( - "ray.head.default".into(), - RayNodeType { - max_workers: aws_config.config.number_of_workers, - node_config: node_config.clone(), - resources: Some(RayResources { cpu: 0 }), - }, - ), - ( - "ray.worker.default".into(), - RayNodeType { - max_workers: aws_config.config.number_of_workers, - node_config, - resources: None, - }, - ), - ] - .into_iter() - .collect(), - setup_commands: { - let mut commands = vec![ - "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), - "uv python install 3.12".into(), - "uv python pin 3.12".into(), - "uv venv".into(), - "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), - "source ~/.bashrc".into(), - "uv pip install boto3 pip ray[default] getdaft py-spy deltalake".into(), - ]; - if !daft_config.setup.dependencies.is_empty() { - let deps = daft_config.setup.dependencies - .iter() - .map(|dep| format!(r#""{dep}""#)) - .collect::>() - .join(" "); - let deps = format!("uv pip install {deps}").into(); - commands.push(deps); - } - commands - }, - }) - } + + Ok(daft_config) +} + +fn convert( + daft_config: &DaftConfig, + teardown_behaviour: Option, +) -> anyhow::Result { + let ProviderConfig::Provisioned(aws_config) = &daft_config.setup.provider_config else { + unreachable!("Can only convert to a ray config-file for provisioned configurations; this should be statically determined"); }; - Ok((daft_config, ray_config)) + let key_name = aws_config + .ssh_private_key + .clone() + .file_stem() + .ok_or_else(|| { + anyhow::anyhow!(r#"Private key doesn't have a name of the format "name.ext""#) + })? + .to_str() + .ok_or_else(|| { + anyhow::anyhow!( + "The file {:?} does not have a valid UTF-8 name", + aws_config.ssh_private_key + ) + })? + .into(); + + let node_config = RayNodeConfig { + key_name, + instance_type: aws_config.instance_type.clone(), + image_id: aws_config.image_id.clone(), + iam_instance_profile: aws_config + .iam_instance_profile_name + .clone() + .map(|name| IamInstanceProfile { name }), + }; + + Ok(RayConfig { + cluster_name: daft_config.setup.name.clone(), + max_workers: aws_config.number_of_workers, + provider: RayProvider { + r#type: "aws".into(), + region: aws_config.region.clone(), + cache_stopped_nodes: teardown_behaviour.map(TeardownBehaviour::to_cache_stopped_nodes), + }, + auth: RayAuth { + ssh_user: aws_config.ssh_user.clone(), + ssh_private_key: aws_config.ssh_private_key.clone(), + }, + available_node_types: vec![ + ( + "ray.head.default".into(), + RayNodeType { + max_workers: aws_config.number_of_workers, + node_config: node_config.clone(), + resources: Some(RayResources { cpu: 0 }), + }, + ), + ( + "ray.worker.default".into(), + RayNodeType { + max_workers: aws_config.number_of_workers, + node_config, + resources: None, + }, + ), + ] + .into_iter() + .collect(), + setup_commands: { + let mut commands = vec![ + "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), + "uv python install 3.12".into(), + "uv python pin 3.12".into(), + "uv venv".into(), + "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), + "source ~/.bashrc".into(), + "uv pip install boto3 pip ray[default] getdaft py-spy deltalake".into(), + ]; + if !daft_config.setup.dependencies.is_empty() { + let deps = daft_config + .setup + .dependencies + .iter() + .map(|dep| format!(r#""{dep}""#)) + .collect::>() + .join(" "); + let deps = format!("uv pip install {deps}").into(); + commands.push(deps); + } + commands + }, + }) } -async fn write_ray_config(ray_config: RayConfig, dest: impl AsRef) -> anyhow::Result<()> { - let ray_config = serde_yaml::to_string(&ray_config)?; +async fn write_ray_config(ray_config: &RayConfig, dest: impl AsRef) -> anyhow::Result<()> { + let ray_config = serde_yaml::to_string(ray_config)?; fs::write(dest, ray_config).await?; Ok(()) } @@ -551,14 +560,14 @@ impl SpinDirection { #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum TeardownBehaviour { - Stop, + Down, Kill, } impl TeardownBehaviour { fn to_cache_stopped_nodes(self) -> bool { match self { - Self::Stop => true, + Self::Down => true, Self::Kill => false, } } @@ -740,129 +749,7 @@ async fn assert_is_logged_in_with_aws() -> anyhow::Result<()> { } } -async fn get_region(region: Option, config: impl AsRef) -> anyhow::Result { - let config = config.as_ref(); - Ok(if let Some(region) = region { - region - } else if config.exists() { - let (daft_config, _) = read_and_convert(&config, None).await?; - match &daft_config.setup.provider_config { - ProviderConfig::Provisioned(aws_config) => aws_config.config.region.clone(), - ProviderConfig::Byoc(_) => "us-west-2".into(), - } - } else { - "us-west-2".into() - }) -} - -async fn get_head_node_ip(ray_path: impl AsRef) -> anyhow::Result { - let mut ray_command = Command::new("ray") - .arg("get-head-ip") - .arg(ray_path.as_ref()) - .stdout(Stdio::piped()) - .spawn()?; - - let mut tail_command = Command::new("tail") - .args(["-n", "1"]) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .spawn()?; - - let mut writer = tail_command.stdin.take().expect("stdin must exist"); - - tokio::spawn(async move { - let mut reader = BufReader::new(ray_command.stdout.take().expect("stdout must exist")); - let mut buffer = Vec::new(); - reader.read_to_end(&mut buffer).await?; - writer.write_all(&buffer).await?; - Ok::<_, anyhow::Error>(()) - }); - let output = tail_command.wait_with_output().await?; - if !output.status.success() { - anyhow::bail!("Failed to fetch ip address of head node"); - }; - let addr = String::from_utf8_lossy(&output.stdout) - .trim() - .parse::()?; - Ok(addr) -} - -async fn ssh(ray_path: impl AsRef, aws_config: &AwsConfig) -> anyhow::Result<()> { - let addr = get_head_node_ip(ray_path).await?; - let exit_status = Command::new("ssh") - .arg("-i") - .arg(aws_config.ssh_private_key.as_ref()) - .arg(format!("{}@{}", aws_config.ssh_user, addr)) - .kill_on_drop(true) - .spawn()? - .wait() - .await?; - - if exit_status.success() { - Ok(()) - } else { - Err(anyhow::anyhow!("Failed to ssh into the ray cluster")) - } -} - -async fn establish_ssh_portforward( - ray_path: impl AsRef, - aws_config: &AwsConfig, - port: Option, -) -> anyhow::Result { - let addr = get_head_node_ip(ray_path).await?; - let port = port.unwrap_or(8265); - let mut child = Command::new("ssh") - .arg("-N") - .arg("-i") - .arg(aws_config.ssh_private_key.as_ref()) - .arg("-L") - .arg(format!("{port}:localhost:8265")) - .arg(format!("{}@{}", aws_config.ssh_user, addr)) - .arg("-v") - .stderr(Stdio::piped()) - .kill_on_drop(true) - .spawn()?; - - // We wait for the ssh port-forwarding process to write a specific string to the - // output. - // - // This is a little hacky (and maybe even incorrect across platforms) since we - // are just parsing the output and observing if a specific string has been - // printed. It may be incorrect across platforms because the SSH standard - // does *not* specify a standard "success-message" to printout if the ssh - // port-forward was successful. - timeout(Duration::from_secs(5), { - let stderr = child.stderr.take().expect("stderr must exist"); - async move { - let mut lines = BufReader::new(stderr).lines(); - loop { - let Some(line) = lines.next_line().await? else { - anyhow::bail!("Failed to establish ssh port-forward to {addr}"); - }; - if line.starts_with(format!("Authenticated to {addr}").as_str()) { - break Ok(()); - } - } - } - }) - .await - .map_err(|_| anyhow::anyhow!("Establishing an ssh port-forward to {addr} timed out"))??; - - Ok(child) -} - -struct PortForward { - process: Child, -} - -impl Drop for PortForward { - fn drop(&mut self) { - let _ = self.process.start_kill(); - } -} - -async fn establish_kubernetes_port_forward(namespace: Option<&str>) -> anyhow::Result { +async fn establish_kubernetes_port_forward(namespace: Option<&str>) -> anyhow::Result { let namespace = namespace.unwrap_or("default"); let output = Command::new("kubectl") .arg("get") @@ -877,19 +764,28 @@ async fn establish_kubernetes_port_forward(namespace: Option<&str>) -> anyhow::R .output() .await?; if !output.status.success() { - return Err(anyhow::anyhow!("Failed to get Ray head node services with kubectl in namespace {}", namespace)); + return Err(anyhow::anyhow!( + "Failed to get Ray head node services with kubectl in namespace {}", + namespace + )); } let stdout = String::from_utf8_lossy(&output.stdout); if stdout.trim().is_empty() { - return Err(anyhow::anyhow!("Ray head node service not found in namespace {}", namespace)); + return Err(anyhow::anyhow!( + "Ray head node service not found in namespace {}", + namespace + )); } - + let head_node_service_name = stdout .lines() .next() .ok_or_else(|| anyhow::anyhow!("Failed to get the head node service name"))?; - println!("Found Ray head node service: {} in namespace {}", head_node_service_name, namespace); + println!( + "Found Ray head node service: {} in namespace {}", + head_node_service_name, namespace + ); // Start port-forward with stderr piped so we can monitor the process let mut port_forward = Command::new("kubectl") @@ -899,7 +795,7 @@ async fn establish_kubernetes_port_forward(namespace: Option<&str>) -> anyhow::R .arg(format!("svc/{}", head_node_service_name)) .arg("8265:8265") .stderr(Stdio::piped()) - .stdout(Stdio::piped()) // Capture stdout too + .stdout(Stdio::piped()) // Capture stdout too .kill_on_drop(true) .spawn()?; @@ -909,16 +805,14 @@ async fn establish_kubernetes_port_forward(namespace: Option<&str>) -> anyhow::R // Check if process is still running match port_forward.try_wait()? { Some(status) => { - return Err(anyhow::anyhow!( + anyhow::bail!( "Port-forward process exited immediately with status: {}", status - )); + ); } None => { println!("Port-forwarding started successfully"); - Ok(PortForward { - process: port_forward, - }) + Ok(port_forward) } } } @@ -955,38 +849,24 @@ async fn submit_k8s( } } -async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> { - match daft_launcher.sub_command { - SubCommand::Config(config_cmd) => { - config_cmd.command.run(daft_launcher.verbosity).await - } - SubCommand::Job(job_cmd) => { - job_cmd.command.run(daft_launcher.verbosity).await - } - SubCommand::Provisioned(provisioned_cmd) => { - provisioned_cmd.command.run(daft_launcher.verbosity).await - } - SubCommand::Byoc(byoc_cmd) => { - byoc_cmd.command.run(daft_launcher.verbosity).await - } - } -} - #[tokio::main] async fn main() -> anyhow::Result<()> { - run(DaftLauncher::parse()).await + DaftLauncher::parse().run().await } -// Helper function to get AWS config -fn get_aws_config(config: &DaftConfig) -> anyhow::Result<&AwsConfig> { - match &config.setup.provider_config { - ProviderConfig::Provisioned(aws_config) => Ok(&aws_config.config), - ProviderConfig::Byoc(_) => anyhow::bail!("Expected provisioned configuration but found Kubernetes configuration"), +impl DaftLauncher { + async fn run(&self) -> anyhow::Result<()> { + match &self.sub_command { + SubCommand::Config(config_cmd) => config_cmd.run().await, + SubCommand::Job(job_cmd) => job_cmd.run().await, + SubCommand::Provisioned(provisioned_cmd) => provisioned_cmd.run().await, + SubCommand::Byoc(byoc_cmd) => byoc_cmd.run().await, + } } } impl ConfigCommand { - async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { match self { ConfigCommand::Init(Init { path, provider }) => { #[cfg(not(test))] @@ -1001,14 +881,11 @@ impl ConfigCommand { fs::write(path, contents).await?; } ConfigCommand::Check(ConfigPath { config }) => { - let _ = read_and_convert(&config, None).await?; + let _ = read_daft_config(config).await?; } ConfigCommand::Export(ConfigPath { config }) => { - let (_, ray_config) = read_and_convert(&config, None).await?; - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + let daft_config = read_daft_config(config).await?; + let ray_config = convert(&daft_config, None)?; let ray_config_str = serde_yaml::to_string(&ray_config)?; println!("{ray_config_str}"); } @@ -1018,27 +895,28 @@ impl ConfigCommand { } impl JobCommand { - async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { match self { - JobCommand::Submit(Submit { config_path, job_name }) => { - let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; - let daft_job = daft_config - .jobs - .get(job_name) - .ok_or_else(|| anyhow::anyhow!("A job with the name {job_name} was not found"))?; + JobCommand::Submit(Submit { + config_path, + job_name, + }) => { + let daft_config = read_daft_config(&config_path.config).await?; + let daft_job = daft_config.jobs.get(job_name).ok_or_else(|| { + anyhow::anyhow!("A job with the name {job_name} was not found") + })?; match &daft_config.setup.provider_config { - ProviderConfig::Provisioned(_) => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + ProviderConfig::Provisioned(aws_config) => { + assert_is_logged_in_with_aws().await?; + + let ray_config = convert(&daft_config, None)?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - - let aws_config = get_aws_config(&daft_config)?; + write_ray_config(&ray_config, &ray_path).await?; + // Start port forwarding - it will be automatically killed when _port_forward is dropped - let _port_forward = establish_ssh_portforward(ray_path, aws_config, Some(8265)).await?; + let _port_forward = + ssh::ssh_portforward(ray_path, aws_config, None).await?; // Give the port-forward a moment to fully establish tokio::time::sleep(Duration::from_secs(1)).await; @@ -1070,7 +948,7 @@ impl JobCommand { } } JobCommand::Sql(Sql { sql, config_path }) => { - let (daft_config, _) = read_and_convert(&config_path.config, None).await?; + let daft_config = read_daft_config(&config_path.config).await?; match &daft_config.setup.provider_config { ProviderConfig::Provisioned(_) => { anyhow::bail!("'sql' command is only available for BYOC configurations"); @@ -1087,133 +965,108 @@ impl JobCommand { } } } - JobCommand::Status(_) => { - anyhow::bail!("Job status command not yet implemented"); - } - JobCommand::Logs(_) => { - anyhow::bail!("Job logs command not yet implemented"); - } + JobCommand::Status(..) => todo!(), + JobCommand::Logs(..) => todo!(), } Ok(()) } } impl ProvisionedCommand { - async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { match self { ProvisionedCommand::Up(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, None).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + let daft_config = read_daft_config(config).await?; + match daft_config.setup.provider_config { + ProviderConfig::Provisioned(..) => { assert_is_logged_in_with_aws().await?; + let ray_config = convert(&daft_config, None)?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; + write_ray_config(&ray_config, &ray_path).await?; run_ray_up_or_down_command(SpinDirection::Up, ray_path).await?; } - DaftProvider::Byoc => { - anyhow::bail!("'up' command is only available for provisioned configurations"); - } + ProviderConfig::Byoc(..) => not_available_for_byoc!("up"), } } ProvisionedCommand::Down(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Stop)).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + let daft_config = read_daft_config(config).await?; + match daft_config.setup.provider_config { + ProviderConfig::Provisioned(..) => { assert_is_logged_in_with_aws().await?; + let ray_config = convert(&daft_config, Some(TeardownBehaviour::Down))?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; + write_ray_config(&ray_config, &ray_path).await?; run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; } - DaftProvider::Byoc => { - anyhow::bail!("'down' command is only available for provisioned configurations"); - } + ProviderConfig::Byoc(..) => not_available_for_byoc!("down"), } } ProvisionedCommand::Kill(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Kill)).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + let daft_config = read_daft_config(config).await?; + match daft_config.setup.provider_config { + ProviderConfig::Provisioned(..) => { assert_is_logged_in_with_aws().await?; + let ray_config = convert(&daft_config, Some(TeardownBehaviour::Kill))?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; + write_ray_config(&ray_config, &ray_path).await?; run_ray_up_or_down_command(SpinDirection::Down, ray_path).await?; } - DaftProvider::Byoc => { - anyhow::bail!("'kill' command is only available for provisioned configurations"); - } + ProviderConfig::Byoc(..) => not_available_for_byoc!("kill"), } } - ProvisionedCommand::List(List { config_path, region, head, running }) => { - let (daft_config, _) = read_and_convert(&config_path.config, None).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { + ProvisionedCommand::List(List { + config_path, + region, + head, + running, + }) => { + let daft_config = read_daft_config(&config_path.config).await?; + match &daft_config.setup.provider_config { + ProviderConfig::Provisioned(aws_config) => { assert_is_logged_in_with_aws().await?; - let aws_config = get_aws_config(&daft_config)?; + let region = region.as_ref().unwrap_or_else(|| &aws_config.region); let instances = get_ray_clusters_from_aws(region.clone()).await?; print_instances(&instances, *head, *running); } - DaftProvider::Byoc => { - anyhow::bail!("'list' command is only available for provisioned configurations"); - } + ProviderConfig::Byoc(..) => not_available_for_byoc!("list"), } } - ProvisionedCommand::Connect(Connect { port, config_path }) => { - let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + &ProvisionedCommand::Connect(Connect { + port, + ref config_path, + }) => { + let daft_config = read_daft_config(&config_path.config).await?; + match &daft_config.setup.provider_config { + ProviderConfig::Provisioned(aws_config) => { assert_is_logged_in_with_aws().await?; + let ray_config = convert(&daft_config, None)?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - let aws_config = get_aws_config(&daft_config)?; - let _ = establish_ssh_portforward(ray_path, aws_config, Some(*port)) + write_ray_config(&ray_config, &ray_path).await?; + let _ = ssh::ssh_portforward(ray_path, aws_config, Some(port)) .await? .wait_with_output() .await?; } - DaftProvider::Byoc => { - anyhow::bail!("'connect' command is only available for provisioned configurations"); - } + ProviderConfig::Byoc(..) => not_available_for_byoc!("connect"), } } ProvisionedCommand::Ssh(ConfigPath { config }) => { - let (daft_config, ray_config) = read_and_convert(&config, None).await?; - match daft_config.setup.provider { - DaftProvider::Provisioned => { - if ray_config.is_none() { - anyhow::bail!("Failed to find Ray config in config file"); - } - let ray_config = ray_config.unwrap(); + let daft_config = read_daft_config(config).await?; + match &daft_config.setup.provider_config { + ProviderConfig::Provisioned(aws_config) => { assert_is_logged_in_with_aws().await?; + let ray_config = convert(&daft_config, None)?; let (_temp_dir, ray_path) = create_temp_ray_file()?; - write_ray_config(ray_config, &ray_path).await?; - let aws_config = get_aws_config(&daft_config)?; - ssh(ray_path, aws_config).await?; - } - DaftProvider::Byoc => { - anyhow::bail!("'ssh' command is only available for provisioned configurations"); + write_ray_config(&ray_config, &ray_path).await?; + ssh::ssh(ray_path, aws_config).await?; } + ProviderConfig::Byoc(..) => not_available_for_byoc!("ssh"), } } } @@ -1222,15 +1075,10 @@ impl ProvisionedCommand { } impl ByocCommand { - async fn run(&self, _verbosity: u8) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { match self { - ByocCommand::Verify(ConfigPath { config: _ }) => { - anyhow::bail!("Verify command not yet implemented"); - } - ByocCommand::Info(ConfigPath { config: _ }) => { - anyhow::bail!("Info command not yet implemented"); - } + ByocCommand::Verify(..) => todo!(), + ByocCommand::Info(..) => todo!(), } - Ok(()) } } diff --git a/src/ssh.rs b/src/ssh.rs index e02a674..6e64d46 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -6,7 +6,7 @@ use tokio::{ time::timeout, }; -use crate::DaftConfig; +use crate::AwsConfig; async fn get_head_node_ip(ray_path: impl AsRef) -> anyhow::Result { let mut ray_command = Command::new("ray") @@ -42,18 +42,18 @@ async fn get_head_node_ip(ray_path: impl AsRef) -> anyhow::Result, - daft_config: &DaftConfig, + aws_config: &AwsConfig, portforward: Option, verbose: bool, ) -> anyhow::Result<(Ipv4Addr, Command)> { - let user = daft_config.setup.ssh_user.as_ref(); + let user = aws_config.ssh_user.as_ref(); let addr = get_head_node_ip(ray_path).await?; let mut command = Command::new("ssh"); command .arg("-i") - .arg(daft_config.setup.ssh_private_key.as_ref()) + .arg(aws_config.ssh_private_key.as_ref()) .arg("-o") .arg("StrictHostKeyChecking=no"); @@ -73,24 +73,26 @@ async fn generate_ssh_command( Ok((addr, command)) } -pub async fn ssh(ray_path: impl AsRef, daft_config: &DaftConfig) -> anyhow::Result<()> { - let (_, mut command) = generate_ssh_command(ray_path, daft_config, None, false).await?; +pub async fn ssh(ray_path: impl AsRef, aws_config: &AwsConfig) -> anyhow::Result<()> { + let (addr, mut command) = generate_ssh_command(ray_path, aws_config, None, false).await?; let exit_status = command.spawn()?.wait().await?; if exit_status.success() { Ok(()) } else { - Err(anyhow::anyhow!("Failed to ssh into the ray cluster")) + Err(anyhow::anyhow!( + "Failed to ssh into the ray cluster at address {addr}" + )) } } pub async fn ssh_portforward( ray_path: impl AsRef, - daft_config: &DaftConfig, + aws_config: &AwsConfig, portforward: Option, ) -> anyhow::Result { let (addr, mut command) = generate_ssh_command( ray_path, - daft_config, + aws_config, Some(portforward.unwrap_or(8265)), true, ) diff --git a/src/template_byoc.toml b/src/template_byoc.toml index e70adc0..ee81e6c 100644 --- a/src/template_byoc.toml +++ b/src/template_byoc.toml @@ -1,9 +1,8 @@ -# This is a template configuration file for daft-launcher with BYOC provider +# This is a template configuration file for daft-launcher with a BYOC provider + [setup] name = "my-daft-cluster" version = "" -provider = "byoc" -# TODO: support dependencies [setup.byoc] namespace = "default" # Optional, defaults to "default" @@ -12,4 +11,4 @@ namespace = "default" # Optional, defaults to "default" [[job]] name = "example-job" command = "python my_script.py" -working-dir = "~/my_project" \ No newline at end of file +working-dir = "~/my_project" diff --git a/src/template_provisioned.toml b/src/template_provisioned.toml index 4299fbf..f6c1f84 100644 --- a/src/template_provisioned.toml +++ b/src/template_provisioned.toml @@ -1,8 +1,8 @@ -# This is a template configuration file for daft-launcher with provisioned provider +# This is a template configuration file for daft-launcher with a provisioned provider + [setup] name = "my-daft-cluster" version = "" -provider = "provisioned" dependencies = [] # Optional additional Python packages to install # Provisioned (AWS) configuration @@ -19,4 +19,4 @@ iam-instance-profile-name = "YourInstanceProfileName" # Optional [[job]] name = "example-job" command = "python my_script.py" -working-dir = "~/my_project" \ No newline at end of file +working-dir = "~/my_project" diff --git a/src/tests.rs b/src/tests.rs index 960bb21..b4f4c27 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -26,11 +26,17 @@ async fn get_path() -> (TempDir, PathBuf) { /// `template.toml` file before writing it. Thus, the outputted configuration /// file does not *exactly* match the original `template.toml` file. #[tokio::test] -async fn test_init() { +#[rstest::rstest] +#[case(DaftProvider::Byoc)] +#[case(DaftProvider::Provisioned)] +async fn test_init(#[case] provider: DaftProvider) { let (_temp_dir, path) = get_path().await; run(DaftLauncher { - sub_command: SubCommand::Init(Init { path: path.clone() }), + sub_command: SubCommand::Init(Init { + path: path.clone(), + provider, + }), verbosity: 0, }) .await From 9f31ab5fa4261983150003a921199597c020262e Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 17:47:35 -0800 Subject: [PATCH 05/15] Edit derive macros --- src/main.rs | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/src/main.rs b/src/main.rs index 78730e2..2688997 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,7 @@ use std::{ use anyhow::bail; use aws_config::{BehaviorVersion, Region}; use aws_sdk_ec2::{types::InstanceStateName, Client}; -use clap::{Parser, Subcommand}; +use clap::{Parser, Subcommand, ValueEnum}; use comfy_table::{ modifiers, presets, Attribute, Cell, CellAlignment, Color, ContentArrangement, Table, }; @@ -333,36 +333,19 @@ where } } -#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "kebab-case", deny_unknown_fields)] +#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] enum DaftProvider { - #[serde(rename = "provisioned")] Provisioned, - #[serde(rename = "byoc")] Byoc, } -impl FromStr for DaftProvider { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "provisioned" => Ok(DaftProvider::Provisioned), - "byoc" => Ok(DaftProvider::Byoc), - _ => anyhow::bail!( - "Invalid provider '{}'. Must be either 'provisioned' or 'byoc'", - s - ), - } - } -} - impl ToString for DaftProvider { fn to_string(&self) -> String { match self { - DaftProvider::Provisioned => "provisioned".to_string(), - DaftProvider::Byoc => "byoc".to_string(), + DaftProvider::Provisioned => "provisioned", + DaftProvider::Byoc => "byoc", } + .to_string() } } From e1dbb859cb1bfef9df9c7fa4c321c99d9f2b565d Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 17:53:20 -0800 Subject: [PATCH 06/15] Move dependencies key-value around --- src/main.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/main.rs b/src/main.rs index 2688997..ed2d160 100644 --- a/src/main.rs +++ b/src/main.rs @@ -201,8 +201,6 @@ struct DaftSetup { name: StrRef, #[serde(deserialize_with = "parse_requirement")] version: Requirement, - #[serde(default)] - dependencies: Vec, #[serde(flatten)] provider_config: ProviderConfig, } @@ -210,9 +208,7 @@ struct DaftSetup { #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] enum ProviderConfig { - #[serde(rename = "provisioned")] Provisioned(AwsConfig), - #[serde(rename = "byoc")] Byoc(K8sConfig), } @@ -229,7 +225,10 @@ struct AwsConfig { instance_type: StrRef, #[serde(default = "default_image_id")] image_id: StrRef, + #[serde(skip_serializing_if = "Option::is_none")] iam_instance_profile_name: Option, + #[serde(default)] + dependencies: Vec, } #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] @@ -504,9 +503,8 @@ fn convert( "source ~/.bashrc".into(), "uv pip install boto3 pip ray[default] getdaft py-spy deltalake".into(), ]; - if !daft_config.setup.dependencies.is_empty() { - let deps = daft_config - .setup + if !aws_config.dependencies.is_empty() { + let deps = aws_config .dependencies .iter() .map(|dep| format!(r#""{dep}""#)) From ea98371d538a0c66a1f34dc64ae879cba888e24b Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 17:57:57 -0800 Subject: [PATCH 07/15] Remove unnecessary pub visibility modifier --- src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index ed2d160..2b7c4b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -592,7 +592,7 @@ struct AwsInstance { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum NodeType { +enum NodeType { Head, Worker, } From dbec79f8a7d142dc0b953599d2d268ea3be3917f Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 18:24:02 -0800 Subject: [PATCH 08/15] Fix template and add proper requirement version --- src/main.rs | 2 +- src/template_provisioned.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main.rs b/src/main.rs index 2b7c4b0..44a1915 100644 --- a/src/main.rs +++ b/src/main.rs @@ -858,7 +858,7 @@ impl ConfigCommand { DaftProvider::Byoc => include_str!("template_byoc.toml"), DaftProvider::Provisioned => include_str!("template_provisioned.toml"), } - .replace("", env!("CARGO_PKG_VERSION")); + .replace("", concat!("=", env!("CARGO_PKG_VERSION"))); fs::write(path, contents).await?; } ConfigCommand::Check(ConfigPath { config }) => { diff --git a/src/template_provisioned.toml b/src/template_provisioned.toml index f6c1f84..268c3d7 100644 --- a/src/template_provisioned.toml +++ b/src/template_provisioned.toml @@ -3,7 +3,6 @@ [setup] name = "my-daft-cluster" version = "" -dependencies = [] # Optional additional Python packages to install # Provisioned (AWS) configuration [setup.provisioned] @@ -14,6 +13,7 @@ ssh-private-key = "~/.ssh/id_rsa" instance-type = "i3.2xlarge" image-id = "ami-04dd23e62ed049936" iam-instance-profile-name = "YourInstanceProfileName" # Optional +dependencies = [] # Optional additional Python packages to install # Job definitions [[job]] From 071bf017e9d61394e3d34beaa1410542677bbe11 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 18:44:02 -0800 Subject: [PATCH 09/15] Re-implement sql submission for provisioned configs --- src/main.rs | 87 +++++++++++++++++++++++++++-------------------------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/src/main.rs b/src/main.rs index 44a1915..7d95381 100644 --- a/src/main.rs +++ b/src/main.rs @@ -798,27 +798,18 @@ async fn establish_kubernetes_port_forward(namespace: Option<&str>) -> anyhow::R } } -async fn submit_k8s( - working_dir: &Path, +async fn submit( + working_dir: impl AsRef, command_segments: impl AsRef<[&str]>, - namespace: Option<&str>, ) -> anyhow::Result<()> { - let command_segments = command_segments.as_ref(); - - // Start port forwarding - it will be automatically killed when _port_forward is dropped - let _port_forward = establish_kubernetes_port_forward(namespace).await?; - - // Give the port-forward a moment to fully establish - tokio::time::sleep(Duration::from_secs(1)).await; - // Submit the job let exit_status = Command::new("ray") .env("PYTHONUNBUFFERED", "1") .args(["job", "submit", "--address", "http://localhost:8265"]) .arg("--working-dir") - .arg(working_dir) + .arg(working_dir.as_ref()) .arg("--") - .args(command_segments) + .args(command_segments.as_ref()) .spawn()? .wait() .await?; @@ -830,6 +821,22 @@ async fn submit_k8s( } } +async fn submit_k8s( + working_dir: impl AsRef, + command_segments: impl AsRef<[&str]>, + namespace: Option<&str>, +) -> anyhow::Result<()> { + // Start port forwarding - it will be automatically killed when _port_forward is dropped + let _port_forward = establish_kubernetes_port_forward(namespace).await?; + + // Give the port-forward a moment to fully establish + tokio::time::sleep(Duration::from_secs(1)).await; + + submit(working_dir, command_segments).await?; + + Ok(()) +} + #[tokio::main] async fn main() -> anyhow::Result<()> { DaftLauncher::parse().run().await @@ -887,6 +894,9 @@ impl JobCommand { anyhow::anyhow!("A job with the name {job_name} was not found") })?; + let working_dir = daft_job.working_dir.as_ref(); + let command_segments = daft_job.command.as_ref().split(' ').collect::>(); + match &daft_config.setup.provider_config { ProviderConfig::Provisioned(aws_config) => { assert_is_logged_in_with_aws().await?; @@ -895,33 +905,13 @@ impl JobCommand { let (_temp_dir, ray_path) = create_temp_ray_file()?; write_ray_config(&ray_config, &ray_path).await?; - // Start port forwarding - it will be automatically killed when _port_forward is dropped - let _port_forward = - ssh::ssh_portforward(ray_path, aws_config, None).await?; - - // Give the port-forward a moment to fully establish - tokio::time::sleep(Duration::from_secs(1)).await; - - // Submit the job - let exit_status = Command::new("ray") - .env("PYTHONUNBUFFERED", "1") - .args(["job", "submit", "--address", "http://localhost:8265"]) - .arg("--working-dir") - .arg(daft_job.working_dir.as_ref()) - .arg("--") - .args(daft_job.command.as_ref().split(' ').collect::>()) - .spawn()? - .wait() - .await?; - - if !exit_status.success() { - anyhow::bail!("Failed to submit job to the ray cluster"); - } + let _child = ssh::ssh_portforward(ray_path, aws_config, None).await?; + submit(working_dir, command_segments).await?; } ProviderConfig::Byoc(k8s_config) => { submit_k8s( - daft_job.working_dir.as_ref(), - daft_job.command.as_ref().split(' ').collect::>(), + working_dir, + command_segments, k8s_config.namespace.as_deref(), ) .await?; @@ -930,16 +920,27 @@ impl JobCommand { } JobCommand::Sql(Sql { sql, config_path }) => { let daft_config = read_daft_config(&config_path.config).await?; + let (temp_sql_dir, sql_path) = create_temp_file("sql.py")?; + fs::write(sql_path, include_str!("sql.py")).await?; + + let working_dir = temp_sql_dir.path(); + let command_segments = vec!["python", "sql.py", sql.as_ref()]; + match &daft_config.setup.provider_config { - ProviderConfig::Provisioned(_) => { - anyhow::bail!("'sql' command is only available for BYOC configurations"); + ProviderConfig::Provisioned(aws_config) => { + assert_is_logged_in_with_aws().await?; + + let ray_config = convert(&daft_config, None)?; + let (_temp_dir, ray_path) = create_temp_ray_file()?; + write_ray_config(&ray_config, &ray_path).await?; + + let _child = ssh::ssh_portforward(ray_path, aws_config, None).await?; + submit(working_dir, command_segments).await?; } ProviderConfig::Byoc(k8s_config) => { - let (temp_sql_dir, sql_path) = create_temp_file("sql.py")?; - fs::write(sql_path, include_str!("sql.py")).await?; submit_k8s( - temp_sql_dir.path(), - vec!["python", "sql.py", sql.as_ref()], + working_dir, + command_segments, k8s_config.namespace.as_deref(), ) .await?; From d8274d273d57d7e9b4c1a0562bf9a3b209575e6d Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 18:44:28 -0800 Subject: [PATCH 10/15] Remove tests (since they're not being used right now) --- src/tests.rs | 223 --------------------------------------------------- 1 file changed, 223 deletions(-) delete mode 100644 src/tests.rs diff --git a/src/tests.rs b/src/tests.rs deleted file mode 100644 index b4f4c27..0000000 --- a/src/tests.rs +++ /dev/null @@ -1,223 +0,0 @@ -use tokio::fs; - -use super::*; -use crate::{ConfigCommand, ConfigCommands, ConfigPath, DaftLauncher, Init, SubCommand}; - -fn not_found_okay(result: std::io::Result<()>) -> std::io::Result<()> { - match result { - Ok(()) => Ok(()), - Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), - Err(err) => Err(err), - } -} - -async fn get_path() -> (TempDir, PathBuf) { - let (temp_dir, path) = create_temp_file(".test.toml").unwrap(); - not_found_okay(fs::remove_file(path.as_ref()).await).unwrap(); - not_found_okay(fs::remove_dir_all(path.as_ref()).await).unwrap(); - (temp_dir, PathBuf::from(path.as_ref())) -} - -/// This tests the creation of a daft-launcher configuration file. -/// -/// # Note -/// This does *not* check the contents of the newly created configuration file. -/// The reason is because we perform some minor templatization of the -/// `template.toml` file before writing it. Thus, the outputted configuration -/// file does not *exactly* match the original `template.toml` file. -#[tokio::test] -#[rstest::rstest] -#[case(DaftProvider::Byoc)] -#[case(DaftProvider::Provisioned)] -async fn test_init(#[case] provider: DaftProvider) { - let (_temp_dir, path) = get_path().await; - - run(DaftLauncher { - sub_command: SubCommand::Init(Init { - path: path.clone(), - provider, - }), - verbosity: 0, - }) - .await - .unwrap(); - - assert!(path.exists()); - assert!(path.is_file()); -} - -/// Tests to make sure that `daft check` properly asserts the schema of the -/// newly created daft-launcher configuration file. -#[tokio::test] -async fn test_check() { - let (_temp_dir, path) = get_path().await; - - run(DaftLauncher { - sub_command: SubCommand::Init(Init { path: path.clone() }), - verbosity: 0, - }) - .await - .unwrap(); - run(DaftLauncher { - sub_command: SubCommand::Check(ConfigPath { config: path }), - verbosity: 0, - }) - .await - .unwrap(); -} - -/// This tests the core conversion functionality, from a `DaftConfig` to a -/// `RayConfig`. -/// -/// # Note -/// Fields which expect a filesystem path (i.e., "ssh_private_key" and -/// "job.working_dir") are not checked for existence. Therefore, you can really -/// put any value in there and this test will pass. -/// -/// This is because the point of this test is not to check for existence, but -/// rather to test the mapping from `DaftConfig` to `RayConfig`. -#[rstest::rstest] -#[case(simple_config())] -fn test_conversion( - #[case] (daft_config, teardown_behaviour, expected): ( - DaftConfig, - Option, - RayConfig, - ), -) { - let actual = convert(&daft_config, teardown_behaviour).unwrap(); - assert_eq!(actual, expected); -} - -#[rstest::rstest] -#[case("3.9".parse().unwrap(), "2.34".parse().unwrap(), vec![], vec![ - "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), - "uv python install 3.9".into(), - "uv python pin 3.9".into(), - "uv venv".into(), - "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), - "source ~/.bashrc".into(), - r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), -])] -#[case("3.9".parse().unwrap(), "2.34".parse().unwrap(), vec!["requests==0.0.0".into()], vec![ - "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), - "uv python install 3.9".into(), - "uv python pin 3.9".into(), - "uv venv".into(), - "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), - "source ~/.bashrc".into(), - r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), - r#"uv pip install "requests==0.0.0""#.into(), -])] -fn test_generate_setup_commands( - #[case] python_version: Versioning, - #[case] ray_version: Versioning, - #[case] dependencies: Vec, - #[case] expected: Vec, -) { - let actual = generate_setup_commands(python_version, ray_version, dependencies.as_slice()); - assert_eq!(actual, expected); -} - -#[rstest::fixture] -pub fn simple_config() -> (DaftConfig, Option, RayConfig) { - let test_name: StrRef = "test".into(); - let ssh_private_key: PathRef = Arc::from(PathBuf::from("testkey.pem")); - let number_of_workers = 4; - let daft_config = DaftConfig { - setup: DaftSetup { - name: test_name.clone(), - requires: "=1.2.3".parse().unwrap(), - python_version: "3.12".parse().unwrap(), - ray_version: "2.34".parse().unwrap(), - region: test_name.clone(), - number_of_workers, - ssh_user: test_name.clone(), - ssh_private_key: ssh_private_key.clone(), - instance_type: test_name.clone(), - image_id: test_name.clone(), - iam_instance_profile_name: Some(test_name.clone()), - dependencies: vec![], - }, - run: vec![], - jobs: HashMap::default(), - }; - let node_config = RayNodeConfig { - key_name: "testkey".into(), - instance_type: test_name.clone(), - image_id: test_name.clone(), - iam_instance_profile: Some(RayIamInstanceProfile { - name: test_name.clone(), - }), - }; - - let ray_config = RayConfig { - cluster_name: test_name.clone(), - max_workers: number_of_workers, - provider: RayProvider { - r#type: "aws".into(), - region: test_name.clone(), - cache_stopped_nodes: None, - }, - auth: RayAuth { - ssh_user: test_name.clone(), - ssh_private_key, - }, - available_node_types: vec![ - ( - "ray.head.default".into(), - RayNodeType { - max_workers: 0, - node_config: node_config.clone(), - resources: Some(RayResources { cpu: 0 }), - }, - ), - ( - "ray.worker.default".into(), - RayNodeType { - max_workers: number_of_workers, - node_config, - resources: None, - }, - ), - ] - .into_iter() - .collect(), - setup_commands: vec![ - "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), - "uv python install 3.12".into(), - "uv python pin 3.12".into(), - "uv venv".into(), - "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), - "source ~/.bashrc".into(), - r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), - ], - }; - - (daft_config, None, ray_config) -} - -#[tokio::test] -async fn test_init_and_export() { - crate::run(DaftLauncher { - sub_command: SubCommand::Config(ConfigCommands { - command: ConfigCommand::Init(Init { - path: ".daft.toml".into(), - provider: "provisioned".into(), - }), - }), - verbosity: 0, - }) - .await - .unwrap(); - crate::run(DaftLauncher { - sub_command: SubCommand::Config(ConfigCommands { - command: ConfigCommand::Check(ConfigPath { - config: ".daft.toml".into(), - }), - }), - verbosity: 0, - }) - .await - .unwrap(); -} From 7be3e82602c86a7749622e0a0f08d352ab3c7291 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 18:53:44 -0800 Subject: [PATCH 11/15] Change explicit returns to anyhow::bails --- src/main.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/main.rs b/src/main.rs index 7d95381..7d7c09e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -745,18 +745,15 @@ async fn establish_kubernetes_port_forward(namespace: Option<&str>) -> anyhow::R .output() .await?; if !output.status.success() { - return Err(anyhow::anyhow!( + anyhow::bail!( "Failed to get Ray head node services with kubectl in namespace {}", namespace - )); + ); } let stdout = String::from_utf8_lossy(&output.stdout); if stdout.trim().is_empty() { - return Err(anyhow::anyhow!( - "Ray head node service not found in namespace {}", - namespace - )); + anyhow::bail!("Ray head node service not found in namespace {}", namespace); } let head_node_service_name = stdout From 40cd032fec95a9b043834a7ce3494683502e491d Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 18:55:31 -0800 Subject: [PATCH 12/15] Add logic to kill process on drop --- src/main.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main.rs b/src/main.rs index 7d7c09e..8e85602 100644 --- a/src/main.rs +++ b/src/main.rs @@ -742,6 +742,7 @@ async fn establish_kubernetes_port_forward(namespace: Option<&str>) -> anyhow::R .arg("--no-headers") .arg("-o") .arg("custom-columns=:metadata.name") + .kill_on_drop(true) .output() .await?; if !output.status.success() { From c3a8b42ab27d27f40ff48c134268b54d377de32b Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 22:20:10 -0800 Subject: [PATCH 13/15] Add tests back in --- Cargo.lock | 109 ++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 3 ++ src/main.rs | 4 +- src/tests.rs | 126 +++++++++++++++------------------------------------ 4 files changed, 150 insertions(+), 92 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8d106c9..2d97e92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,15 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "anstream" version = "0.6.18" @@ -608,6 +617,7 @@ dependencies = [ "aws-sdk-sts", "clap", "comfy-table", + "rstest", "serde", "serde_yaml", "tempdir", @@ -711,6 +721,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -723,6 +744,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -730,9 +757,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", + "futures-macro", "futures-task", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -762,6 +791,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "h2" version = "0.3.26" @@ -1260,6 +1295,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "proc-macro-crate" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.93" @@ -1324,12 +1368,47 @@ dependencies = [ "bitflags", ] +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + [[package]] name = "regex-lite" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "remove_dir_all" version = "0.5.3" @@ -1354,6 +1433,36 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rstest" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03e905296805ab93e13c1ec3a03f4b6c4f35e9498a3d5fa96dc626d22c03cd89" +dependencies = [ + "futures-timer", + "futures-util", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef0053bbffce09062bee4bcc499b0fbe7a57b879f1efe088d6d8d4c7adcdef9b" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + [[package]] name = "rustc-demangle" version = "0.1.24" diff --git a/Cargo.toml b/Cargo.toml index 1765999..203845e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,3 +33,6 @@ features = ["derive", "rc"] [dependencies.versions] version = "6.3" features = ["serde"] + +[dev-dependencies] +rstest = "0.24.0" diff --git a/src/main.rs b/src/main.rs index f7a7ba5..6504c9d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -484,7 +484,7 @@ fn convert( ( "ray.head.default".into(), RayNodeType { - max_workers: aws_config.number_of_workers, + max_workers: 0, node_config: node_config.clone(), resources: Some(RayResources { cpu: 0 }), }, @@ -508,7 +508,7 @@ fn convert( "uv venv".into(), "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), "source ~/.bashrc".into(), - "uv pip install boto3 pip ray[default] getdaft py-spy deltalake".into(), + "uv pip install boto3 pip py-spy deltalake getdaft ray[default]".into(), ]; if !aws_config.dependencies.is_empty() { let deps = aws_config diff --git a/src/tests.rs b/src/tests.rs index 04a062f..d936129 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -27,17 +27,19 @@ async fn get_path() -> (TempDir, PathBuf) { /// `template.toml` file before writing it. Thus, the outputted configuration /// file does not *exactly* match the original `template.toml` file. #[tokio::test] -async fn test_init() { +#[rstest::rstest] +#[case(DaftProvider::Provisioned)] +#[case(DaftProvider::Byoc)] +async fn test_init(#[case] provider: DaftProvider) { let (_temp_dir, path) = get_path().await; - run(DaftLauncher { - sub_command: SubCommand::Config(ConfigCommands { - command: ConfigCommand::Init(Init { - path: path.clone(), - provider: DaftProvider::Provisioned, - }), - }), - }) + DaftLauncher { + sub_command: SubCommand::Config(ConfigCommand::Init(Init { + path: path.clone(), + provider, + })), + } + .run() .await .unwrap(); @@ -48,25 +50,26 @@ async fn test_init() { /// Tests to make sure that `daft check` properly asserts the schema of the /// newly created daft-launcher configuration file. #[tokio::test] -async fn test_check() { +#[rstest::rstest] +#[case(DaftProvider::Provisioned)] +#[case(DaftProvider::Byoc)] +async fn test_check(#[case] provider: DaftProvider) { let (_temp_dir, path) = get_path().await; - run(DaftLauncher { - sub_command: SubCommand::Config(ConfigCommands { - command: ConfigCommand::Init(Init { - path: path.clone(), - provider: DaftProvider::Provisioned, - }), - }), - }) + DaftLauncher { + sub_command: SubCommand::Config(ConfigCommand::Init(Init { + path: path.clone(), + provider, + })), + } + .run() .await .unwrap(); - run(DaftLauncher { - sub_command: SubCommand::Config(ConfigCommands { - command: ConfigCommand::Check(ConfigPath { config: path }), - }), - }) + DaftLauncher { + sub_command: SubCommand::Config(ConfigCommand::Check(ConfigPath { config: path })), + } + .run() .await .unwrap(); } @@ -94,36 +97,6 @@ fn test_conversion( assert_eq!(actual, expected); } -#[rstest::rstest] -#[case("3.9".parse().unwrap(), "2.34".parse().unwrap(), vec![], vec![ - "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), - "uv python install 3.9".into(), - "uv python pin 3.9".into(), - "uv venv".into(), - "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), - "source ~/.bashrc".into(), - r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), -])] -#[case("3.9".parse().unwrap(), "2.34".parse().unwrap(), vec!["requests==0.0.0".into()], vec![ - "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), - "uv python install 3.9".into(), - "uv python pin 3.9".into(), - "uv venv".into(), - "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), - "source ~/.bashrc".into(), - r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), - r#"uv pip install "requests==0.0.0""#.into(), -])] -fn test_generate_setup_commands( - #[case] python_version: Versioning, - #[case] ray_version: Versioning, - #[case] dependencies: Vec, - #[case] expected: Vec, -) { - let actual = generate_setup_commands(python_version, ray_version, dependencies.as_slice()); - assert_eq!(actual, expected); -} - #[rstest::fixture] pub fn simple_config() -> (DaftConfig, Option, RayConfig) { let test_name: StrRef = "test".into(); @@ -133,18 +106,15 @@ pub fn simple_config() -> (DaftConfig, Option, RayConfig) { setup: DaftSetup { name: test_name.clone(), version: "=1.2.3".parse().unwrap(), - provider: DaftProvider::Provisioned, - dependencies: vec![], - provider_config: ProviderConfig::Provisioned(AwsConfigWithRun { - config: AwsConfig { - region: test_name.clone(), - number_of_workers, - ssh_user: test_name.clone(), - ssh_private_key: ssh_private_key.clone(), - instance_type: test_name.clone(), - image_id: test_name.clone(), - iam_instance_profile_name: Some(test_name.clone()), - }, + provider_config: ProviderConfig::Provisioned(AwsConfig { + region: test_name.clone(), + number_of_workers, + ssh_user: test_name.clone(), + ssh_private_key: ssh_private_key.clone(), + instance_type: test_name.clone(), + image_id: test_name.clone(), + iam_instance_profile_name: Some(test_name.clone()), + dependencies: vec![], }), }, jobs: HashMap::default(), @@ -197,33 +167,9 @@ pub fn simple_config() -> (DaftConfig, Option, RayConfig) { "uv venv".into(), "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), "source ~/.bashrc".into(), - r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), + "uv pip install boto3 pip py-spy deltalake getdaft ray[default]".into(), ], }; (daft_config, None, ray_config) } - -#[tokio::test] -async fn test_init_and_export() { - run(DaftLauncher { - sub_command: SubCommand::Config(ConfigCommands { - command: ConfigCommand::Init(Init { - path: ".daft.toml".into(), - provider: DaftProvider::Provisioned, - }), - }), - }) - .await - .unwrap(); - - run(DaftLauncher { - sub_command: SubCommand::Config(ConfigCommands { - command: ConfigCommand::Check(ConfigPath { - config: ".daft.toml".into(), - }), - }), - }) - .await - .unwrap(); -} From 789c3f5832e3e96cde895fd4e43c2e427f9ed8e9 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Wed, 22 Jan 2025 22:20:59 -0800 Subject: [PATCH 14/15] Remove rstest version pinning --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 203845e..e485945 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,4 +35,4 @@ version = "6.3" features = ["serde"] [dev-dependencies] -rstest = "0.24.0" +rstest = "0.24" From 20dba47f55b8bf529e7ed4578007a794a386e32e Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Fri, 24 Jan 2025 12:54:53 -0800 Subject: [PATCH 15/15] Change error message; run formatter --- src/main.rs | 5 +++-- src/tests.rs | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/main.rs b/src/main.rs index 6504c9d..3c378dd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ macro_rules! not_available_for_byoc { anyhow::bail!(concat!( "The command `", $command, - "` is only available for provisioned configurations" + "` is not available for byoc configurations" )) }; } @@ -829,7 +829,8 @@ async fn submit_k8s( command_segments: impl AsRef<[&str]>, namespace: &str, ) -> anyhow::Result<()> { - // Start port forwarding - it will be automatically killed when _port_forward is dropped + // Start port forwarding - it will be automatically killed when _port_forward is + // dropped let _port_forward = establish_kubernetes_port_forward(namespace).await?; // Give the port-forward a moment to fully establish diff --git a/src/tests.rs b/src/tests.rs index d936129..8aa3e14 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,4 +1,5 @@ use std::io::ErrorKind; + use tempdir::TempDir; use tokio::fs;