diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..976a1910 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,305 @@ +name: Release + +on: + push: + branches: [br_release] + +permissions: + contents: write + +env: + # Set to 'true' to also publish to crates.io + PUBLISH_CRATES: ${{ vars.PUBLISH_CRATES || 'false' }} + +jobs: + release: + name: Create Release with Native Binaries + runs-on: ubuntu-latest + outputs: + version: ${{ steps.version.outputs.version }} + version_tag: ${{ steps.version.outputs.version_tag }} + release_created: ${{ steps.check_release.outputs.exists != 'true' }} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Get version from latest tag + id: version + run: | + # Get the latest tag on this branch + VERSION_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "") + + if [ -z "$VERSION_TAG" ]; then + echo "Error: No tags found. Run dump-version.sh first." + exit 1 + fi + + # Remove 'v' prefix for version number + VERSION="${VERSION_TAG#v}" + + echo "version_tag=${VERSION_TAG}" >> $GITHUB_OUTPUT + echo "version=${VERSION}" >> $GITHUB_OUTPUT + echo "Version Tag: ${VERSION_TAG}" + echo "Version: ${VERSION}" + + - name: Check if release already exists + id: check_release + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + if gh release view ${{ steps.version.outputs.version_tag }} &>/dev/null; then + echo "Release ${{ steps.version.outputs.version_tag }} already exists" + echo "exists=true" >> $GITHUB_OUTPUT + else + echo "Release ${{ steps.version.outputs.version_tag }} does not exist" + echo "exists=false" >> $GITHUB_OUTPUT + fi + + - name: Download native binaries from gopher-orch + if: steps.check_release.outputs.exists != 'true' + env: + GH_TOKEN: ${{ secrets.GOPHER_ORCH_TOKEN }} + run: | + echo "Downloading native binaries for ${{ steps.version.outputs.version_tag }}..." + + mkdir -p downloads + + # Download all platform binaries from gopher-orch release + gh release download ${{ steps.version.outputs.version_tag }} \ + -R GopherSecurity/gopher-orch \ + -D downloads \ + -p "libgopher-orch-*.tar.gz" \ + -p "libgopher-orch-*.zip" || { + echo "Warning: Could not download some binaries" + echo "Available assets:" + gh release view ${{ steps.version.outputs.version_tag }} -R GopherSecurity/gopher-orch --json assets -q '.assets[].name' + } + + echo "Downloaded files:" + ls -la downloads/ + + - name: Prepare release assets + if: steps.check_release.outputs.exists != 'true' + run: | + mkdir -p release-assets + + # Copy binaries to release-assets + for file in downloads/*; do + if [ -f "$file" ]; then + cp "$file" "release-assets/" + fi + done + + echo "Release assets:" + ls -la release-assets/ + + - name: Generate release notes + if: steps.check_release.outputs.exists != 'true' + run: | + VERSION="${{ steps.version.outputs.version }}" + VERSION_TAG="${{ steps.version.outputs.version_tag }}" + + cat > RELEASE_NOTES.md << EOF + ## gopher-mcp-rust ${VERSION_TAG} + + Rust SDK for gopher-orch orchestration framework. + + ### Installation + + #### From crates.io + + \`\`\`toml + # Add to Cargo.toml + [dependencies] + gopher-orch = "${VERSION}" + \`\`\` + + Or via cargo: + + \`\`\`bash + cargo add gopher-orch@${VERSION} + \`\`\` + + #### From GitHub + + \`\`\`toml + [dependencies] + gopher-orch = { git = "https://github.com/GopherSecurity/gopher-mcp-rust.git", tag = "${VERSION_TAG}" } + \`\`\` + + ### Native Library Installation + + \`\`\`bash + # macOS (Apple Silicon) + gh release download ${VERSION_TAG} -R GopherSecurity/gopher-mcp-rust -p "libgopher-orch-macos-arm64.tar.gz" + tar -xzf libgopher-orch-macos-arm64.tar.gz -C ./native + + # macOS (Intel) + gh release download ${VERSION_TAG} -R GopherSecurity/gopher-mcp-rust -p "libgopher-orch-macos-x64.tar.gz" + tar -xzf libgopher-orch-macos-x64.tar.gz -C ./native + + # Linux (x64) + gh release download ${VERSION_TAG} -R GopherSecurity/gopher-mcp-rust -p "libgopher-orch-linux-x64.tar.gz" + tar -xzf libgopher-orch-linux-x64.tar.gz -C ./native + + # Linux (arm64) + gh release download ${VERSION_TAG} -R GopherSecurity/gopher-mcp-rust -p "libgopher-orch-linux-arm64.tar.gz" + tar -xzf libgopher-orch-linux-arm64.tar.gz -C ./native + \`\`\` + + ### Environment Setup + + \`\`\`bash + # macOS + export DYLD_LIBRARY_PATH="./native/lib:\$DYLD_LIBRARY_PATH" + + # Linux + export LD_LIBRARY_PATH="./native/lib:\$LD_LIBRARY_PATH" + \`\`\` + + ### Build Information + + - **Version:** ${VERSION} + - **gopher-orch:** ${VERSION_TAG} + - **Commit:** ${{ github.sha }} + - **Date:** $(date -u +"%Y-%m-%d %H:%M:%S UTC") + + EOF + + # Extract changelog content + if [ -f "CHANGELOG.md" ]; then + echo "### What's Changed" >> RELEASE_NOTES.md + echo "" >> RELEASE_NOTES.md + + # Get content from the version section + sed -n "/^## \[${VERSION}\]/,/^## \[/p" CHANGELOG.md | \ + grep -v "^## \[" | \ + head -30 >> RELEASE_NOTES.md || true + fi + + # Add comparison link + PREV_TAG=$(git tag --sort=-creatordate | grep -v "^${VERSION_TAG}$" | head -1) + if [ -n "$PREV_TAG" ]; then + echo "" >> RELEASE_NOTES.md + echo "**Full Changelog**: https://github.com/${{ github.repository }}/compare/${PREV_TAG}...${VERSION_TAG}" >> RELEASE_NOTES.md + fi + + echo "=== Release Notes ===" + cat RELEASE_NOTES.md + + - name: Create GitHub Release + if: steps.check_release.outputs.exists != 'true' + uses: softprops/action-gh-release@v1 + with: + tag_name: ${{ steps.version.outputs.version_tag }} + name: gopher-mcp-rust ${{ steps.version.outputs.version_tag }} + body_path: RELEASE_NOTES.md + draft: false + prerelease: ${{ contains(steps.version.outputs.version, '-') }} + files: release-assets/* + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Summary + run: | + echo "## Release Summary" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "- **Version:** ${{ steps.version.outputs.version_tag }}" >> $GITHUB_STEP_SUMMARY + echo "- **Release URL:** https://github.com/${{ github.repository }}/releases/tag/${{ steps.version.outputs.version_tag }}" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "### Native Libraries" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + if [ -d "release-assets" ]; then + ls release-assets/ | while read file; do + echo "- \`${file}\`" >> $GITHUB_STEP_SUMMARY + done + fi + + publish-crates: + name: Publish to crates.io + needs: release + runs-on: ubuntu-latest + if: | + needs.release.outputs.release_created == 'true' && + (vars.PUBLISH_CRATES == 'true' || github.event.head_commit.message contains '[publish]') + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-cargo-publish-${{ hashFiles('**/Cargo.lock') }} + + - name: Verify package + run: | + echo "Verifying package before publish..." + cargo package --list + cargo publish --dry-run + + - name: Publish to crates.io + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + run: | + echo "Publishing version ${{ needs.release.outputs.version }} to crates.io..." + cargo publish + + - name: Summary + run: | + echo "## crates.io Publish Summary" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "- **Version:** ${{ needs.release.outputs.version }}" >> $GITHUB_STEP_SUMMARY + echo "- **crates.io:** https://crates.io/crates/gopher-orch/${{ needs.release.outputs.version }}" >> $GITHUB_STEP_SUMMARY + echo "- **docs.rs:** https://docs.rs/gopher-orch/${{ needs.release.outputs.version }}" >> $GITHUB_STEP_SUMMARY + + test: + name: Run Tests + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + + - name: Check formatting + run: cargo fmt --check + + - name: Run clippy + run: cargo clippy -- -D warnings || echo "Clippy warnings found" + + - name: Build (without native library) + run: cargo build || echo "Build requires native library" + + notify: + name: Notify on Failure + needs: [release, test] + runs-on: ubuntu-latest + if: failure() + steps: + - name: Report failure + run: | + echo "Release workflow failed!" + echo "Check the logs for details." diff --git a/.gitignore b/.gitignore index 8a4dfe5b..709c0fc3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # Rust -/target/ +**/target/ **/*.rs.bk Cargo.lock diff --git a/.gitmodules b/.gitmodules index d5cd4211..18a3014b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "third_party/gopher-orch"] path = third_party/gopher-orch url = https://github.com/GopherSecurity/gopher-orch.git + branch = br_release diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..68e2da72 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,26 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- Initial release of gopher-mcp-rust SDK +- Rust bindings for gopher-orch native library via FFI +- Runtime library loading using `libloading` crate +- OAuth 2.0 authentication support (feature-gated with `auth` feature) +- MCP (Model Context Protocol) client implementation +- GopherAgent for AI agent orchestration +- ConfigBuilder for client configuration +- Auth example server with Axum web framework + +### Features +- `default` - Core functionality without auth +- `auth` - OAuth 2.0 token validation via native library + +--- + +[Unreleased]: https://github.com/GopherSecurity/gopher-mcp-rust/compare/HEAD diff --git a/Cargo.toml b/Cargo.toml index 2352cddf..fe5901e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,16 +1,16 @@ [package] -name = "gopher-orch" -version = "0.1.0" +name = "gopher-mcp-rust" +version = "0.1.2-9" edition = "2021" authors = ["GopherSecurity"] description = "Rust SDK for Gopher Orch - AI Agent orchestration framework" -license = "MIT" +license = "Apache-2.0" repository = "https://github.com/GopherSecurity/gopher-mcp-rust" keywords = ["ai", "agent", "mcp", "llm", "orchestration"] categories = ["api-bindings", "development-tools"] [lib] -name = "gopher_orch" +name = "gopher_mcp_rust" path = "src/lib.rs" [[example]] @@ -28,3 +28,4 @@ once_cell = "=1.17.0" [features] default = [] +auth = [] diff --git a/build.sh b/build.sh index c36aa4a5..34ddbf71 100755 --- a/build.sh +++ b/build.sh @@ -183,7 +183,7 @@ echo -e "${YELLOW} Compiling Rust SDK...${NC}" LIBRARY_PATH="${NATIVE_LIB_DIR}" \ LD_LIBRARY_PATH="${NATIVE_LIB_DIR}" \ DYLD_LIBRARY_PATH="${NATIVE_LIB_DIR}" \ -cargo build --release +cargo build --release --features auth echo -e "${GREEN}✓ Rust SDK built successfully${NC}" echo "" @@ -193,7 +193,7 @@ echo -e "${YELLOW}Step 5: Running tests...${NC}" LIBRARY_PATH="${NATIVE_LIB_DIR}" \ LD_LIBRARY_PATH="${NATIVE_LIB_DIR}" \ DYLD_LIBRARY_PATH="${NATIVE_LIB_DIR}" \ -cargo test && echo -e "${GREEN}✓ Tests passed${NC}" || echo -e "${YELLOW}⚠ Some tests may have failed (native library required)${NC}" +cargo test --features auth && echo -e "${GREEN}✓ Tests passed${NC}" || echo -e "${YELLOW}⚠ Some tests may have failed (native library required)${NC}" echo "" echo -e "${GREEN}======================================${NC}" @@ -204,7 +204,7 @@ echo -e "Native libraries: ${YELLOW}${NATIVE_LIB_DIR}${NC}" echo -e "Native headers: ${YELLOW}${NATIVE_INCLUDE_DIR}${NC}" echo "" echo -e "To run tests manually:" -echo -e " ${YELLOW}DYLD_LIBRARY_PATH=\$(pwd)/native/lib cargo test${NC}" +echo -e " ${YELLOW}DYLD_LIBRARY_PATH=\$(pwd)/native/lib cargo test --features auth${NC}" echo "" echo -e "To build:" -echo -e " ${YELLOW}cargo build --release${NC}" +echo -e " ${YELLOW}cargo build --release --features auth${NC}" diff --git a/dump-version.sh b/dump-version.sh new file mode 100755 index 00000000..15eb7f07 --- /dev/null +++ b/dump-version.sh @@ -0,0 +1,485 @@ +#!/bin/bash +# +# dump-version.sh - Prepare a new release version for gopher-mcp-rust +# +# Usage: +# ./dump-version.sh [OPTIONS] [VERSION] +# +# Options: +# --skip-crates Skip publishing to crates.io (default: publish) +# --dry-run Show what would be done without making changes +# --help Show this help message +# +# Arguments: +# VERSION - Optional. Format: X.Y.Z or X.Y.Z.E +# If not provided, uses latest gopher-orch release version (X.Y.Z) +# If provided as X.Y.Z.E, X.Y.Z must match gopher-orch version +# +# This script will: +# 1. Fetch latest version from gopher-orch releases +# 2. Validate and determine the target version +# 3. Update Cargo.toml version +# 4. Update CHANGELOG.md ([Unreleased] -> [X.Y.Z] - date) +# 5. Create git tag vX.Y.Z +# 6. Commit the changes +# 7. Publish to crates.io (unless --skip-crates is specified) +# +# After running this script: +# 1. Review the changes: git show HEAD +# 2. Push to release: git push origin br_release vX.Y.Z +# +# Environment variables: +# CARGO_REGISTRY_TOKEN - crates.io API token (required for publishing) +# + +set -e + +# Colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +CYAN='\033[0;36m' +NC='\033[0m' + +# Get script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Files +CHANGELOG_FILE="CHANGELOG.md" +CARGO_TOML="Cargo.toml" + +# Options +PUBLISH_CRATES=true +DRY_RUN=false +INPUT_VERSION="" + +# Parse options +while [[ $# -gt 0 ]]; do + case $1 in + --skip-crates|--no-crates) + PUBLISH_CRATES=false + shift + ;; + --dry-run) + DRY_RUN=true + shift + ;; + --help|-h) + echo "Usage: $0 [OPTIONS] [VERSION]" + echo "" + echo "Options:" + echo " --skip-crates Skip publishing to crates.io (default: publish)" + echo " --dry-run Show what would be done without making changes" + echo " --help Show this help message" + echo "" + echo "Arguments:" + echo " VERSION Version to release (default: latest gopher-orch version)" + echo " Format: X.Y.Z or X.Y.Z.E" + echo "" + echo "Examples:" + echo " $0 # Release to GitHub and crates.io" + echo " $0 0.1.2 # Release specific version" + echo " $0 --skip-crates # Release to GitHub only" + echo " $0 --dry-run # Preview changes without executing" + echo "" + echo "Environment variables:" + echo " CARGO_REGISTRY_TOKEN crates.io API token (required for publishing)" + echo "" + exit 0 + ;; + -*) + echo -e "${RED}Unknown option: $1${NC}" + echo "Use --help for usage information" + exit 1 + ;; + *) + INPUT_VERSION="$1" + shift + ;; + esac +done + +echo -e "${CYAN}========================================${NC}" +echo -e "${CYAN} gopher-mcp-rust Release Version Dump${NC}" +echo -e "${CYAN}========================================${NC}" +echo "" + +if [ "$DRY_RUN" = true ]; then + echo -e "${YELLOW}DRY RUN MODE - No changes will be made${NC}" + echo "" +fi + +if [ "$PUBLISH_CRATES" = true ]; then + echo -e "${CYAN}Publishing to: GitHub + crates.io${NC}" +else + echo -e "${YELLOW}Publishing to: GitHub only (--skip-crates)${NC}" +fi +echo "" + +# ----------------------------------------------------------------------------- +# Step 1: Fetch latest gopher-orch version from GitHub releases +# ----------------------------------------------------------------------------- +echo -e "${YELLOW}Step 1: Fetching latest gopher-orch version...${NC}" + +# Check if gh CLI is available +if ! command -v gh &> /dev/null; then + echo -e "${RED}Error: GitHub CLI (gh) is not installed${NC}" + echo "Install it with: brew install gh" + echo "Then authenticate: gh auth login" + exit 1 +fi + +# Fetch latest release from gopher-orch using gh CLI (handles private repo auth) +GOPHER_ORCH_TAG=$(gh release view --repo GopherSecurity/gopher-orch --json tagName -q '.tagName' 2>/dev/null) + +if [ -z "$GOPHER_ORCH_TAG" ]; then + echo -e "${RED}Error: Could not fetch latest gopher-orch release${NC}" + echo "Make sure you have access to GopherSecurity/gopher-orch repository." + echo "Run 'gh auth login' to authenticate if needed." + exit 1 +fi + +# Remove 'v' prefix if present (e.g., v0.1.1 -> 0.1.1) +GOPHER_ORCH_VERSION="${GOPHER_ORCH_TAG#v}" + +if [ -z "$GOPHER_ORCH_VERSION" ]; then + echo -e "${RED}Error: Could not parse gopher-orch version from release${NC}" + exit 1 +fi + +echo -e " Latest gopher-orch version: ${GREEN}$GOPHER_ORCH_VERSION${NC}" + +# Validate gopher-orch version format (X.Y.Z) +if ! echo "$GOPHER_ORCH_VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+$'; then + echo -e "${RED}Error: gopher-orch version '$GOPHER_ORCH_VERSION' is not in X.Y.Z format${NC}" + exit 1 +fi + +# ----------------------------------------------------------------------------- +# Step 2: Determine target version +# ----------------------------------------------------------------------------- +echo "" +echo -e "${YELLOW}Step 2: Determining target version...${NC}" + +if [ -z "$INPUT_VERSION" ]; then + # No argument provided, use gopher-orch version directly + TARGET_VERSION="$GOPHER_ORCH_VERSION" + echo -e " No version argument provided" + echo -e " Using gopher-orch version: ${GREEN}$TARGET_VERSION${NC}" +else + # Version argument provided, validate it + # Format should be X.Y.Z or X.Y.Z.E + if echo "$INPUT_VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+$'; then + # X.Y.Z format - must match gopher-orch exactly + if [ "$INPUT_VERSION" != "$GOPHER_ORCH_VERSION" ]; then + echo -e "${RED}Error: Version $INPUT_VERSION does not match gopher-orch version $GOPHER_ORCH_VERSION${NC}" + exit 1 + fi + TARGET_VERSION="$INPUT_VERSION" + elif echo "$INPUT_VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$'; then + # X.Y.Z.E format - first 3 parts must match gopher-orch + INPUT_BASE=$(echo "$INPUT_VERSION" | sed -E 's/^([0-9]+\.[0-9]+\.[0-9]+)\.[0-9]+$/\1/') + if [ "$INPUT_BASE" != "$GOPHER_ORCH_VERSION" ]; then + echo -e "${RED}Error: Version base $INPUT_BASE does not match gopher-orch version $GOPHER_ORCH_VERSION${NC}" + echo "Extended version X.Y.Z.E must have X.Y.Z matching gopher-orch." + exit 1 + fi + TARGET_VERSION="$INPUT_VERSION" + else + echo -e "${RED}Error: Invalid version format '$INPUT_VERSION'${NC}" + echo "Expected format: X.Y.Z or X.Y.Z.E" + exit 1 + fi + echo -e " Using provided version: ${GREEN}$TARGET_VERSION${NC}" +fi + +TAG_VERSION="v$TARGET_VERSION" + +# ----------------------------------------------------------------------------- +# Step 3: Check if tag already exists +# ----------------------------------------------------------------------------- +echo "" +echo -e "${YELLOW}Step 3: Checking existing tags...${NC}" + +if git tag -l | grep -q "^$TAG_VERSION$"; then + echo -e "${RED}Error: Tag $TAG_VERSION already exists${NC}" + echo "If you want to re-release, delete the tag first:" + echo " git tag -d $TAG_VERSION" + echo " git push origin :refs/tags/$TAG_VERSION" + exit 1 +fi + +echo -e " Tag ${GREEN}$TAG_VERSION${NC} is available" + +# ----------------------------------------------------------------------------- +# Step 4: Update Cargo.toml version +# ----------------------------------------------------------------------------- +echo "" +echo -e "${YELLOW}Step 4: Updating Cargo.toml...${NC}" + +if [ ! -f "$CARGO_TOML" ]; then + echo -e "${RED}Error: Cargo.toml not found${NC}" + exit 1 +fi + +# Get current version +CURRENT_VERSION=$(grep -E '^version = "[0-9]+\.[0-9]+\.[0-9]+"' "$CARGO_TOML" | head -1 | sed -E 's/version = "([^"]+)"/\1/') +echo -e " Current version: ${YELLOW}$CURRENT_VERSION${NC}" + +if [ "$DRY_RUN" = false ]; then + # Update version in Cargo.toml + sed -i.bak -E "s/^version = \"[0-9]+\.[0-9]+\.[0-9]+.*\"/version = \"$TARGET_VERSION\"/" "$CARGO_TOML" + rm -f "${CARGO_TOML}.bak" +fi + +echo -e " Updated to: ${GREEN}$TARGET_VERSION${NC}" + +# ----------------------------------------------------------------------------- +# Step 5: Check [Unreleased] section has content +# ----------------------------------------------------------------------------- +echo "" +echo -e "${YELLOW}Step 5: Checking [Unreleased] section...${NC}" + +if [ ! -f "$CHANGELOG_FILE" ]; then + echo -e "${YELLOW}Warning: $CHANGELOG_FILE not found, creating one...${NC}" + if [ "$DRY_RUN" = false ]; then + cat > "$CHANGELOG_FILE" << EOF +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- Initial release of gopher-mcp-rust SDK +- Rust bindings for gopher-orch native library +- OAuth 2.0 authentication support (feature-gated) +- MCP (Model Context Protocol) client implementation +- Runtime library loading via libloading + +--- + +[Unreleased]: https://github.com/GopherSecurity/gopher-mcp-rust/compare/HEAD +EOF + fi +fi + +# Extract content between [Unreleased] and next ## section +UNRELEASED_CONTENT=$(sed -n '/^## \[Unreleased\]/,/^## \[/p' "$CHANGELOG_FILE" | \ + grep -v "^## \[" | grep -v "^$" | head -20) + +if [ -z "$UNRELEASED_CONTENT" ]; then + echo -e "${YELLOW}Warning: [Unreleased] section in CHANGELOG.md appears empty${NC}" + echo "You may want to add release notes before continuing." + if [ "$DRY_RUN" = false ]; then + read -p "Continue anyway? (y/N) " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + exit 1 + fi + fi +else + echo -e " ${GREEN}[Unreleased] section has content${NC}" + echo " Preview:" + echo "$UNRELEASED_CONTENT" | head -5 | sed 's/^/ /' +fi + +# ----------------------------------------------------------------------------- +# Step 6: Update CHANGELOG.md +# ----------------------------------------------------------------------------- +echo "" +echo -e "${YELLOW}Step 6: Updating CHANGELOG.md...${NC}" + +TODAY=$(date +%Y-%m-%d) +REPO_URL="https://github.com/GopherSecurity/gopher-mcp-rust" + +if [ "$DRY_RUN" = false ]; then + # Create backup + cp "$CHANGELOG_FILE" "${CHANGELOG_FILE}.bak" + + # Find the line number of [Unreleased] header + UNRELEASED_LINE=$(grep -n "^## \[Unreleased\]" "$CHANGELOG_FILE" | head -1 | cut -d: -f1) + + if [ -z "$UNRELEASED_LINE" ]; then + echo -e "${RED}Error: Could not find [Unreleased] section in CHANGELOG.md${NC}" + rm -f "${CHANGELOG_FILE}.bak" + exit 1 + fi + + # Find the previous version for link generation + PREV_VERSION=$(grep -E "^## \[[0-9]+\.[0-9]+\.[0-9]+" "$CHANGELOG_FILE" | head -1 | sed -E 's/^## \[([^]]+)\].*/\1/') + + # Check if there's a links section at the bottom (starts with --- or [Unreleased]:) + HAS_LINKS_SECTION=$(grep -c "^\[Unreleased\]:" "$CHANGELOG_FILE" || true) + + # Find where links section starts (look for --- separator or [Unreleased]: link) + if [ "$HAS_LINKS_SECTION" -gt 0 ]; then + # Find the --- line before [Unreleased]: link, or the [Unreleased]: line itself + LINKS_LINE=$(grep -n "^\[Unreleased\]:" "$CHANGELOG_FILE" | head -1 | cut -d: -f1) + # Check if there's a --- separator before it + SEPARATOR_LINE=$(grep -n "^---$" "$CHANGELOG_FILE" | tail -1 | cut -d: -f1) + if [ -n "$SEPARATOR_LINE" ] && [ "$SEPARATOR_LINE" -lt "$LINKS_LINE" ]; then + LINKS_LINE=$SEPARATOR_LINE + fi + else + LINKS_LINE="" + fi + + # Build new CHANGELOG content + { + # 1. Header section (everything before [Unreleased]) + head -n $((UNRELEASED_LINE - 1)) "$CHANGELOG_FILE" + + # 2. New [Unreleased] section (empty) + echo "## [Unreleased]" + echo "" + + # 3. New version section with today's date + echo "## [$TARGET_VERSION] - $TODAY" + + # 4. Content after old [Unreleased] header until links section or EOF + if [ -n "$LINKS_LINE" ]; then + # Get content between [Unreleased] header and links section + tail -n +$((UNRELEASED_LINE + 1)) "$CHANGELOG_FILE" | head -n $((LINKS_LINE - UNRELEASED_LINE - 1)) + else + # No links section, get everything after [Unreleased] header + tail -n +$((UNRELEASED_LINE + 1)) "$CHANGELOG_FILE" + fi + + # 5. Add/Update links section + echo "" + echo "---" + echo "" + # [Unreleased] link pointing to compare from new version to HEAD + echo "[Unreleased]: ${REPO_URL}/compare/v${TARGET_VERSION}...HEAD" + # Add new version link + if [ -n "$PREV_VERSION" ]; then + echo "[$TARGET_VERSION]: ${REPO_URL}/compare/v${PREV_VERSION}...v${TARGET_VERSION}" + else + echo "[$TARGET_VERSION]: ${REPO_URL}/releases/tag/v${TARGET_VERSION}" + fi + # Keep existing version links (skip old [Unreleased] link and current version) + if [ "$HAS_LINKS_SECTION" -gt 0 ]; then + grep -E "^\[[0-9]+\.[0-9]+\.[0-9]+" "$CHANGELOG_FILE" | grep -v "^\[$TARGET_VERSION\]" || true + fi + } > "${CHANGELOG_FILE}.new" + + mv "${CHANGELOG_FILE}.new" "$CHANGELOG_FILE" + rm -f "${CHANGELOG_FILE}.bak" +fi + +echo -e " ${GREEN}CHANGELOG.md updated${NC}" +echo -e " [Unreleased] -> [$TARGET_VERSION] - $TODAY" + +# ----------------------------------------------------------------------------- +# Step 7: Commit changes and create tag +# ----------------------------------------------------------------------------- +echo "" +echo -e "${YELLOW}Step 7: Committing changes and creating tag...${NC}" + +if [ "$DRY_RUN" = false ]; then + # Show what changed + echo "" + echo -e "${CYAN}Changes to be committed:${NC}" + git diff --stat "$CARGO_TOML" "$CHANGELOG_FILE" + + echo "" + echo -e "${CYAN}Committing...${NC}" + + git add "$CARGO_TOML" "$CHANGELOG_FILE" + git commit -m "Release version $TARGET_VERSION + +Prepare release v$TARGET_VERSION: +- Update Cargo.toml: version = \"$TARGET_VERSION\" +- Update CHANGELOG.md: [Unreleased] -> [$TARGET_VERSION] - $TODAY + +gopher-orch version: $GOPHER_ORCH_VERSION + +Changes in this release: +$(echo "$UNRELEASED_CONTENT" | head -10) +" + + # Create annotated tag + echo "" + echo -e "${CYAN}Creating tag $TAG_VERSION...${NC}" + git tag -a "$TAG_VERSION" -m "Release $TARGET_VERSION + +gopher-orch version: $GOPHER_ORCH_VERSION + +Changes: +$(echo "$UNRELEASED_CONTENT" | head -15) +" +else + echo -e " ${YELLOW}[DRY RUN] Would commit Cargo.toml and CHANGELOG.md${NC}" + echo -e " ${YELLOW}[DRY RUN] Would create tag $TAG_VERSION${NC}" +fi + +# ----------------------------------------------------------------------------- +# Step 8: Publish to crates.io +# ----------------------------------------------------------------------------- +echo "" +if [ "$PUBLISH_CRATES" = true ]; then + echo -e "${YELLOW}Step 8: Publishing to crates.io...${NC}" + + if [ "$DRY_RUN" = true ]; then + echo -e " ${YELLOW}[DRY RUN] Would run: cargo publish${NC}" + echo -e " ${CYAN}Verifying package...${NC}" + cargo publish --dry-run 2>&1 | head -20 || true + else + echo -e " ${CYAN}Running cargo publish...${NC}" + + # Publish + if cargo publish; then + echo -e " ${GREEN}Successfully published to crates.io${NC}" + else + echo -e "${RED}Error: Failed to publish to crates.io${NC}" + echo "The git commit and tag were created. You can manually publish later with:" + echo " cargo publish" + exit 1 + fi + fi +else + echo -e "${YELLOW}Step 8: Skipping crates.io publish (--skip-crates)${NC}" +fi + +echo "" +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN} Release preparation complete!${NC}" +echo -e "${GREEN}========================================${NC}" +echo "" +echo -e "Version: ${CYAN}$TARGET_VERSION${NC}" +echo -e "Tag: ${CYAN}$TAG_VERSION${NC}" +echo -e "gopher-orch: ${CYAN}$GOPHER_ORCH_VERSION${NC}" +if [ "$PUBLISH_CRATES" = true ]; then + echo -e "crates.io: ${GREEN}Published${NC}" +else + echo -e "crates.io: ${YELLOW}Skipped${NC}" +fi +echo "" +echo -e "${YELLOW}Next steps:${NC}" +echo " 1. Review the commit: git show HEAD" +echo " 2. Push to release: git push origin br_release $TAG_VERSION" +echo "" +echo -e "${CYAN}After pushing:${NC}" +echo " - CI workflow will create GitHub Release" +echo " - Native libraries will be attached to release" +if [ "$PUBLISH_CRATES" = false ]; then + echo " - Publish to crates.io manually: cargo publish" +fi +echo "" +echo -e "${CYAN}Users can install with:${NC}" +if [ "$PUBLISH_CRATES" = true ]; then + echo "" + echo " # From crates.io" + echo " [dependencies]" + echo " gopher-orch = \"$TARGET_VERSION\"" +fi +echo "" +echo " # From GitHub" +echo " [dependencies]" +echo " gopher-orch = { git = \"https://github.com/GopherSecurity/gopher-mcp-rust.git\", tag = \"$TAG_VERSION\" }" +echo "" diff --git a/examples/auth/.gitignore b/examples/auth/.gitignore new file mode 100644 index 00000000..da2a8cd8 --- /dev/null +++ b/examples/auth/.gitignore @@ -0,0 +1,17 @@ +# Build output +target/ + +# Downloaded native libraries +native/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# OS +.DS_Store + +# Cargo lock is committed for reproducible builds +# Cargo.lock diff --git a/examples/auth/Cargo.toml b/examples/auth/Cargo.toml new file mode 100644 index 00000000..eb144d10 --- /dev/null +++ b/examples/auth/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "auth-mcp-server" +version = "0.1.0" +edition = "2021" +description = "OAuth-protected MCP server example using gopher-mcp-rust SDK" +license = "MIT" + +[dependencies] +# Web framework +axum = { version = "0.7", features = ["macros"] } +tower = { version = "0.4", features = ["util"] } +tower-http = { version = "0.5", features = ["cors", "trace"] } + +# Async runtime +tokio = { version = "1", features = ["full"] } + +# Serialization +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# HTTP types +http = "1" + +# Gopher MCP Rust SDK with auth FFI +# Use path dependency for local development +gopher-mcp-rust = { path = "../..", features = ["auth"] } + +# Logging +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +# Utilities +once_cell = "1" +thiserror = "1" +chrono = { version = "0.4", features = ["serde"] } +url = "2" + +[dev-dependencies] +http-body-util = "0.1" diff --git a/examples/auth/README.md b/examples/auth/README.md new file mode 100644 index 00000000..57064fba --- /dev/null +++ b/examples/auth/README.md @@ -0,0 +1,310 @@ +# Gopher Auth MCP Server - Rust Example + +This example demonstrates an MCP (Model Context Protocol) server with OAuth 2.0 authentication using the gopher-mcp-rust SDK. + +## Overview + +The auth example server provides: +- OAuth 2.0 / OIDC discovery endpoints (RFC 8414, RFC 9728) +- JWT token validation via native library +- Scope-based authorization for MCP tools +- Example weather tools with different scope requirements + +## Prerequisites + +- Rust 1.70 or later +- GitHub CLI (`gh`) for downloading native libraries + +## Installation + +### 1. Clone or Copy This Example + +```bash +# Option A: Clone the repository +git clone https://github.com/GopherSecurity/gopher-mcp-rust.git +cd gopher-mcp-rust/examples/auth + +# Option B: Copy the example files to your project +# Copy the examples/auth directory contents +``` + +### 2. Install the Rust SDK + +The SDK is specified in `Cargo.toml` as a git dependency: + +```toml +[dependencies] +gopher-orch = { git = "https://github.com/GopherSecurity/gopher-mcp-rust.git", features = ["auth"] } +``` + +### 3. Download Native Libraries + +The SDK requires native libraries for OAuth token validation. The `run_example.sh` script downloads these automatically, or you can install them manually: + +```bash +# Using the run script (downloads automatically) +./run_example.sh --no-auth + +# Or download manually using the install script +curl -sSL https://raw.githubusercontent.com/GopherSecurity/gopher-mcp-rust/main/install-native.sh | bash -s -- latest ./native +``` + +## Quick Start + +### Development Mode (No Auth) + +```bash +# Run with auth disabled (all requests bypass authentication) +./run_example.sh --no-auth + +# Or build and run manually +cargo build --release +./target/release/auth-mcp-server +``` + +### With Full OAuth Support + +```bash +# Run with OAuth authentication enabled +./run_example.sh + +# Or build manually with environment set +export DYLD_LIBRARY_PATH="./native/lib:$DYLD_LIBRARY_PATH" +cargo build --release +./target/release/auth-mcp-server server.config +``` + +### Using Environment Variables + +```bash +# Use a specific SDK version +SDK_VERSION=v0.1.3 ./run_example.sh + +# Use custom native library location +NATIVE_LIB_DIR=/usr/local/lib ./run_example.sh --skip-download +``` + +## Configuration + +Create a `server.config` file with INI-style key-value pairs: + +```ini +# Server settings +host=0.0.0.0 +port=3001 +server_url=http://localhost:3001 + +# OAuth/IDP settings +client_id=my-client +client_secret=my-secret +auth_server_url=https://keycloak.example.com/realms/mcp + +# Scopes +allowed_scopes=openid profile email mcp:read mcp:admin + +# Cache settings +jwks_cache_duration=3600 +jwks_auto_refresh=true +request_timeout=30 + +# Auth bypass mode (for development) +auth_disabled=true +``` + +### Configuration Options + +| Option | Description | Default | +|--------|-------------|---------| +| `host` | Bind address | `0.0.0.0` | +| `port` | Port number | `3001` | +| `server_url` | Public URL of this server | `http://localhost:3001` | +| `auth_server_url` | Keycloak/IDP base URL | - | +| `client_id` | OAuth client ID | - | +| `client_secret` | OAuth client secret | - | +| `allowed_scopes` | Space-separated allowed scopes | - | +| `jwks_cache_duration` | JWKS cache TTL in seconds | `3600` | +| `jwks_auto_refresh` | Auto-refresh JWKS | `true` | +| `request_timeout` | HTTP request timeout in seconds | `30` | +| `auth_disabled` | Disable authentication | `false` | + +## Available Endpoints + +### Health Check + +```bash +curl http://localhost:3001/health +``` + +### OAuth Discovery + +```bash +# Protected Resource Metadata (RFC 9728) +curl http://localhost:3001/.well-known/oauth-protected-resource + +# Authorization Server Metadata (RFC 8414) +curl http://localhost:3001/.well-known/oauth-authorization-server + +# OpenID Connect Discovery +curl http://localhost:3001/.well-known/openid-configuration +``` + +### MCP Endpoints + +```bash +# Initialize +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc":"2.0","id":1,"method":"initialize"}' + +# List Tools +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc":"2.0","id":2,"method":"tools/list"}' + +# Call Tool (with auth) +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -d '{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"get-weather","arguments":{"city":"NYC"}}}' +``` + +## Available Tools + +| Tool | Scope Required | Description | +|------|----------------|-------------| +| `get-weather` | None | Get current weather for a city | +| `get-forecast` | `mcp:read` | Get 5-day weather forecast | +| `get-weather-alerts` | `mcp:admin` | Get weather alerts for a region | + +### Tool Examples + +```bash +# get-weather (no auth required) +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"get-weather","arguments":{"city":"Tokyo"}}}' + +# get-forecast (requires mcp:read scope) +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer TOKEN_WITH_MCP_READ" \ + -d '{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"get-forecast","arguments":{"city":"Paris"}}}' + +# get-weather-alerts (requires mcp:admin scope) +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer TOKEN_WITH_MCP_ADMIN" \ + -d '{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"get-weather-alerts","arguments":{"region":"California"}}}' +``` + +## Troubleshooting + +### "Native library not found" at runtime + +The native gopher-orch library is required for JWT validation: + +```bash +# Download using the run script +./run_example.sh + +# Or manually download +curl -sSL https://raw.githubusercontent.com/GopherSecurity/gopher-mcp-rust/main/install-native.sh | bash -s -- latest ./native +``` + +Verify the library is installed: +```bash +ls -la ./native/lib/libgopher-orch* +``` + +### Library path not set + +```bash +# macOS +export DYLD_LIBRARY_PATH="./native/lib:$DYLD_LIBRARY_PATH" + +# Linux +export LD_LIBRARY_PATH="./native/lib:$LD_LIBRARY_PATH" +``` + +### "Auth client creation failed" + +Check that: +1. `jwks_uri` points to a valid JWKS endpoint +2. `issuer` matches the token issuer +3. Network can reach the auth server + +### "Token validation failed" + +Ensure: +1. Token is not expired +2. Token issuer matches config +3. Token was signed by a key in JWKS +4. Required scopes are present in token + +## Project Structure + +``` +auth/ +├── Cargo.toml # Dependencies (uses gopher-orch SDK) +├── Cargo.lock # Dependency lock file +├── server.config # Example configuration +├── run_example.sh # Build and run script +├── README.md # This file +├── native/ # Downloaded native libraries +│ ├── lib/ # .dylib/.so files +│ └── include/ # Header files +└── src/ + ├── main.rs # Entry point and router setup + ├── config.rs # Configuration parsing + ├── cors.rs # CORS utilities + ├── error.rs # Error types + ├── ffi/ + │ ├── mod.rs # FFI module + │ └── auth.rs # gopher-auth bindings + ├── middleware/ + │ ├── mod.rs # Middleware module + │ └── oauth_auth.rs # Auth middleware + ├── routes/ + │ ├── mod.rs # Routes module + │ ├── health.rs # Health endpoint + │ ├── mcp_handler.rs # MCP JSON-RPC handler + │ └── oauth_endpoints.rs # OAuth discovery + └── tools/ + ├── mod.rs # Tools module + └── weather_tools.rs # Weather tool implementations +``` + +## Testing + +Run the test suite: + +```bash +cargo test +``` + +Run with verbose output: + +```bash +cargo test -- --nocapture +``` + +## Environment Variables + +| Variable | Description | +|----------|-------------| +| `RUST_LOG` | Log level (e.g., `info`, `debug`, `trace`) | +| `SDK_VERSION` | Version of gopher-mcp-rust SDK (default: v0.1.2) | +| `NATIVE_LIB_DIR` | Directory for native libraries (default: ./native/lib) | +| `DYLD_LIBRARY_PATH` | macOS library search path | +| `LD_LIBRARY_PATH` | Linux library search path | + +## SDK Documentation + +For more information about the gopher-mcp-rust SDK: + +- Repository: https://github.com/GopherSecurity/gopher-mcp-rust +- Documentation: https://docs.rs/gopher-orch (after crates.io publish) + +## License + +MIT License - see LICENSE file for details. diff --git a/examples/auth/run_example.sh b/examples/auth/run_example.sh new file mode 100755 index 00000000..9d8cf022 --- /dev/null +++ b/examples/auth/run_example.sh @@ -0,0 +1,270 @@ +#!/bin/bash +# +# Gopher Auth MCP Server - Run Script (Rust) +# This script downloads dependencies and runs the Rust auth example server +# Works as a standalone third-party example +# +# Usage: +# ./run_example.sh # Run with default config +# ./run_example.sh --no-auth # Run with auth disabled +# ./run_example.sh --config # Run with custom config file +# ./run_example.sh --skip-download # Skip native library download +# + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +CYAN='\033[0;36m' +NC='\033[0m' # No Color + +# Configuration +SDK_VERSION="${SDK_VERSION:-v0.1.2}" +NATIVE_LIB_DIR="${NATIVE_LIB_DIR:-$SCRIPT_DIR/native/lib}" +NATIVE_INCLUDE_DIR="${NATIVE_INCLUDE_DIR:-$SCRIPT_DIR/native/include}" +GITHUB_REPO="GopherSecurity/gopher-mcp-rust" + +# Print usage +usage() { + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " --no-auth Run with authentication disabled" + echo " --config Use custom configuration file" + echo " --skip-download Skip native library download (use existing)" + echo " --help, -h Show this help message" + echo "" + echo "Environment Variables:" + echo " SDK_VERSION Version of gopher-mcp-rust SDK (default: $SDK_VERSION)" + echo " NATIVE_LIB_DIR Directory for native libraries (default: ./native/lib)" + echo "" + echo "Examples:" + echo " $0 # Run with default settings" + echo " $0 --no-auth # Run with auth disabled" + echo " SDK_VERSION=v0.1.3 $0 # Use specific SDK version" + echo "" +} + +# Check for cargo +check_cargo() { + if ! command -v cargo &> /dev/null; then + echo -e "${RED}Error: cargo is not installed${NC}" + echo "Please install Rust from https://rustup.rs/" + exit 1 + fi + + RUST_VERSION=$(rustc --version | grep -oE '[0-9]+\.[0-9]+\.[0-9]+') + echo -e "${GREEN}Rust version: $RUST_VERSION${NC}" +} + +# Check for gh CLI +check_gh_cli() { + if ! command -v gh &> /dev/null; then + echo -e "${RED}Error: GitHub CLI (gh) is not installed${NC}" + echo "Install it with: brew install gh" + echo "Then authenticate: gh auth login" + exit 1 + fi +} + +# Download native library +download_native_library() { + echo -e "${YELLOW}Downloading native library ($SDK_VERSION)...${NC}" + + # Detect platform + OS=$(uname -s | tr '[:upper:]' '[:lower:]') + ARCH=$(uname -m) + + case "$OS" in + darwin) OS_NAME="macos" ;; + linux) OS_NAME="linux" ;; + mingw*|msys*|cygwin*) OS_NAME="windows" ;; + *) echo -e "${RED}Error: Unsupported OS: $OS${NC}"; exit 1 ;; + esac + + case "$ARCH" in + x86_64|amd64) ARCH_NAME="x64" ;; + arm64|aarch64) ARCH_NAME="arm64" ;; + *) echo -e "${RED}Error: Unsupported architecture: $ARCH${NC}"; exit 1 ;; + esac + + PLATFORM="${OS_NAME}-${ARCH_NAME}" + + # Determine file extension + if [ "$OS_NAME" = "windows" ]; then + ARCHIVE_EXT="zip" + else + ARCHIVE_EXT="tar.gz" + fi + + ARCHIVE_NAME="libgopher-orch-${PLATFORM}.${ARCHIVE_EXT}" + + echo -e " Platform: ${GREEN}${PLATFORM}${NC}" + echo -e " Archive: ${GREEN}${ARCHIVE_NAME}${NC}" + + # Create temp directory + TEMP_DIR=$(mktemp -d) + trap "rm -rf $TEMP_DIR" EXIT + + cd "$TEMP_DIR" + + # Download + gh release download "$SDK_VERSION" \ + -R "$GITHUB_REPO" \ + -p "$ARCHIVE_NAME" || { + echo -e "${RED}Error: Could not download $ARCHIVE_NAME${NC}" + echo "" + echo "Available assets for $SDK_VERSION:" + gh release view "$SDK_VERSION" -R "$GITHUB_REPO" --json assets -q '.assets[].name' 2>/dev/null || echo " (could not list assets)" + exit 1 + } + + echo -e "${GREEN}Downloaded${NC}" + + # Extract + echo -e "${YELLOW}Extracting...${NC}" + + if [ "$ARCHIVE_EXT" = "zip" ]; then + unzip -o "$ARCHIVE_NAME" + else + tar -xzf "$ARCHIVE_NAME" + fi + + # Create directories + mkdir -p "$NATIVE_LIB_DIR" + mkdir -p "$NATIVE_INCLUDE_DIR" + + # Copy libraries + if [ -d "lib" ]; then + cp -P lib/* "$NATIVE_LIB_DIR/" 2>/dev/null || true + fi + + # Copy headers + if [ -d "include" ]; then + cp -r include/* "$NATIVE_INCLUDE_DIR/" 2>/dev/null || true + fi + + # Handle flat structure (files directly in archive) + cp -P *.dylib "$NATIVE_LIB_DIR/" 2>/dev/null || true + cp -P *.so* "$NATIVE_LIB_DIR/" 2>/dev/null || true + cp -P *.dll "$NATIVE_LIB_DIR/" 2>/dev/null || true + cp -P *.h "$NATIVE_INCLUDE_DIR/" 2>/dev/null || true + + cd "$SCRIPT_DIR" + + echo -e "${GREEN}Native library installed to $NATIVE_LIB_DIR${NC}" +} + +# Check if native library exists +check_native_library() { + if [ -f "$NATIVE_LIB_DIR/libgopher-orch.dylib" ] || [ -f "$NATIVE_LIB_DIR/libgopher-orch.so" ] || \ + [ -f "$NATIVE_LIB_DIR/libgopher-orch.0.dylib" ] || [ -f "$NATIVE_LIB_DIR/libgopher-orch.0.so" ] || \ + ls "$NATIVE_LIB_DIR"/libgopher-orch*.dylib 1> /dev/null 2>&1 || \ + ls "$NATIVE_LIB_DIR"/libgopher-orch*.so 1> /dev/null 2>&1; then + echo -e "${GREEN}Native library found at $NATIVE_LIB_DIR${NC}" + return 0 + fi + return 1 +} + +# Parse arguments +CONFIG_FILE="server.config" +AUTH_DISABLED=false +SKIP_DOWNLOAD=false + +while [[ $# -gt 0 ]]; do + case $1 in + --no-auth) + AUTH_DISABLED=true + shift + ;; + --config) + CONFIG_FILE="$2" + shift 2 + ;; + --skip-download) + SKIP_DOWNLOAD=true + shift + ;; + --help|-h) + usage + exit 0 + ;; + *) + echo -e "${RED}Unknown option: $1${NC}" + usage + exit 1 + ;; + esac +done + +echo "=========================================" +echo " Gopher Auth MCP Server (Rust)" +echo "=========================================" +echo "" + +# Check Rust/Cargo +check_cargo + +# Download native library if needed +if [ "$SKIP_DOWNLOAD" = false ]; then + if check_native_library; then + echo -e "${YELLOW}Using existing native library. Use --skip-download=false to re-download.${NC}" + else + check_gh_cli + download_native_library + fi +else + if ! check_native_library; then + echo -e "${RED}Error: Native library not found and --skip-download specified${NC}" + exit 1 + fi +fi + +echo "" + +# Set environment for native library loading +export DYLD_LIBRARY_PATH="${NATIVE_LIB_DIR}:${DYLD_LIBRARY_PATH}" +export LD_LIBRARY_PATH="${NATIVE_LIB_DIR}:${LD_LIBRARY_PATH}" +export LIBRARY_PATH="${NATIVE_LIB_DIR}:${LIBRARY_PATH}" + +# Set log level +export RUST_LOG="${RUST_LOG:-info}" + +# Build the project +echo "Building auth-mcp-server..." +cargo build --release +echo -e "${GREEN}Build successful${NC}" +echo "" + +# Create temporary config if --no-auth was specified +if [ "$AUTH_DISABLED" = true ]; then + echo "Running with authentication disabled..." + TEMP_CONFIG=$(mktemp) + cat > "$TEMP_CONFIG" << EOF +# Temporary config with auth disabled +host=0.0.0.0 +port=3001 +server_url=http://localhost:3001 +auth_disabled=true +allowed_scopes=openid profile email mcp:read mcp:admin +EOF + CONFIG_FILE="$TEMP_CONFIG" + trap "rm -f $TEMP_CONFIG" EXIT +fi + +# Check if config file exists +if [ ! -f "$CONFIG_FILE" ]; then + echo -e "${YELLOW}Warning: Config file '$CONFIG_FILE' not found${NC}" + echo "Server will use default configuration with auth disabled" +fi + +# Run the server +echo "Starting Rust Auth MCP Server..." +echo "" +./target/release/auth-mcp-server "$CONFIG_FILE" diff --git a/examples/auth/server.config b/examples/auth/server.config new file mode 100644 index 00000000..9184d2ee --- /dev/null +++ b/examples/auth/server.config @@ -0,0 +1,33 @@ +# Rust Auth MCP Server Configuration +# ==================================== + +# Server settings +host=0.0.0.0 +port=3001 +server_url=https://marni-nightcapped-nonmeditatively.ngrok-free.dev + +# OAuth/IDP settings +# Uncomment and configure for Keycloak or other OAuth provider +client_id=oauth_0a650b79c5a64c3b920ae8c2b20599d9 +client_secret=6BiU2beUi2wIBxY3MUBLyYqoWKa4t0U_kJVm9mvSOKw +auth_server_url=https://auth-test.gopher.security/realms/gopher-mcp +oauth_authorize_url=https://api-test.gopher.security/oauth/authorize +# oauth_token_url derived from auth_server_url: https://auth-test.gopher.security/realms/gopher-mcp-auth/protocol/openid-connect/token + +# Direct OAuth endpoint URLs (optional, derived from auth_server_url if not set) +# jwks_uri=https://keycloak.example.com/realms/mcp/protocol/openid-connect/certs +# issuer=https://keycloak.example.com/realms/mcp +# oauth_authorize_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/auth +# oauth_token_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/token + +# Scopes +exchange_idps=oauth-idp-714982830194556929-google +allowed_scopes=openid profile email scope-001 + +# Cache settings +jwks_cache_duration=3600 +jwks_auto_refresh=true +request_timeout=30 + +# Auth bypass mode (for development) +auth_disabled=false diff --git a/examples/auth/src/config.rs b/examples/auth/src/config.rs new file mode 100644 index 00000000..d66dfcea --- /dev/null +++ b/examples/auth/src/config.rs @@ -0,0 +1,628 @@ +//! Configuration module for the auth MCP server. +//! +//! Provides INI-style configuration file parsing and server configuration. + +use std::collections::HashMap; +use std::path::Path; + +use crate::error::AppError; + +/// Server configuration for the OAuth-protected MCP server. +#[derive(Debug, Clone)] +pub struct AuthServerConfig { + // Server settings + /// Server bind address (e.g., "0.0.0.0") + pub host: String, + /// Server port + pub port: u16, + /// Public server URL for metadata endpoints + pub server_url: String, + + // OAuth/IDP settings + /// Base URL of the authorization server (e.g., Keycloak realm URL) + pub auth_server_url: String, + /// JWKS endpoint URL for token validation + pub jwks_uri: String, + /// Expected token issuer + pub issuer: String, + /// OAuth client ID + pub client_id: String, + /// OAuth client secret + pub client_secret: String, + /// Token endpoint URL + pub token_endpoint: String, + + // Direct OAuth endpoint URLs + /// Authorization endpoint URL + pub oauth_authorize_url: String, + /// Token endpoint URL (alternative to token_endpoint) + pub oauth_token_url: String, + + // Scopes + /// Space-separated list of allowed scopes + pub allowed_scopes: String, + + // Cache settings + /// JWKS cache duration in seconds + pub jwks_cache_duration: u32, + /// Whether to auto-refresh JWKS cache + pub jwks_auto_refresh: bool, + /// Request timeout in milliseconds + pub request_timeout: u32, + + // Auth bypass mode + /// When true, authentication is disabled + pub auth_disabled: bool, +} + +impl Default for AuthServerConfig { + fn default() -> Self { + Self { + host: "0.0.0.0".to_string(), + port: 3001, + server_url: "http://localhost:3001".to_string(), + auth_server_url: String::new(), + jwks_uri: String::new(), + issuer: String::new(), + client_id: String::new(), + client_secret: String::new(), + token_endpoint: String::new(), + oauth_authorize_url: String::new(), + oauth_token_url: String::new(), + allowed_scopes: "mcp:read mcp:admin".to_string(), + jwks_cache_duration: 3600, + jwks_auto_refresh: true, + request_timeout: 5000, + auth_disabled: false, + } + } +} + +impl AuthServerConfig { + /// Create a default config with authentication disabled. + /// + /// Useful for testing and development. + pub fn default_disabled() -> Self { + Self { + auth_disabled: true, + ..Default::default() + } + } + + /// Build configuration from a parsed key-value map. + /// + /// When `auth_server_url` is provided, missing OAuth endpoints are + /// automatically derived using standard OpenID Connect paths: + /// - `jwks_uri` → `{auth_server_url}/protocol/openid-connect/certs` + /// - `issuer` → `{auth_server_url}` + /// - `oauth_authorize_url` → `{auth_server_url}/protocol/openid-connect/auth` + /// - `oauth_token_url` → `{auth_server_url}/protocol/openid-connect/token` + /// - `token_endpoint` → `{auth_server_url}/protocol/openid-connect/token` + pub fn build_from_map(map: HashMap) -> Result { + let defaults = Self::default(); + + // Parse basic fields with defaults + let host = map.get("host").cloned().unwrap_or(defaults.host); + let port = map + .get("port") + .and_then(|s| s.parse().ok()) + .unwrap_or(defaults.port); + let server_url = map.get("server_url").cloned().unwrap_or_else(|| { + // Use localhost for display when binding to all interfaces + let display_host = if host == "0.0.0.0" { "localhost" } else { &host }; + format!("http://{}:{}", display_host, port) + }); + + // Get auth server URL for endpoint derivation + let auth_server_url = map + .get("auth_server_url") + .cloned() + .unwrap_or_default(); + + // Derive endpoints from auth_server_url if not explicitly set + let jwks_uri = map.get("jwks_uri").cloned().unwrap_or_else(|| { + if auth_server_url.is_empty() { + String::new() + } else { + format!("{}/protocol/openid-connect/certs", auth_server_url) + } + }); + + let issuer = map.get("issuer").cloned().unwrap_or_else(|| { + auth_server_url.clone() + }); + + let oauth_authorize_url = map.get("oauth_authorize_url").cloned().unwrap_or_else(|| { + if auth_server_url.is_empty() { + String::new() + } else { + format!("{}/protocol/openid-connect/auth", auth_server_url) + } + }); + + let oauth_token_url = map.get("oauth_token_url").cloned().unwrap_or_else(|| { + if auth_server_url.is_empty() { + String::new() + } else { + format!("{}/protocol/openid-connect/token", auth_server_url) + } + }); + + let token_endpoint = map.get("token_endpoint").cloned().unwrap_or_else(|| { + if auth_server_url.is_empty() { + String::new() + } else { + format!("{}/protocol/openid-connect/token", auth_server_url) + } + }); + + // Parse other OAuth fields + let client_id = map.get("client_id").cloned().unwrap_or_default(); + let client_secret = map.get("client_secret").cloned().unwrap_or_default(); + let allowed_scopes = map + .get("allowed_scopes") + .cloned() + .unwrap_or(defaults.allowed_scopes); + + // Parse cache settings + let jwks_cache_duration = map + .get("jwks_cache_duration") + .and_then(|s| s.parse().ok()) + .unwrap_or(defaults.jwks_cache_duration); + + let jwks_auto_refresh = map + .get("jwks_auto_refresh") + .map(|s| s == "true" || s == "1") + .unwrap_or(defaults.jwks_auto_refresh); + + let request_timeout = map + .get("request_timeout") + .and_then(|s| s.parse().ok()) + .unwrap_or(defaults.request_timeout); + + // Parse auth disabled flag + let auth_disabled = map + .get("auth_disabled") + .map(|s| s == "true" || s == "1") + .unwrap_or(defaults.auth_disabled); + + Ok(Self { + host, + port, + server_url, + auth_server_url, + jwks_uri, + issuer, + client_id, + client_secret, + token_endpoint, + oauth_authorize_url, + oauth_token_url, + allowed_scopes, + jwks_cache_duration, + jwks_auto_refresh, + request_timeout, + auth_disabled, + }) + } + + /// Validate configuration. + /// + /// When authentication is enabled, validates that required fields + /// are present: + /// - `client_id` is not empty + /// - `client_secret` is not empty + /// - `jwks_uri` is not empty + /// + /// Validation is skipped when `auth_disabled` is true. + pub fn validate(&self) -> Result<(), AppError> { + // Skip validation when auth is disabled + if self.auth_disabled { + return Ok(()); + } + + if self.client_id.is_empty() { + return Err(AppError::Config( + "client_id is required when authentication is enabled".to_string(), + )); + } + + if self.client_secret.is_empty() { + return Err(AppError::Config( + "client_secret is required when authentication is enabled".to_string(), + )); + } + + if self.jwks_uri.is_empty() { + return Err(AppError::Config( + "jwks_uri is required when authentication is enabled (provide jwks_uri or auth_server_url)".to_string(), + )); + } + + Ok(()) + } + + /// Load configuration from an INI-style file. + /// + /// Reads the file, parses it, builds the config, and validates it. + pub fn from_file>(path: P) -> Result { + let path = path.as_ref(); + + let content = std::fs::read_to_string(path).map_err(|e| { + AppError::Config(format!( + "Failed to read config file '{}': {}", + path.display(), + e + )) + })?; + + let map = parse_config_file(&content); + let config = Self::build_from_map(map)?; + config.validate()?; + + Ok(config) + } +} + +/// Parse INI-style configuration file content. +/// +/// Handles: +/// - Comments (lines starting with `#`) +/// - Empty lines +/// - Values containing `=` characters (splits only on first `=`) +/// - Whitespace trimming for keys and values +pub fn parse_config_file(content: &str) -> HashMap { + let mut map = HashMap::new(); + + for line in content.lines() { + let trimmed = line.trim(); + + // Skip empty lines and comments + if trimmed.is_empty() || trimmed.starts_with('#') { + continue; + } + + // Split on first '=' only to handle values containing '=' + if let Some(pos) = trimmed.find('=') { + let key = trimmed[..pos].trim(); + let value = trimmed[pos + 1..].trim(); + + if !key.is_empty() { + map.insert(key.to_string(), value.to_string()); + } + } + } + + map +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = AuthServerConfig::default(); + + assert_eq!(config.host, "0.0.0.0"); + assert_eq!(config.port, 3001); + assert_eq!(config.server_url, "http://localhost:3001"); + assert!(!config.auth_disabled); + assert_eq!(config.jwks_cache_duration, 3600); + assert!(config.jwks_auto_refresh); + assert_eq!(config.request_timeout, 5000); + } + + #[test] + fn test_default_disabled() { + let config = AuthServerConfig::default_disabled(); + + assert!(config.auth_disabled); + assert_eq!(config.host, "0.0.0.0"); + assert_eq!(config.port, 3001); + } + + #[test] + fn test_parse_basic_key_value() { + let content = "host=localhost\nport=3001"; + let map = parse_config_file(content); + + assert_eq!(map.get("host"), Some(&"localhost".to_string())); + assert_eq!(map.get("port"), Some(&"3001".to_string())); + } + + #[test] + fn test_parse_comments_skipped() { + let content = "# This is a comment\nhost=localhost\n# Another comment\nport=3001"; + let map = parse_config_file(content); + + assert_eq!(map.len(), 2); + assert_eq!(map.get("host"), Some(&"localhost".to_string())); + assert_eq!(map.get("port"), Some(&"3001".to_string())); + } + + #[test] + fn test_parse_empty_lines_skipped() { + let content = "host=localhost\n\n\nport=3001\n\n"; + let map = parse_config_file(content); + + assert_eq!(map.len(), 2); + assert_eq!(map.get("host"), Some(&"localhost".to_string())); + assert_eq!(map.get("port"), Some(&"3001".to_string())); + } + + #[test] + fn test_parse_values_with_equals() { + let content = "auth_url=https://auth.example.com?param=value&other=123"; + let map = parse_config_file(content); + + assert_eq!( + map.get("auth_url"), + Some(&"https://auth.example.com?param=value&other=123".to_string()) + ); + } + + #[test] + fn test_parse_whitespace_trimmed() { + let content = " host = localhost \n port= 3001"; + let map = parse_config_file(content); + + assert_eq!(map.get("host"), Some(&"localhost".to_string())); + assert_eq!(map.get("port"), Some(&"3001".to_string())); + } + + #[test] + fn test_parse_empty_value() { + let content = "empty_key="; + let map = parse_config_file(content); + + assert_eq!(map.get("empty_key"), Some(&"".to_string())); + } + + #[test] + fn test_build_from_map_defaults() { + let map = HashMap::new(); + let config = AuthServerConfig::build_from_map(map).unwrap(); + + assert_eq!(config.host, "0.0.0.0"); + assert_eq!(config.port, 3001); + // server_url uses localhost for display when host is 0.0.0.0 + assert_eq!(config.server_url, "http://localhost:3001"); + assert!(!config.auth_disabled); + } + + #[test] + fn test_build_from_map_custom_values() { + let mut map = HashMap::new(); + map.insert("host".to_string(), "127.0.0.1".to_string()); + map.insert("port".to_string(), "8080".to_string()); + map.insert("client_id".to_string(), "my-client".to_string()); + map.insert("auth_disabled".to_string(), "true".to_string()); + + let config = AuthServerConfig::build_from_map(map).unwrap(); + + assert_eq!(config.host, "127.0.0.1"); + assert_eq!(config.port, 8080); + assert_eq!(config.server_url, "http://127.0.0.1:8080"); + assert_eq!(config.client_id, "my-client"); + assert!(config.auth_disabled); + } + + #[test] + fn test_build_from_map_endpoint_derivation() { + let mut map = HashMap::new(); + map.insert( + "auth_server_url".to_string(), + "https://auth.example.com/realms/test".to_string(), + ); + + let config = AuthServerConfig::build_from_map(map).unwrap(); + + assert_eq!( + config.jwks_uri, + "https://auth.example.com/realms/test/protocol/openid-connect/certs" + ); + assert_eq!( + config.issuer, + "https://auth.example.com/realms/test" + ); + assert_eq!( + config.oauth_authorize_url, + "https://auth.example.com/realms/test/protocol/openid-connect/auth" + ); + assert_eq!( + config.oauth_token_url, + "https://auth.example.com/realms/test/protocol/openid-connect/token" + ); + assert_eq!( + config.token_endpoint, + "https://auth.example.com/realms/test/protocol/openid-connect/token" + ); + } + + #[test] + fn test_build_from_map_explicit_endpoints_override() { + let mut map = HashMap::new(); + map.insert( + "auth_server_url".to_string(), + "https://auth.example.com/realms/test".to_string(), + ); + map.insert( + "jwks_uri".to_string(), + "https://custom.example.com/jwks".to_string(), + ); + + let config = AuthServerConfig::build_from_map(map).unwrap(); + + // Explicit value should override derived + assert_eq!(config.jwks_uri, "https://custom.example.com/jwks"); + // Other endpoints still derived + assert_eq!( + config.oauth_authorize_url, + "https://auth.example.com/realms/test/protocol/openid-connect/auth" + ); + } + + #[test] + fn test_build_from_map_cache_settings() { + let mut map = HashMap::new(); + map.insert("jwks_cache_duration".to_string(), "7200".to_string()); + map.insert("jwks_auto_refresh".to_string(), "false".to_string()); + map.insert("request_timeout".to_string(), "10000".to_string()); + + let config = AuthServerConfig::build_from_map(map).unwrap(); + + assert_eq!(config.jwks_cache_duration, 7200); + assert!(!config.jwks_auto_refresh); + assert_eq!(config.request_timeout, 10000); + } + + #[test] + fn test_build_from_map_boolean_parsing() { + // Test "1" as true + let mut map = HashMap::new(); + map.insert("auth_disabled".to_string(), "1".to_string()); + let config = AuthServerConfig::build_from_map(map).unwrap(); + assert!(config.auth_disabled); + + // Test "true" as true + let mut map = HashMap::new(); + map.insert("jwks_auto_refresh".to_string(), "true".to_string()); + let config = AuthServerConfig::build_from_map(map).unwrap(); + assert!(config.jwks_auto_refresh); + } + + #[test] + fn test_validate_passes_with_valid_config() { + let mut map = HashMap::new(); + map.insert("client_id".to_string(), "my-client".to_string()); + map.insert("client_secret".to_string(), "secret".to_string()); + map.insert( + "auth_server_url".to_string(), + "https://auth.example.com".to_string(), + ); + + let config = AuthServerConfig::build_from_map(map).unwrap(); + assert!(config.validate().is_ok()); + } + + #[test] + fn test_validate_fails_missing_client_id() { + let mut map = HashMap::new(); + map.insert("client_secret".to_string(), "secret".to_string()); + map.insert("jwks_uri".to_string(), "https://example.com/jwks".to_string()); + + let config = AuthServerConfig::build_from_map(map).unwrap(); + let result = config.validate(); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("client_id")); + } + + #[test] + fn test_validate_fails_missing_client_secret() { + let mut map = HashMap::new(); + map.insert("client_id".to_string(), "my-client".to_string()); + map.insert("jwks_uri".to_string(), "https://example.com/jwks".to_string()); + + let config = AuthServerConfig::build_from_map(map).unwrap(); + let result = config.validate(); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("client_secret")); + } + + #[test] + fn test_validate_fails_missing_jwks_uri() { + let mut map = HashMap::new(); + map.insert("client_id".to_string(), "my-client".to_string()); + map.insert("client_secret".to_string(), "secret".to_string()); + + let config = AuthServerConfig::build_from_map(map).unwrap(); + let result = config.validate(); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("jwks_uri")); + } + + #[test] + fn test_validate_skipped_when_auth_disabled() { + // Empty config should fail validation normally + let config = AuthServerConfig::default(); + assert!(config.validate().is_err()); + + // But should pass when auth is disabled + let config = AuthServerConfig::default_disabled(); + assert!(config.validate().is_ok()); + } + + #[test] + fn test_from_file_success() { + use std::io::Write; + + // Create a temporary config file + let dir = std::env::temp_dir(); + let path = dir.join("test_config.ini"); + + let content = r#" +# Test configuration +host=127.0.0.1 +port=8080 +client_id=test-client +client_secret=test-secret +auth_server_url=https://auth.example.com/realms/test +"#; + + let mut file = std::fs::File::create(&path).unwrap(); + file.write_all(content.as_bytes()).unwrap(); + + let config = AuthServerConfig::from_file(&path).unwrap(); + + assert_eq!(config.host, "127.0.0.1"); + assert_eq!(config.port, 8080); + assert_eq!(config.client_id, "test-client"); + assert_eq!(config.client_secret, "test-secret"); + assert_eq!( + config.jwks_uri, + "https://auth.example.com/realms/test/protocol/openid-connect/certs" + ); + + // Cleanup + std::fs::remove_file(&path).ok(); + } + + #[test] + fn test_from_file_missing_file() { + let result = AuthServerConfig::from_file("/nonexistent/path/config.ini"); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Failed to read")); + } + + #[test] + fn test_from_file_validation_error() { + use std::io::Write; + + // Create a config file missing required fields + let dir = std::env::temp_dir(); + let path = dir.join("test_invalid_config.ini"); + + let content = r#" +host=127.0.0.1 +port=8080 +# Missing client_id, client_secret, auth_server_url +"#; + + let mut file = std::fs::File::create(&path).unwrap(); + file.write_all(content.as_bytes()).unwrap(); + + let result = AuthServerConfig::from_file(&path); + + assert!(result.is_err()); + // Should fail validation + assert!(result.unwrap_err().to_string().contains("required")); + + // Cleanup + std::fs::remove_file(&path).ok(); + } +} diff --git a/examples/auth/src/cors.rs b/examples/auth/src/cors.rs new file mode 100644 index 00000000..ff129455 --- /dev/null +++ b/examples/auth/src/cors.rs @@ -0,0 +1,166 @@ +//! CORS utilities for the auth MCP server. +//! +//! Provides CORS headers and handlers matching the TypeScript implementation +//! for browser compatibility with MCP clients. + +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, +}; + +/// CORS allowed origin (permissive for MCP clients). +pub const CORS_ALLOW_ORIGIN: &str = "*"; + +/// CORS allowed methods. +pub const CORS_ALLOW_METHODS: &str = "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD"; + +/// CORS allowed headers (includes MCP-specific headers). +pub const CORS_ALLOW_HEADERS: &str = "Accept, Accept-Language, Content-Language, Content-Type, \ + Authorization, X-Requested-With, Origin, Cache-Control, Pragma, \ + Mcp-Session-Id, Mcp-Protocol-Version"; + +/// CORS exposed headers. +pub const CORS_EXPOSE_HEADERS: &str = "WWW-Authenticate, Content-Length, Content-Type"; + +/// CORS max age in seconds (24 hours). +pub const CORS_MAX_AGE: &str = "86400"; + +/// OPTIONS handler for CORS preflight requests. +/// +/// Returns 204 No Content with all CORS headers set. +pub async fn options_handler() -> impl IntoResponse { + ( + StatusCode::NO_CONTENT, + [ + ("Access-Control-Allow-Origin", CORS_ALLOW_ORIGIN), + ("Access-Control-Allow-Methods", CORS_ALLOW_METHODS), + ("Access-Control-Allow-Headers", CORS_ALLOW_HEADERS), + ("Access-Control-Expose-Headers", CORS_EXPOSE_HEADERS), + ("Access-Control-Max-Age", CORS_MAX_AGE), + ("Content-Length", "0"), + ], + ) +} + +/// Add CORS headers to a response. +/// +/// Wraps any response type and adds the necessary CORS headers +/// for cross-origin requests. +pub fn with_cors_headers(response: T) -> Response { + let mut res = response.into_response(); + let headers = res.headers_mut(); + + headers.insert( + "Access-Control-Allow-Origin", + CORS_ALLOW_ORIGIN.parse().unwrap(), + ); + headers.insert( + "Access-Control-Allow-Methods", + CORS_ALLOW_METHODS.parse().unwrap(), + ); + headers.insert( + "Access-Control-Allow-Headers", + CORS_ALLOW_HEADERS.parse().unwrap(), + ); + headers.insert( + "Access-Control-Expose-Headers", + CORS_EXPOSE_HEADERS.parse().unwrap(), + ); + headers.insert("Access-Control-Max-Age", CORS_MAX_AGE.parse().unwrap()); + + res +} + +#[cfg(test)] +mod tests { + use super::*; + use http_body_util::BodyExt; + + #[tokio::test] + async fn test_options_handler_status() { + let response = options_handler().await.into_response(); + assert_eq!(response.status(), StatusCode::NO_CONTENT); + } + + #[tokio::test] + async fn test_options_handler_headers() { + let response = options_handler().await.into_response(); + let headers = response.headers(); + + assert_eq!( + headers.get("Access-Control-Allow-Origin").unwrap(), + "*" + ); + assert_eq!( + headers.get("Access-Control-Allow-Methods").unwrap(), + CORS_ALLOW_METHODS + ); + assert!(headers + .get("Access-Control-Allow-Headers") + .unwrap() + .to_str() + .unwrap() + .contains("Mcp-Session-Id")); + assert!(headers + .get("Access-Control-Allow-Headers") + .unwrap() + .to_str() + .unwrap() + .contains("Mcp-Protocol-Version")); + assert_eq!( + headers.get("Access-Control-Max-Age").unwrap(), + "86400" + ); + assert_eq!(headers.get("Content-Length").unwrap(), "0"); + } + + #[tokio::test] + async fn test_options_handler_empty_body() { + let response = options_handler().await.into_response(); + let body = response.into_body().collect().await.unwrap().to_bytes(); + assert!(body.is_empty()); + } + + #[tokio::test] + async fn test_with_cors_headers() { + let original = (StatusCode::OK, "Hello"); + let response = with_cors_headers(original); + + assert_eq!(response.status(), StatusCode::OK); + + let headers = response.headers(); + assert_eq!( + headers.get("Access-Control-Allow-Origin").unwrap(), + "*" + ); + assert_eq!( + headers.get("Access-Control-Allow-Methods").unwrap(), + CORS_ALLOW_METHODS + ); + assert_eq!( + headers.get("Access-Control-Max-Age").unwrap(), + "86400" + ); + } + + #[tokio::test] + async fn test_with_cors_headers_preserves_body() { + use axum::Json; + use serde_json::json; + + let original = Json(json!({"message": "test"})); + let response = with_cors_headers(original); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["message"], "test"); + } + + #[test] + fn test_cors_headers_include_mcp_headers() { + // Verify MCP-specific headers are in the allowed list + assert!(CORS_ALLOW_HEADERS.contains("Mcp-Session-Id")); + assert!(CORS_ALLOW_HEADERS.contains("Mcp-Protocol-Version")); + } +} diff --git a/examples/auth/src/error.rs b/examples/auth/src/error.rs new file mode 100644 index 00000000..f47ff06a --- /dev/null +++ b/examples/auth/src/error.rs @@ -0,0 +1,198 @@ +//! Error types for the auth MCP server. +//! +//! Defines the main error enum with variants for configuration, +//! authentication, FFI, JSON-RPC, and internal errors. + +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::Serialize; +use thiserror::Error; + +// Re-export gopher_mcp_rust error for convenience +pub use gopher_mcp_rust::Error as GopherOrchError; + +/// Application error type. +#[derive(Error, Debug)] +pub enum AppError { + /// Configuration error. + #[error("Configuration error: {0}")] + Config(String), + + /// Authentication error. + #[error("Auth error: {0}")] + Auth(String), + + /// FFI/native library error. + #[error("FFI error: {0}")] + Ffi(String), + + /// JSON-RPC protocol error. + #[error("JSON-RPC error: {message}")] + JsonRpc { code: i32, message: String }, + + /// Internal server error. + #[error("Internal error: {0}")] + Internal(String), +} + +impl From for AppError { + fn from(err: GopherOrchError) -> Self { + match err { + GopherOrchError::Auth(msg) => AppError::Auth(msg), + GopherOrchError::Library(msg) => AppError::Ffi(msg), + other => AppError::Ffi(other.to_string()), + } + } +} + +/// JSON error response body. +#[derive(Serialize)] +struct ErrorResponse { + error: String, + #[serde(skip_serializing_if = "Option::is_none")] + code: Option, +} + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + let (status, error_response) = match &self { + AppError::Config(msg) => ( + StatusCode::INTERNAL_SERVER_ERROR, + ErrorResponse { + error: msg.clone(), + code: None, + }, + ), + AppError::Auth(msg) => ( + StatusCode::UNAUTHORIZED, + ErrorResponse { + error: msg.clone(), + code: None, + }, + ), + AppError::Ffi(msg) => ( + StatusCode::INTERNAL_SERVER_ERROR, + ErrorResponse { + error: msg.clone(), + code: None, + }, + ), + AppError::JsonRpc { code, message } => ( + StatusCode::BAD_REQUEST, + ErrorResponse { + error: message.clone(), + code: Some(*code), + }, + ), + AppError::Internal(msg) => ( + StatusCode::INTERNAL_SERVER_ERROR, + ErrorResponse { + error: msg.clone(), + code: None, + }, + ), + }; + + (status, Json(error_response)).into_response() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http_body_util::BodyExt; + + #[tokio::test] + async fn test_config_error_response() { + let error = AppError::Config("missing field".to_string()); + let response = error.into_response(); + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["error"], "missing field"); + assert!(json.get("code").is_none() || json["code"].is_null()); + } + + #[tokio::test] + async fn test_auth_error_response() { + let error = AppError::Auth("invalid token".to_string()); + let response = error.into_response(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["error"], "invalid token"); + } + + #[tokio::test] + async fn test_ffi_error_response() { + let error = AppError::Ffi("library not found".to_string()); + let response = error.into_response(); + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["error"], "library not found"); + } + + #[tokio::test] + async fn test_jsonrpc_error_response() { + let error = AppError::JsonRpc { + code: -32600, + message: "Invalid Request".to_string(), + }; + let response = error.into_response(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["error"], "Invalid Request"); + assert_eq!(json["code"], -32600); + } + + #[tokio::test] + async fn test_internal_error_response() { + let error = AppError::Internal("unexpected error".to_string()); + let response = error.into_response(); + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["error"], "unexpected error"); + } + + #[test] + fn test_error_display() { + let config_err = AppError::Config("test config".to_string()); + assert_eq!(format!("{}", config_err), "Configuration error: test config"); + + let auth_err = AppError::Auth("test auth".to_string()); + assert_eq!(format!("{}", auth_err), "Auth error: test auth"); + + let ffi_err = AppError::Ffi("test ffi".to_string()); + assert_eq!(format!("{}", ffi_err), "FFI error: test ffi"); + + let jsonrpc_err = AppError::JsonRpc { + code: -32700, + message: "Parse error".to_string(), + }; + assert_eq!(format!("{}", jsonrpc_err), "JSON-RPC error: Parse error"); + + let internal_err = AppError::Internal("test internal".to_string()); + assert_eq!(format!("{}", internal_err), "Internal error: test internal"); + } +} diff --git a/examples/auth/src/ffi/mod.rs b/examples/auth/src/ffi/mod.rs new file mode 100644 index 00000000..4cce2c69 --- /dev/null +++ b/examples/auth/src/ffi/mod.rs @@ -0,0 +1,72 @@ +//! FFI bindings module. +//! +//! Re-exports gopher-auth types from the gopher-mcp-rust library with a thin +//! wrapper to support testing without the native library. + +use crate::error::AppError; + +// Re-export payload types directly +pub use gopher_mcp_rust::TokenPayload; +pub use gopher_mcp_rust::ValidationResult; + +/// Wrapper around gopher-mcp-rust's GopherAuthClient. +/// +/// Provides the same interface but allows creating dummy instances for testing. +pub struct GopherAuthClient { + inner: Option, +} + +unsafe impl Send for GopherAuthClient {} +unsafe impl Sync for GopherAuthClient {} + +impl GopherAuthClient { + /// Create a new client. + pub fn new(jwks_uri: &str, issuer: &str) -> Result { + let inner = gopher_mcp_rust::GopherAuthClient::new(jwks_uri, issuer)?; + Ok(Self { inner: Some(inner) }) + } + + /// Validate a JWT token. + pub fn validate_token(&self, token: &str, clock_skew: u32) -> ValidationResult { + match &self.inner { + Some(client) => client.validate_token(token, clock_skew), + None => ValidationResult::failure(-1, "Client not initialized"), + } + } + + /// Extract payload from a JWT token. + pub fn extract_payload(&self, token: &str) -> Result { + match &self.inner { + Some(client) => client.extract_payload(token).map_err(Into::into), + None => Err(AppError::Ffi("Client not initialized".to_string())), + } + } + + /// Set a client option. + pub fn set_option(&self, key: &str, value: &str) -> Result<(), AppError> { + match &self.inner { + Some(client) => client.set_option(key, value).map_err(Into::into), + None => Err(AppError::Ffi("Client not initialized".to_string())), + } + } + + /// Destroy the client. + pub fn destroy(&mut self) { + if let Some(ref mut client) = self.inner { + client.destroy(); + } + self.inner = None; + } + + /// Create a dummy client for testing. + #[cfg(test)] + pub fn dummy() -> Self { + Self { inner: None } + } +} + +impl Drop for GopherAuthClient { + fn drop(&mut self) { + self.destroy(); + } +} diff --git a/examples/auth/src/main.rs b/examples/auth/src/main.rs new file mode 100644 index 00000000..ae848791 --- /dev/null +++ b/examples/auth/src/main.rs @@ -0,0 +1,324 @@ +//! Rust Auth MCP Server +//! +//! An OAuth-protected MCP server example demonstrating JWT token validation +//! and scope-based access control for MCP tools. + +mod config; +mod cors; +mod error; +mod ffi; +mod middleware; +mod routes; +mod tools; + +use std::env; +use std::sync::Arc; + +use axum::{ + extract::FromRef, + middleware as axum_middleware, + routing::{get, post}, + Router, +}; +use tokio::net::TcpListener; +use tokio::sync::watch; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use crate::config::AuthServerConfig; +use crate::ffi::GopherAuthClient; +use crate::middleware::{auth_middleware, AuthState}; +use crate::routes::health::{health_handler, HealthState}; +use crate::routes::mcp_handler::{mcp_handler, mcp_options, McpHandler}; +use crate::routes::oauth_endpoints::{ + authorization_server_metadata, oauth_authorize, oauth_register, openid_configuration, + protected_resource_metadata, +}; +use crate::tools::weather_tools::register_weather_tools; + +/// Combined application state. +#[derive(Clone)] +pub struct AppState { + /// Server configuration. + pub config: Arc, + /// Health endpoint state. + pub health: Arc, + /// MCP handler state. + pub mcp: Arc, +} + +// Implement FromRef to extract individual states +impl FromRef for Arc { + fn from_ref(state: &AppState) -> Self { + state.config.clone() + } +} + +impl FromRef for Arc { + fn from_ref(state: &AppState) -> Self { + state.health.clone() + } +} + +impl FromRef for Arc { + fn from_ref(state: &AppState) -> Self { + state.mcp.clone() + } +} + +/// Print server banner. +fn print_banner() { + println!(); + println!("╔══════════════════════════════════════════════════════════╗"); + println!("║ Rust Auth MCP Server ║"); + println!("║ OAuth-Protected MCP Server Example ║"); + println!("╚══════════════════════════════════════════════════════════╝"); + println!(); +} + +/// Print available endpoints. +fn print_endpoints(config: &AuthServerConfig) { + println!("Available Endpoints:"); + println!("────────────────────────────────────────────────────────────"); + println!(" Health:"); + println!(" GET {}/health", config.server_url); + println!(); + println!(" OAuth Discovery:"); + println!(" GET {}/.well-known/oauth-protected-resource", config.server_url); + println!(" GET {}/.well-known/oauth-protected-resource/mcp", config.server_url); + println!(" GET {}/.well-known/oauth-authorization-server", config.server_url); + println!(" GET {}/.well-known/openid-configuration", config.server_url); + println!(); + println!(" OAuth:"); + println!(" GET {}/oauth/authorize", config.server_url); + println!(" POST {}/oauth/register", config.server_url); + println!(); + println!(" MCP (Protected):"); + println!(" POST {}/mcp", config.server_url); + println!(" POST {}/rpc", config.server_url); + println!("────────────────────────────────────────────────────────────"); + println!(); +} + +/// Wait for shutdown signal. +/// +/// Handles SIGINT (Ctrl+C) and SIGTERM (on Unix) for graceful shutdown. +async fn shutdown_signal(mut shutdown_rx: watch::Receiver) { + let ctrl_c = async { + tokio::signal::ctrl_c() + .await + .expect("Failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("Failed to install SIGTERM handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + let shutdown_watch = async { + // Wait for explicit shutdown signal + while !*shutdown_rx.borrow_and_update() { + if shutdown_rx.changed().await.is_err() { + break; + } + } + }; + + tokio::select! { + _ = ctrl_c => { + tracing::info!("Received Ctrl+C signal"); + } + _ = terminate => { + tracing::info!("Received SIGTERM signal"); + } + _ = shutdown_watch => { + tracing::info!("Received shutdown signal"); + } + } +} + +#[tokio::main] +async fn main() { + // Initialize tracing + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + print_banner(); + + // Load configuration + let config_path = env::args().nth(1).unwrap_or_else(|| "server.config".to_string()); + + let config = match AuthServerConfig::from_file(&config_path) { + Ok(c) => { + tracing::info!("Configuration loaded from {}", config_path); + c + } + Err(e) => { + tracing::warn!("Failed to load config from {}: {}", config_path, e); + tracing::info!("Using default configuration with auth disabled"); + AuthServerConfig::default_disabled() + } + }; + + // Initialize auth client if auth is enabled + let auth_client: Option> = if config.auth_disabled { + tracing::info!("Authentication is DISABLED"); + None + } else { + tracing::info!("Initializing auth library..."); + match GopherAuthClient::new(&config.jwks_uri, &config.issuer) { + Ok(client) => { + // Set client options + if let Err(e) = client.set_option( + "cache_duration", + &config.jwks_cache_duration.to_string(), + ) { + tracing::warn!("Failed to set cache_duration: {}", e); + } + if let Err(e) = client.set_option( + "auto_refresh", + if config.jwks_auto_refresh { "true" } else { "false" }, + ) { + tracing::warn!("Failed to set auto_refresh: {}", e); + } + if let Err(e) = client.set_option( + "request_timeout", + &config.request_timeout.to_string(), + ) { + tracing::warn!("Failed to set request_timeout: {}", e); + } + tracing::info!("Auth library initialized successfully"); + Some(Arc::new(client)) + } + Err(e) => { + tracing::warn!("Failed to initialize auth library: {}", e); + tracing::warn!("Continuing with auth disabled"); + None + } + } + }; + + // Create shared state + let config = Arc::new(config); + let health_state = Arc::new(HealthState::new(Some("1.0.0".to_string()))); + let auth_state = Arc::new(AuthState::new(auth_client.clone(), (*config).clone())); + + // Create MCP handler and register tools + let mut mcp = McpHandler::new(); + register_weather_tools(&mut mcp, config.auth_disabled); + let mcp_state = Arc::new(mcp); + + // Create combined app state + let app_state = AppState { + config: config.clone(), + health: health_state, + mcp: mcp_state, + }; + + // Build router + let app = Router::new() + // Health endpoint + .route("/health", get(health_handler)) + // OAuth discovery endpoints + .route( + "/.well-known/oauth-protected-resource", + get(protected_resource_metadata).options(options_handler), + ) + .route( + "/.well-known/oauth-protected-resource/mcp", + get(protected_resource_metadata).options(options_handler), + ) + .route( + "/.well-known/oauth-authorization-server", + get(authorization_server_metadata).options(options_handler), + ) + .route( + "/.well-known/openid-configuration", + get(openid_configuration).options(options_handler), + ) + // OAuth endpoints + .route( + "/oauth/authorize", + get(oauth_authorize).options(options_handler), + ) + .route( + "/oauth/register", + post(oauth_register).options(options_handler), + ) + // MCP endpoints + .route("/mcp", post(mcp_handler).options(mcp_options)) + .route("/rpc", post(mcp_handler).options(mcp_options)) + // Add combined state + .with_state(app_state) + // Add auth middleware + .layer(axum_middleware::from_fn_with_state( + auth_state, + auth_middleware, + )); + + // Print startup information + let addr = format!("{}:{}", config.host, config.port); + print_endpoints(&config); + + if config.auth_disabled { + println!("⚠️ Authentication is DISABLED - all requests are allowed"); + } else { + println!("🔒 Authentication is ENABLED"); + println!(" JWKS URI: {}", config.jwks_uri); + println!(" Issuer: {}", config.issuer); + } + println!(); + + tracing::info!("Server starting on {}", addr); + println!("🚀 Server listening on http://{}", addr); + println!(); + println!("Press Ctrl+C to shutdown"); + println!(); + + // Create shutdown channel + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + // Start server with graceful shutdown + let listener = TcpListener::bind(&addr).await.expect("Failed to bind address"); + let server = axum::serve(listener, app) + .with_graceful_shutdown(shutdown_signal(shutdown_rx)); + + // Run server + if let Err(e) = server.await { + tracing::error!("Server error: {}", e); + } + + // Shutdown sequence + println!(); + println!("Shutting down..."); + tracing::info!("Server shutdown initiated"); + + // Signal shutdown (in case it wasn't already signaled) + let _ = shutdown_tx.send(true); + + // Cleanup auth client if present + if auth_client.is_some() { + tracing::info!("Cleaning up auth client..."); + // The Arc will be dropped when all references are gone + // The Drop implementation on GopherAuthClient will call destroy() + println!("Auth client destroyed"); + } + + tracing::info!("Server shutdown complete"); + println!("Goodbye!"); +} + +/// Generic OPTIONS handler for CORS preflight. +async fn options_handler() -> impl axum::response::IntoResponse { + cors::options_handler().await +} diff --git a/examples/auth/src/middleware/mod.rs b/examples/auth/src/middleware/mod.rs new file mode 100644 index 00000000..2dc3759c --- /dev/null +++ b/examples/auth/src/middleware/mod.rs @@ -0,0 +1,11 @@ +//! Authentication middleware module. +//! +//! Provides OAuth/JWT authentication middleware for protecting routes. + +pub mod oauth_auth; + +// Re-export commonly used types +pub use oauth_auth::{ + auth_middleware, cors_preflight_response, extract_token, unauthorized_response, AuthContext, + AuthState, +}; diff --git a/examples/auth/src/middleware/oauth_auth.rs b/examples/auth/src/middleware/oauth_auth.rs new file mode 100644 index 00000000..3743c267 --- /dev/null +++ b/examples/auth/src/middleware/oauth_auth.rs @@ -0,0 +1,522 @@ +//! OAuth Authentication Middleware +//! +//! Provides OAuth/JWT authentication middleware for protecting MCP routes. +//! Handles token extraction, validation, and scope-based access control. + +use axum::{ + body::Body, + extract::State, + http::{Method, Request, StatusCode}, + middleware::Next, + response::{IntoResponse, Response}, +}; +use serde_json::json; +use std::sync::Arc; + +use crate::config::AuthServerConfig; +use crate::cors::{ + CORS_ALLOW_HEADERS, CORS_ALLOW_METHODS, CORS_ALLOW_ORIGIN, CORS_EXPOSE_HEADERS, CORS_MAX_AGE, +}; +use crate::ffi::GopherAuthClient; + +/// Authentication context from JWT token validation. +/// +/// Contains user information extracted from a validated token. +#[derive(Debug, Clone, Default)] +pub struct AuthContext { + /// User identifier from token subject. + pub user_id: String, + /// Space-separated list of scopes. + pub scopes: String, + /// Token audience. + pub audience: String, + /// Token expiration timestamp (unix seconds). + pub token_expiry: u64, + /// Whether the user is authenticated. + pub authenticated: bool, +} + +impl AuthContext { + /// Check if a specific scope is present. + pub fn has_scope(&self, scope: &str) -> bool { + self.scopes.split_whitespace().any(|s| s == scope) + } +} + +/// Shared state for the auth middleware. +pub struct AuthState { + /// Optional auth client (None if auth is disabled). + pub auth_client: Option>, + /// Server configuration. + pub config: AuthServerConfig, +} + +impl AuthState { + /// Create a new auth state. + pub fn new(auth_client: Option>, config: AuthServerConfig) -> Self { + Self { + auth_client, + config, + } + } + + /// Check if a path requires authentication. + /// + /// Returns false for public paths, true for protected paths. + pub fn requires_auth(&self, path: &str) -> bool { + // If auth is disabled or no auth client, nothing requires auth + if self.config.auth_disabled || self.auth_client.is_none() { + return false; + } + + // Public paths that never require auth + let public_prefixes = [ + "/.well-known/", + "/oauth/", + "/authorize", + "/health", + "/favicon.ico", + ]; + + for prefix in &public_prefixes { + if path.starts_with(prefix) || path == *prefix { + return false; + } + } + + // Exact match for public paths + if path == "/health" || path == "/favicon.ico" { + return false; + } + + // Protected paths + let protected_prefixes = ["/mcp", "/rpc", "/events", "/sse"]; + + for prefix in &protected_prefixes { + if path.starts_with(prefix) { + return true; + } + } + + // Default: protected (fail secure) + true + } +} + +/// Extract bearer token from a request. +/// +/// Looks for the token in: +/// 1. Authorization header (Bearer prefix) +/// 2. access_token query parameter (fallback) +/// +/// # Arguments +/// +/// * `request` - The HTTP request +/// +/// # Returns +/// +/// The token string if found, None otherwise +pub fn extract_token(request: &Request) -> Option { + // Try Authorization header first + if let Some(auth_header) = request.headers().get("authorization") { + if let Ok(auth_str) = auth_header.to_str() { + if let Some(token) = auth_str.strip_prefix("Bearer ") { + return Some(token.to_string()); + } + // Also try lowercase prefix + if let Some(token) = auth_str.strip_prefix("bearer ") { + return Some(token.to_string()); + } + } + } + + // Try access_token query parameter as fallback + if let Some(query) = request.uri().query() { + for pair in query.split('&') { + if let Some(token) = pair.strip_prefix("access_token=") { + return Some(token.to_string()); + } + } + } + + None +} + +/// Create a CORS preflight response. +/// +/// Returns 204 No Content with full CORS headers for OPTIONS requests. +pub fn cors_preflight_response() -> Response { + Response::builder() + .status(StatusCode::NO_CONTENT) + .header("Access-Control-Allow-Origin", CORS_ALLOW_ORIGIN) + .header("Access-Control-Allow-Methods", CORS_ALLOW_METHODS) + .header("Access-Control-Allow-Headers", CORS_ALLOW_HEADERS) + .header("Access-Control-Expose-Headers", CORS_EXPOSE_HEADERS) + .header("Access-Control-Max-Age", CORS_MAX_AGE) + .body(Body::empty()) + .unwrap() +} + +/// Create an unauthorized response with WWW-Authenticate header. +/// +/// Returns 401 Unauthorized with RFC 6750 Bearer scheme header and JSON body. +/// +/// # Arguments +/// +/// * `config` - Server configuration for resource metadata URL +/// * `error` - OAuth error code (e.g., "invalid_token", "invalid_request") +/// * `description` - Human-readable error description +pub fn unauthorized_response(config: &AuthServerConfig, error: &str, description: &str) -> Response { + // Build WWW-Authenticate header value per RFC 6750 + let www_authenticate = format!( + r#"Bearer realm="{server_url}", resource_metadata="{server_url}/.well-known/oauth-protected-resource", scope="{scopes}", error="{error}", error_description="{description}""#, + server_url = config.server_url, + scopes = config.allowed_scopes, + error = error, + description = description + ); + + let body = json!({ + "error": error, + "error_description": description + }); + + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .header("WWW-Authenticate", www_authenticate) + .header("Content-Type", "application/json") + .header("Access-Control-Allow-Origin", CORS_ALLOW_ORIGIN) + .header("Access-Control-Expose-Headers", CORS_EXPOSE_HEADERS) + .body(Body::from(serde_json::to_string(&body).unwrap_or_default())) + .unwrap() +} + +/// Authentication middleware. +/// +/// Validates bearer tokens and injects AuthContext into request extensions. +/// +/// # Flow +/// +/// 1. OPTIONS requests → CORS preflight response +/// 2. Public paths → pass through +/// 3. Extract token → 401 if missing +/// 4. Validate token (placeholder) → 401 if invalid +/// 5. Inject AuthContext → continue to handler +pub async fn auth_middleware( + State(state): State>, + mut request: Request, + next: Next, +) -> Result { + // Handle CORS preflight + if request.method() == Method::OPTIONS { + return Ok(cors_preflight_response()); + } + + let path = request.uri().path().to_string(); + + // Check if this path requires authentication + if !state.requires_auth(&path) { + // Insert default auth context for public paths + request.extensions_mut().insert(AuthContext::default()); + return Ok(next.run(request).await); + } + + // Extract bearer token + let token = match extract_token(&request) { + Some(t) => t, + None => { + return Ok(unauthorized_response( + &state.config, + "invalid_request", + "Missing bearer token", + )); + } + }; + + // TODO: Validate token using gopher-auth FFI + // For now, create a mock auth context if a token is present + // In the real implementation, this would call: + // - state.auth_client.validate_token(&token, clock_skew) + // - state.auth_client.extract_payload(&token) + + // Placeholder: accept any token and extract mock claims + // This will be replaced with actual validation in the FFI implementation + let auth_context = if state.config.auth_disabled { + // Auth disabled: grant full access + AuthContext { + user_id: "anonymous".to_string(), + scopes: state.config.allowed_scopes.clone(), + audience: state.config.server_url.clone(), + token_expiry: u64::MAX, + authenticated: false, + } + } else { + // Placeholder for real token validation + // In production, this would parse and validate the JWT + AuthContext { + user_id: "user".to_string(), + scopes: state.config.allowed_scopes.clone(), + audience: state.config.server_url.clone(), + token_expiry: chrono::Utc::now().timestamp() as u64 + 3600, + authenticated: true, + } + }; + + // Insert auth context into request extensions + request.extensions_mut().insert(auth_context); + + // Continue to the next handler + Ok(next.run(request).await) +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::Request; + + #[test] + fn test_auth_context_default() { + let ctx = AuthContext::default(); + assert_eq!(ctx.user_id, ""); + assert_eq!(ctx.scopes, ""); + assert_eq!(ctx.audience, ""); + assert_eq!(ctx.token_expiry, 0); + assert!(!ctx.authenticated); + } + + #[test] + fn test_auth_context_has_scope() { + let ctx = AuthContext { + scopes: "openid profile mcp:read mcp:admin".to_string(), + ..Default::default() + }; + + assert!(ctx.has_scope("openid")); + assert!(ctx.has_scope("profile")); + assert!(ctx.has_scope("mcp:read")); + assert!(ctx.has_scope("mcp:admin")); + assert!(!ctx.has_scope("mcp:write")); + assert!(!ctx.has_scope("")); + } + + #[test] + fn test_auth_context_has_scope_empty() { + let ctx = AuthContext::default(); + assert!(!ctx.has_scope("openid")); + assert!(!ctx.has_scope("")); + } + + #[test] + fn test_auth_state_requires_auth_disabled() { + let config = AuthServerConfig { + auth_disabled: true, + ..Default::default() + }; + let state = AuthState::new(Some(Arc::new(GopherAuthClient::dummy())), config); + + // Nothing requires auth when auth is disabled + assert!(!state.requires_auth("/mcp")); + assert!(!state.requires_auth("/rpc")); + assert!(!state.requires_auth("/health")); + } + + #[test] + fn test_auth_state_requires_auth_no_client() { + let config = AuthServerConfig { + auth_disabled: false, + ..Default::default() + }; + let state = AuthState::new(None, config); + + // Nothing requires auth when no client + assert!(!state.requires_auth("/mcp")); + assert!(!state.requires_auth("/rpc")); + } + + #[test] + fn test_auth_state_requires_auth_public_paths() { + let config = AuthServerConfig { + auth_disabled: false, + ..Default::default() + }; + let state = AuthState::new(Some(Arc::new(GopherAuthClient::dummy())), config); + + // Public paths don't require auth + assert!(!state.requires_auth("/.well-known/oauth-protected-resource")); + assert!(!state.requires_auth("/.well-known/oauth-authorization-server")); + assert!(!state.requires_auth("/.well-known/openid-configuration")); + assert!(!state.requires_auth("/oauth/authorize")); + assert!(!state.requires_auth("/oauth/register")); + assert!(!state.requires_auth("/health")); + assert!(!state.requires_auth("/favicon.ico")); + } + + #[test] + fn test_auth_state_requires_auth_protected_paths() { + let config = AuthServerConfig { + auth_disabled: false, + ..Default::default() + }; + let state = AuthState::new(Some(Arc::new(GopherAuthClient::dummy())), config); + + // Protected paths require auth + assert!(state.requires_auth("/mcp")); + assert!(state.requires_auth("/mcp/messages")); + assert!(state.requires_auth("/rpc")); + assert!(state.requires_auth("/events")); + assert!(state.requires_auth("/sse")); + } + + #[test] + fn test_auth_state_requires_auth_unknown_paths() { + let config = AuthServerConfig { + auth_disabled: false, + ..Default::default() + }; + let state = AuthState::new(Some(Arc::new(GopherAuthClient::dummy())), config); + + // Unknown paths default to protected + assert!(state.requires_auth("/api/unknown")); + assert!(state.requires_auth("/foo/bar")); + } + + #[test] + fn test_extract_token_from_header() { + let request = Request::builder() + .uri("/mcp") + .header("authorization", "Bearer mytoken123") + .body(()) + .unwrap(); + + assert_eq!(extract_token(&request), Some("mytoken123".to_string())); + } + + #[test] + fn test_extract_token_from_header_lowercase() { + let request = Request::builder() + .uri("/mcp") + .header("authorization", "bearer mytoken456") + .body(()) + .unwrap(); + + assert_eq!(extract_token(&request), Some("mytoken456".to_string())); + } + + #[test] + fn test_extract_token_from_query() { + let request = Request::builder() + .uri("/mcp?access_token=querytoken789") + .body(()) + .unwrap(); + + assert_eq!(extract_token(&request), Some("querytoken789".to_string())); + } + + #[test] + fn test_extract_token_from_query_with_other_params() { + let request = Request::builder() + .uri("/mcp?foo=bar&access_token=querytoken&baz=qux") + .body(()) + .unwrap(); + + assert_eq!(extract_token(&request), Some("querytoken".to_string())); + } + + #[test] + fn test_extract_token_header_priority() { + let request = Request::builder() + .uri("/mcp?access_token=querytoken") + .header("authorization", "Bearer headertoken") + .body(()) + .unwrap(); + + // Header takes priority + assert_eq!(extract_token(&request), Some("headertoken".to_string())); + } + + #[test] + fn test_extract_token_missing() { + let request = Request::builder().uri("/mcp").body(()).unwrap(); + + assert_eq!(extract_token(&request), None); + } + + #[test] + fn test_extract_token_invalid_header() { + let request = Request::builder() + .uri("/mcp") + .header("authorization", "Basic dXNlcjpwYXNz") + .body(()) + .unwrap(); + + assert_eq!(extract_token(&request), None); + } + + #[test] + fn test_cors_preflight_response_status() { + let response = cors_preflight_response(); + assert_eq!(response.status(), StatusCode::NO_CONTENT); + } + + #[test] + fn test_cors_preflight_response_headers() { + let response = cors_preflight_response(); + let headers = response.headers(); + + assert!(headers.get("Access-Control-Allow-Origin").is_some()); + assert!(headers.get("Access-Control-Allow-Methods").is_some()); + assert!(headers.get("Access-Control-Allow-Headers").is_some()); + assert!(headers.get("Access-Control-Expose-Headers").is_some()); + assert!(headers.get("Access-Control-Max-Age").is_some()); + } + + #[test] + fn test_unauthorized_response_status() { + let config = AuthServerConfig { + server_url: "http://localhost:3001".to_string(), + allowed_scopes: "openid mcp:read".to_string(), + ..Default::default() + }; + + let response = unauthorized_response(&config, "invalid_token", "Token expired"); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } + + #[test] + fn test_unauthorized_response_www_authenticate() { + let config = AuthServerConfig { + server_url: "http://localhost:3001".to_string(), + allowed_scopes: "openid mcp:read".to_string(), + ..Default::default() + }; + + let response = unauthorized_response(&config, "invalid_token", "Token expired"); + let www_auth = response.headers().get("WWW-Authenticate").unwrap().to_str().unwrap(); + + assert!(www_auth.contains("Bearer")); + assert!(www_auth.contains("realm=")); + assert!(www_auth.contains("resource_metadata=")); + assert!(www_auth.contains("error=\"invalid_token\"")); + assert!(www_auth.contains("error_description=\"Token expired\"")); + } + + #[test] + fn test_unauthorized_response_cors_headers() { + let config = AuthServerConfig::default(); + let response = unauthorized_response(&config, "invalid_request", "Missing token"); + + assert!(response.headers().get("Access-Control-Allow-Origin").is_some()); + assert!(response.headers().get("Access-Control-Expose-Headers").is_some()); + } + + #[test] + fn test_unauthorized_response_content_type() { + let config = AuthServerConfig::default(); + let response = unauthorized_response(&config, "invalid_request", "Missing token"); + + assert_eq!( + response.headers().get("Content-Type").unwrap(), + "application/json" + ); + } +} diff --git a/examples/auth/src/routes/health.rs b/examples/auth/src/routes/health.rs new file mode 100644 index 00000000..144fd12c --- /dev/null +++ b/examples/auth/src/routes/health.rs @@ -0,0 +1,128 @@ +//! Health check endpoint. +//! +//! Provides a simple health endpoint for monitoring server status. + +use axum::{extract::State, response::IntoResponse, Json}; +use serde::Serialize; +use std::sync::Arc; +use std::time::Instant; + +/// Health check response. +#[derive(Serialize)] +pub struct HealthResponse { + /// Server status (always "ok" when responding). + status: String, + /// ISO 8601 timestamp. + timestamp: String, + /// Optional server version. + #[serde(skip_serializing_if = "Option::is_none")] + version: Option, + /// Optional uptime in seconds. + #[serde(skip_serializing_if = "Option::is_none")] + uptime: Option, +} + +/// Shared state for the health endpoint. +pub struct HealthState { + /// Server start time for uptime calculation. + start_time: Instant, + /// Optional version string. + version: Option, +} + +impl HealthState { + /// Create a new health state with the given version. + pub fn new(version: Option) -> Self { + Self { + start_time: Instant::now(), + version, + } + } +} + +/// Health endpoint handler. +/// +/// Returns JSON with server status, timestamp, optional version, and uptime. +pub async fn health_handler(State(state): State>) -> impl IntoResponse { + let uptime = state.start_time.elapsed().as_secs(); + + Json(HealthResponse { + status: "ok".to_string(), + timestamp: chrono::Utc::now().to_rfc3339(), + version: state.version.clone(), + uptime: Some(uptime), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use http_body_util::BodyExt; + + #[tokio::test] + async fn test_health_response_format() { + let state = Arc::new(HealthState::new(Some("1.0.0".to_string()))); + let response = health_handler(State(state)).await.into_response(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["status"], "ok"); + assert!(json["timestamp"].is_string()); + assert_eq!(json["version"], "1.0.0"); + assert!(json["uptime"].is_number()); + } + + #[tokio::test] + async fn test_health_response_without_version() { + let state = Arc::new(HealthState::new(None)); + let response = health_handler(State(state)).await.into_response(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["status"], "ok"); + // Version should be absent (not null) due to skip_serializing_if + assert!(json.get("version").is_none()); + } + + #[tokio::test] + async fn test_health_state_uptime() { + let state = HealthState::new(None); + + // Wait a tiny bit + std::thread::sleep(std::time::Duration::from_millis(10)); + + let elapsed = state.start_time.elapsed(); + assert!(elapsed.as_millis() >= 10); + } + + #[test] + fn test_health_response_serialization() { + let response = HealthResponse { + status: "ok".to_string(), + timestamp: "2024-01-01T00:00:00Z".to_string(), + version: Some("1.0.0".to_string()), + uptime: Some(100), + }; + + let json = serde_json::to_string(&response).unwrap(); + assert!(json.contains("\"status\":\"ok\"")); + assert!(json.contains("\"version\":\"1.0.0\"")); + assert!(json.contains("\"uptime\":100")); + } + + #[test] + fn test_health_response_omits_none_fields() { + let response = HealthResponse { + status: "ok".to_string(), + timestamp: "2024-01-01T00:00:00Z".to_string(), + version: None, + uptime: None, + }; + + let json = serde_json::to_string(&response).unwrap(); + assert!(!json.contains("version")); + assert!(!json.contains("uptime")); + } +} diff --git a/examples/auth/src/routes/mcp_handler.rs b/examples/auth/src/routes/mcp_handler.rs new file mode 100644 index 00000000..04f91ccc --- /dev/null +++ b/examples/auth/src/routes/mcp_handler.rs @@ -0,0 +1,524 @@ +//! MCP (Model Context Protocol) handler. +//! +//! Implements JSON-RPC 2.0 protocol for MCP tool execution. + +use std::collections::HashMap; +use std::sync::Arc; + +use axum::{ + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, + Extension, Json, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +use crate::cors::{options_handler as cors_options, with_cors_headers}; + +// Re-export AuthContext for backward compatibility with tools +pub use crate::middleware::AuthContext; + +/// JSON-RPC 2.0 error codes. +pub mod error_codes { + /// Parse error - Invalid JSON was received. + pub const PARSE_ERROR: i32 = -32700; + /// Invalid Request - The JSON sent is not a valid Request object. + pub const INVALID_REQUEST: i32 = -32600; + /// Method not found - The method does not exist / is not available. + pub const METHOD_NOT_FOUND: i32 = -32601; + /// Invalid params - Invalid method parameter(s). + pub const INVALID_PARAMS: i32 = -32602; + /// Internal error - Internal JSON-RPC error. + pub const INTERNAL_ERROR: i32 = -32603; +} + +/// JSON-RPC 2.0 Request. +#[derive(Debug, Deserialize)] +pub struct JsonRpcRequest { + /// JSON-RPC version (must be "2.0"). + pub jsonrpc: String, + /// Request identifier (optional for notifications). + #[serde(default)] + pub id: Option, + /// Method name to invoke. + pub method: String, + /// Method parameters (optional). + #[serde(default)] + pub params: Option, +} + +/// JSON-RPC 2.0 Response. +#[derive(Debug, Serialize)] +pub struct JsonRpcResponse { + /// JSON-RPC version (always "2.0"). + pub jsonrpc: &'static str, + /// Request identifier (copied from request). + pub id: Option, + /// Result value (present on success). + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + /// Error value (present on failure). + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl JsonRpcResponse { + /// Create a success response. + pub fn success(id: Option, result: Value) -> Self { + Self { + jsonrpc: "2.0", + id, + result: Some(result), + error: None, + } + } + + /// Create an error response. + pub fn error(id: Option, code: i32, message: impl Into) -> Self { + Self { + jsonrpc: "2.0", + id, + result: None, + error: Some(JsonRpcError { + code, + message: message.into(), + data: None, + }), + } + } + + /// Create an error response with data. + pub fn error_with_data( + id: Option, + code: i32, + message: impl Into, + data: Value, + ) -> Self { + Self { + jsonrpc: "2.0", + id, + result: None, + error: Some(JsonRpcError { + code, + message: message.into(), + data: Some(data), + }), + } + } +} + +/// JSON-RPC 2.0 Error. +#[derive(Debug, Serialize)] +pub struct JsonRpcError { + /// Error code. + pub code: i32, + /// Error message. + pub message: String, + /// Additional error data (optional). + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +/// MCP Tool specification. +#[derive(Debug, Clone, Serialize)] +pub struct ToolSpec { + /// Tool name (unique identifier). + pub name: String, + /// Human-readable description. + pub description: String, + /// JSON Schema for input parameters. + #[serde(rename = "inputSchema")] + pub input_schema: Value, +} + +/// MCP Tool execution result. +#[derive(Debug, Serialize)] +pub struct ToolResult { + /// Result content items. + pub content: Vec, + /// Whether this result represents an error. + #[serde(rename = "isError", skip_serializing_if = "Option::is_none")] + pub is_error: Option, +} + +impl ToolResult { + /// Create a text result. + pub fn text(text: impl Into) -> Self { + Self { + content: vec![ToolContent::text(text)], + is_error: None, + } + } + + /// Create an error result. + pub fn error(message: impl Into) -> Self { + Self { + content: vec![ToolContent::text(message)], + is_error: Some(true), + } + } +} + +/// MCP Tool content item. +#[derive(Debug, Serialize)] +pub struct ToolContent { + /// Content type ("text", "image", "resource"). + #[serde(rename = "type")] + pub content_type: String, + /// Text content (for type "text"). + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + /// Base64-encoded data (for type "image"). + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, + /// MIME type (for type "image"). + #[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")] + pub mime_type: Option, +} + +impl ToolContent { + /// Create a text content item. + pub fn text(text: impl Into) -> Self { + Self { + content_type: "text".to_string(), + text: Some(text.into()), + data: None, + mime_type: None, + } + } + + /// Create an image content item. + pub fn image(data: impl Into, mime_type: impl Into) -> Self { + Self { + content_type: "image".to_string(), + text: None, + data: Some(data.into()), + mime_type: Some(mime_type.into()), + } + } +} + +/// Tool handler function type. +pub type ToolHandler = Arc ToolResult + Send + Sync>; + +/// Server information for MCP initialize response. +#[derive(Debug, Serialize)] +struct ServerInfo { + name: &'static str, + version: &'static str, +} + +/// MCP Handler for JSON-RPC 2.0 requests. +pub struct McpHandler { + /// Registered tools: name -> (spec, handler) + tools: HashMap, + /// Server information + server_info: ServerInfo, +} + +impl McpHandler { + /// Create a new MCP handler. + pub fn new() -> Self { + Self { + tools: HashMap::new(), + server_info: ServerInfo { + name: "rust-auth-mcp-server", + version: "1.0.0", + }, + } + } + + /// Handle a JSON-RPC request. + pub fn handle_request(&self, body: Value, auth_context: &AuthContext) -> JsonRpcResponse { + // Parse request + let request: JsonRpcRequest = match serde_json::from_value(body) { + Ok(r) => r, + Err(_) => { + return JsonRpcResponse::error( + None, + error_codes::INVALID_REQUEST, + "Invalid request: expected JSON-RPC object", + ); + } + }; + + // Validate jsonrpc version + if request.jsonrpc != "2.0" { + return JsonRpcResponse::error( + request.id, + error_codes::INVALID_REQUEST, + "Invalid request: jsonrpc must be \"2.0\"", + ); + } + + // Dispatch method + match request.method.as_str() { + "initialize" => self.handle_initialize(request.id), + "tools/list" => self.handle_tools_list(request.id), + "tools/call" => self.handle_tools_call(request.id, request.params, auth_context), + "ping" => self.handle_ping(request.id), + _ => JsonRpcResponse::error( + request.id, + error_codes::METHOD_NOT_FOUND, + format!("Method not found: {}", request.method), + ), + } + } + + /// Handle initialize method. + fn handle_initialize(&self, id: Option) -> JsonRpcResponse { + JsonRpcResponse::success( + id, + json!({ + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "serverInfo": self.server_info + }), + ) + } + + /// Handle tools/list method. + fn handle_tools_list(&self, id: Option) -> JsonRpcResponse { + let tools: Vec<&ToolSpec> = self.tools.values().map(|(spec, _)| spec).collect(); + JsonRpcResponse::success(id, json!({ "tools": tools })) + } + + /// Handle tools/call method. + fn handle_tools_call( + &self, + id: Option, + params: Option, + auth_context: &AuthContext, + ) -> JsonRpcResponse { + let params = match params { + Some(p) => p, + None => { + return JsonRpcResponse::error(id, error_codes::INVALID_PARAMS, "Missing params"); + } + }; + + let name = match params.get("name").and_then(|v| v.as_str()) { + Some(n) => n, + None => { + return JsonRpcResponse::error( + id, + error_codes::INVALID_PARAMS, + "Invalid params: name must be a string", + ); + } + }; + + let arguments = params.get("arguments").cloned().unwrap_or(json!({})); + + let (_, handler) = match self.tools.get(name) { + Some(t) => t, + None => { + return JsonRpcResponse::error( + id, + error_codes::METHOD_NOT_FOUND, + format!("Tool not found: {}", name), + ); + } + }; + + let result = handler(arguments, auth_context); + JsonRpcResponse::success(id, serde_json::to_value(result).unwrap_or(json!(null))) + } + + /// Handle ping method. + fn handle_ping(&self, id: Option) -> JsonRpcResponse { + JsonRpcResponse::success(id, json!({})) + } + + /// Register a tool with the handler. + pub fn register_tool(&mut self, name: &str, spec: ToolSpec, handler: F) + where + F: Fn(Value, &AuthContext) -> ToolResult + Send + Sync + 'static, + { + self.tools.insert(name.to_string(), (spec, Arc::new(handler))); + } +} + +impl Default for McpHandler { + fn default() -> Self { + Self::new() + } +} + +/// MCP POST endpoint handler. +/// +/// Handles JSON-RPC requests at /mcp and /rpc endpoints. +pub async fn mcp_handler( + State(handler): State>, + Extension(auth_context): Extension, + Json(body): Json, +) -> Response { + let response = handler.handle_request(body, &auth_context); + with_cors_headers(Json(response)) +} + +/// MCP OPTIONS endpoint handler for CORS preflight. +pub async fn mcp_options() -> impl IntoResponse { + cors_options().await +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_error_codes() { + assert_eq!(error_codes::PARSE_ERROR, -32700); + assert_eq!(error_codes::INVALID_REQUEST, -32600); + assert_eq!(error_codes::METHOD_NOT_FOUND, -32601); + assert_eq!(error_codes::INVALID_PARAMS, -32602); + assert_eq!(error_codes::INTERNAL_ERROR, -32603); + } + + #[test] + fn test_jsonrpc_request_deserialization() { + let json = r#"{ + "jsonrpc": "2.0", + "id": 1, + "method": "test", + "params": {"key": "value"} + }"#; + + let request: JsonRpcRequest = serde_json::from_str(json).unwrap(); + + assert_eq!(request.jsonrpc, "2.0"); + assert_eq!(request.id, Some(json!(1))); + assert_eq!(request.method, "test"); + assert!(request.params.is_some()); + } + + #[test] + fn test_jsonrpc_request_minimal() { + let json = r#"{"jsonrpc": "2.0", "method": "ping"}"#; + + let request: JsonRpcRequest = serde_json::from_str(json).unwrap(); + + assert_eq!(request.jsonrpc, "2.0"); + assert!(request.id.is_none()); + assert_eq!(request.method, "ping"); + assert!(request.params.is_none()); + } + + #[test] + fn test_jsonrpc_response_success() { + let response = JsonRpcResponse::success(Some(json!(1)), json!({"status": "ok"})); + let json = serde_json::to_value(&response).unwrap(); + + assert_eq!(json["jsonrpc"], "2.0"); + assert_eq!(json["id"], 1); + assert_eq!(json["result"]["status"], "ok"); + assert!(json.get("error").is_none()); + } + + #[test] + fn test_jsonrpc_response_error() { + let response = JsonRpcResponse::error( + Some(json!(1)), + error_codes::METHOD_NOT_FOUND, + "Method not found", + ); + let json = serde_json::to_value(&response).unwrap(); + + assert_eq!(json["jsonrpc"], "2.0"); + assert_eq!(json["id"], 1); + assert!(json.get("result").is_none()); + assert_eq!(json["error"]["code"], -32601); + assert_eq!(json["error"]["message"], "Method not found"); + } + + #[test] + fn test_jsonrpc_error_serialization() { + let error = JsonRpcError { + code: -32600, + message: "Invalid Request".to_string(), + data: Some(json!({"details": "missing field"})), + }; + let json = serde_json::to_value(&error).unwrap(); + + assert_eq!(json["code"], -32600); + assert_eq!(json["message"], "Invalid Request"); + assert_eq!(json["data"]["details"], "missing field"); + } + + #[test] + fn test_jsonrpc_error_omits_none_data() { + let error = JsonRpcError { + code: -32600, + message: "Invalid Request".to_string(), + data: None, + }; + let json = serde_json::to_string(&error).unwrap(); + + assert!(!json.contains("data")); + } + + #[test] + fn test_tool_spec_serialization() { + let spec = ToolSpec { + name: "test-tool".to_string(), + description: "A test tool".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "input": {"type": "string"} + } + }), + }; + let json = serde_json::to_value(&spec).unwrap(); + + assert_eq!(json["name"], "test-tool"); + assert_eq!(json["description"], "A test tool"); + assert_eq!(json["inputSchema"]["type"], "object"); + } + + #[test] + fn test_tool_result_text() { + let result = ToolResult::text("Hello, world!"); + let json = serde_json::to_value(&result).unwrap(); + + assert_eq!(json["content"][0]["type"], "text"); + assert_eq!(json["content"][0]["text"], "Hello, world!"); + assert!(json.get("isError").is_none()); + } + + #[test] + fn test_tool_result_error() { + let result = ToolResult::error("Something went wrong"); + let json = serde_json::to_value(&result).unwrap(); + + assert_eq!(json["content"][0]["type"], "text"); + assert_eq!(json["content"][0]["text"], "Something went wrong"); + assert_eq!(json["isError"], true); + } + + #[test] + fn test_tool_content_text() { + let content = ToolContent::text("Hello"); + let json = serde_json::to_value(&content).unwrap(); + + assert_eq!(json["type"], "text"); + assert_eq!(json["text"], "Hello"); + assert!(json.get("data").is_none()); + assert!(json.get("mimeType").is_none()); + } + + #[test] + fn test_tool_content_image() { + let content = ToolContent::image("base64data", "image/png"); + let json = serde_json::to_value(&content).unwrap(); + + assert_eq!(json["type"], "image"); + assert!(json.get("text").is_none()); + assert_eq!(json["data"], "base64data"); + assert_eq!(json["mimeType"], "image/png"); + } +} diff --git a/examples/auth/src/routes/mod.rs b/examples/auth/src/routes/mod.rs new file mode 100644 index 00000000..d10a002a --- /dev/null +++ b/examples/auth/src/routes/mod.rs @@ -0,0 +1,7 @@ +//! HTTP route handlers module. +//! +//! Contains handlers for health, OAuth discovery, and MCP endpoints. + +pub mod health; +pub mod mcp_handler; +pub mod oauth_endpoints; diff --git a/examples/auth/src/routes/oauth_endpoints.rs b/examples/auth/src/routes/oauth_endpoints.rs new file mode 100644 index 00000000..354d262b --- /dev/null +++ b/examples/auth/src/routes/oauth_endpoints.rs @@ -0,0 +1,566 @@ +//! OAuth discovery endpoints. +//! +//! Implements OAuth 2.0 and OpenID Connect discovery endpoints per RFC specifications: +//! - RFC 9728: Protected Resource Metadata +//! - RFC 8414: Authorization Server Metadata +//! - OpenID Connect Discovery 1.0 +//! - RFC 7591: Dynamic Client Registration + +use std::collections::HashMap; +use std::sync::Arc; + +use axum::{ + extract::{Query, State}, + http::StatusCode, + response::{IntoResponse, Redirect, Response}, + Json, +}; +use serde::{Deserialize, Serialize}; + +use crate::config::AuthServerConfig; +use crate::cors::with_cors_headers; + +/// RFC 9728: Protected Resource Metadata. +/// +/// Describes the OAuth 2.0 protected resource and its requirements. +#[derive(Debug, Clone, Serialize)] +pub struct ProtectedResourceMetadata { + /// The protected resource identifier (URL). + pub resource: String, + /// List of authorization server URLs. + pub authorization_servers: Vec, + /// Supported OAuth scopes. + #[serde(skip_serializing_if = "Option::is_none")] + pub scopes_supported: Option>, + /// Supported bearer token methods. + #[serde(skip_serializing_if = "Option::is_none")] + pub bearer_methods_supported: Option>, + /// URL to resource documentation. + #[serde(skip_serializing_if = "Option::is_none")] + pub resource_documentation: Option, +} + +/// RFC 8414: Authorization Server Metadata. +/// +/// Describes the OAuth 2.0 authorization server configuration. +#[derive(Debug, Clone, Serialize)] +pub struct AuthorizationServerMetadata { + /// Authorization server issuer identifier. + pub issuer: String, + /// URL of the authorization endpoint. + pub authorization_endpoint: String, + /// URL of the token endpoint. + pub token_endpoint: String, + /// URL of the JWKS endpoint. + #[serde(skip_serializing_if = "Option::is_none")] + pub jwks_uri: Option, + /// URL of the dynamic client registration endpoint. + #[serde(skip_serializing_if = "Option::is_none")] + pub registration_endpoint: Option, + /// Supported OAuth scopes. + #[serde(skip_serializing_if = "Option::is_none")] + pub scopes_supported: Option>, + /// Supported response types. + pub response_types_supported: Vec, + /// Supported grant types. + #[serde(skip_serializing_if = "Option::is_none")] + pub grant_types_supported: Option>, + /// Supported token endpoint authentication methods. + #[serde(skip_serializing_if = "Option::is_none")] + pub token_endpoint_auth_methods_supported: Option>, + /// Supported PKCE code challenge methods. + #[serde(skip_serializing_if = "Option::is_none")] + pub code_challenge_methods_supported: Option>, +} + +/// OpenID Connect Discovery 1.0 Configuration. +/// +/// Extends RFC 8414 with OIDC-specific fields. +#[derive(Debug, Clone, Serialize)] +pub struct OpenIDConfiguration { + /// Base authorization server metadata. + #[serde(flatten)] + pub base: AuthorizationServerMetadata, + /// URL of the userinfo endpoint. + #[serde(skip_serializing_if = "Option::is_none")] + pub userinfo_endpoint: Option, + /// Supported subject identifier types. + #[serde(skip_serializing_if = "Option::is_none")] + pub subject_types_supported: Option>, + /// Supported ID token signing algorithms. + #[serde(skip_serializing_if = "Option::is_none")] + pub id_token_signing_alg_values_supported: Option>, +} + +/// RFC 7591: Client Registration Response. +/// +/// Response returned from dynamic client registration endpoint. +#[derive(Debug, Clone, Serialize)] +pub struct ClientRegistrationResponse { + /// Assigned client identifier. + pub client_id: String, + /// Assigned client secret (if confidential client). + #[serde(skip_serializing_if = "Option::is_none")] + pub client_secret: Option, + /// Unix timestamp when client_id was issued. + pub client_id_issued_at: u64, + /// Unix timestamp when client_secret expires (0 = never). + pub client_secret_expires_at: u64, + /// Registered redirect URIs. + pub redirect_uris: Vec, + /// Supported grant types. + pub grant_types: Vec, + /// Supported response types. + pub response_types: Vec, + /// Token endpoint authentication method. + pub token_endpoint_auth_method: String, +} + +/// Client registration request body. +#[derive(Debug, Clone, Deserialize)] +pub struct ClientRegistrationRequest { + /// Requested redirect URIs. + #[serde(default)] + pub redirect_uris: Vec, +} + +/// Parse scopes from a space-separated string. +fn parse_scopes(scopes: &str) -> Vec { + scopes + .split_whitespace() + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect() +} + +/// Protected resource metadata endpoint handler. +/// +/// Returns RFC 9728 compliant metadata describing this MCP resource. +/// Serves both `/.well-known/oauth-protected-resource` and +/// `/.well-known/oauth-protected-resource/mcp`. +pub async fn protected_resource_metadata( + State(config): State>, +) -> impl IntoResponse { + let scopes = parse_scopes(&config.allowed_scopes); + + let metadata = ProtectedResourceMetadata { + resource: format!("{}/mcp", config.server_url), + authorization_servers: vec![config.server_url.clone()], + scopes_supported: if scopes.is_empty() { + None + } else { + Some(scopes) + }, + bearer_methods_supported: Some(vec!["header".to_string(), "query".to_string()]), + resource_documentation: Some(format!("{}/docs", config.server_url)), + }; + + with_cors_headers(Json(metadata)) +} + +/// Authorization server metadata endpoint handler. +/// +/// Returns RFC 8414 compliant metadata describing the authorization server. +/// Serves `/.well-known/oauth-authorization-server`. +pub async fn authorization_server_metadata( + State(config): State>, +) -> impl IntoResponse { + let scopes = parse_scopes(&config.allowed_scopes); + + // Use issuer from config, or fall back to server_url + let issuer = if config.issuer.is_empty() { + config.server_url.clone() + } else { + config.issuer.clone() + }; + + // Use configured OAuth URLs or fall back to auth_server_url derived URLs + let authorization_endpoint = if !config.oauth_authorize_url.is_empty() { + config.oauth_authorize_url.clone() + } else if !config.auth_server_url.is_empty() { + format!("{}/protocol/openid-connect/auth", config.auth_server_url) + } else { + format!("{}/oauth/authorize", config.server_url) + }; + + let token_endpoint = if !config.oauth_token_url.is_empty() { + config.oauth_token_url.clone() + } else if !config.auth_server_url.is_empty() { + format!("{}/protocol/openid-connect/token", config.auth_server_url) + } else { + format!("{}/oauth/token", config.server_url) + }; + + let metadata = AuthorizationServerMetadata { + issuer, + authorization_endpoint, + token_endpoint, + jwks_uri: if config.jwks_uri.is_empty() { + None + } else { + Some(config.jwks_uri.clone()) + }, + registration_endpoint: Some(format!("{}/oauth/register", config.server_url)), + scopes_supported: if scopes.is_empty() { + None + } else { + Some(scopes) + }, + response_types_supported: vec!["code".to_string()], + grant_types_supported: Some(vec![ + "authorization_code".to_string(), + "refresh_token".to_string(), + ]), + token_endpoint_auth_methods_supported: Some(vec![ + "client_secret_basic".to_string(), + "client_secret_post".to_string(), + "none".to_string(), + ]), + code_challenge_methods_supported: Some(vec!["S256".to_string()]), + }; + + with_cors_headers(Json(metadata)) +} + +/// Merge base OIDC scopes with configured scopes, deduplicating. +fn merge_oidc_scopes(configured_scopes: &[String]) -> Vec { + let base_scopes = ["openid", "profile", "email"]; + let mut scopes: Vec = base_scopes.iter().map(|s| s.to_string()).collect(); + + for scope in configured_scopes { + if !scopes.contains(scope) { + scopes.push(scope.clone()); + } + } + + scopes +} + +/// OpenID Connect discovery endpoint handler. +/// +/// Returns OIDC Discovery 1.0 compliant configuration extending RFC 8414. +/// Serves `/.well-known/openid-configuration`. +pub async fn openid_configuration( + State(config): State>, +) -> impl IntoResponse { + let configured_scopes = parse_scopes(&config.allowed_scopes); + let scopes = merge_oidc_scopes(&configured_scopes); + + // Use issuer from config, or fall back to server_url + let issuer = if config.issuer.is_empty() { + config.server_url.clone() + } else { + config.issuer.clone() + }; + + // Use configured OAuth URLs or fall back to auth_server_url derived URLs + let authorization_endpoint = if !config.oauth_authorize_url.is_empty() { + config.oauth_authorize_url.clone() + } else if !config.auth_server_url.is_empty() { + format!("{}/protocol/openid-connect/auth", config.auth_server_url) + } else { + format!("{}/oauth/authorize", config.server_url) + }; + + let token_endpoint = if !config.oauth_token_url.is_empty() { + config.oauth_token_url.clone() + } else if !config.auth_server_url.is_empty() { + format!("{}/protocol/openid-connect/token", config.auth_server_url) + } else { + format!("{}/oauth/token", config.server_url) + }; + + // Userinfo endpoint (only if auth_server_url is configured) + let userinfo_endpoint = if !config.auth_server_url.is_empty() { + Some(format!( + "{}/protocol/openid-connect/userinfo", + config.auth_server_url + )) + } else { + None + }; + + let base = AuthorizationServerMetadata { + issuer, + authorization_endpoint, + token_endpoint, + jwks_uri: if config.jwks_uri.is_empty() { + None + } else { + Some(config.jwks_uri.clone()) + }, + registration_endpoint: Some(format!("{}/oauth/register", config.server_url)), + scopes_supported: Some(scopes), + response_types_supported: vec!["code".to_string()], + grant_types_supported: Some(vec![ + "authorization_code".to_string(), + "refresh_token".to_string(), + ]), + token_endpoint_auth_methods_supported: Some(vec![ + "client_secret_basic".to_string(), + "client_secret_post".to_string(), + "none".to_string(), + ]), + code_challenge_methods_supported: Some(vec!["S256".to_string()]), + }; + + let oidc_config = OpenIDConfiguration { + base, + userinfo_endpoint, + subject_types_supported: Some(vec!["public".to_string()]), + id_token_signing_alg_values_supported: Some(vec!["RS256".to_string()]), + }; + + with_cors_headers(Json(oidc_config)) +} + +/// Query parameters for OAuth authorize endpoint. +#[derive(Debug, Deserialize)] +pub struct AuthorizeParams { + /// Capture all query parameters to forward. + #[serde(flatten)] + pub params: HashMap, +} + +/// OAuth authorize redirect endpoint handler. +/// +/// Redirects to the configured authorization server with all query parameters. +/// Serves `/oauth/authorize`. +pub async fn oauth_authorize( + State(config): State>, + Query(params): Query, +) -> Response { + // Determine authorization endpoint + let auth_endpoint = if !config.oauth_authorize_url.is_empty() { + &config.oauth_authorize_url + } else if !config.auth_server_url.is_empty() { + // Will construct below + &format!("{}/protocol/openid-connect/auth", config.auth_server_url) + } else { + // Fallback error - no authorization endpoint configured + return with_cors_headers(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "server_error", + "error_description": "No authorization endpoint configured" + })), + )); + }; + + // Parse and build redirect URL + let mut url = match url::Url::parse(auth_endpoint) { + Ok(u) => u, + Err(_) => { + return with_cors_headers(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "server_error", + "error_description": "Failed to parse authorization URL" + })), + )); + } + }; + + // Forward all query parameters + for (key, value) in ¶ms.params { + url.query_pairs_mut().append_pair(key, value); + } + + // Return redirect with CORS headers + with_cors_headers(Redirect::to(url.as_str())) +} + +/// OAuth dynamic client registration endpoint handler. +/// +/// Returns client credentials (stateless - uses configured values). +/// Serves `POST /oauth/register`. +pub async fn oauth_register( + State(config): State>, + Json(body): Json, +) -> Response { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + // Determine auth method based on whether secret is configured + let (client_secret, token_endpoint_auth_method) = if config.client_secret.is_empty() { + (None, "none".to_string()) + } else { + ( + Some(config.client_secret.clone()), + "client_secret_post".to_string(), + ) + }; + + let response = ClientRegistrationResponse { + client_id: config.client_id.clone(), + client_secret, + client_id_issued_at: now, + client_secret_expires_at: 0, // Never expires + redirect_uris: body.redirect_uris, + grant_types: vec![ + "authorization_code".to_string(), + "refresh_token".to_string(), + ], + response_types: vec!["code".to_string()], + token_endpoint_auth_method, + }; + + with_cors_headers((StatusCode::CREATED, Json(response))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_protected_resource_metadata_serialization() { + let metadata = ProtectedResourceMetadata { + resource: "https://example.com/mcp".to_string(), + authorization_servers: vec!["https://example.com".to_string()], + scopes_supported: Some(vec!["mcp:read".to_string(), "mcp:admin".to_string()]), + bearer_methods_supported: Some(vec!["header".to_string(), "query".to_string()]), + resource_documentation: Some("https://example.com/docs".to_string()), + }; + + let json = serde_json::to_value(&metadata).unwrap(); + + assert_eq!(json["resource"], "https://example.com/mcp"); + assert_eq!(json["authorization_servers"][0], "https://example.com"); + assert_eq!(json["scopes_supported"][0], "mcp:read"); + assert_eq!(json["bearer_methods_supported"][0], "header"); + } + + #[test] + fn test_protected_resource_metadata_omits_none() { + let metadata = ProtectedResourceMetadata { + resource: "https://example.com/mcp".to_string(), + authorization_servers: vec!["https://example.com".to_string()], + scopes_supported: None, + bearer_methods_supported: None, + resource_documentation: None, + }; + + let json = serde_json::to_string(&metadata).unwrap(); + + assert!(!json.contains("scopes_supported")); + assert!(!json.contains("bearer_methods_supported")); + assert!(!json.contains("resource_documentation")); + } + + #[test] + fn test_authorization_server_metadata_serialization() { + let metadata = AuthorizationServerMetadata { + issuer: "https://auth.example.com".to_string(), + authorization_endpoint: "https://auth.example.com/authorize".to_string(), + token_endpoint: "https://auth.example.com/token".to_string(), + jwks_uri: Some("https://auth.example.com/jwks".to_string()), + registration_endpoint: Some("https://example.com/oauth/register".to_string()), + scopes_supported: Some(vec!["openid".to_string(), "profile".to_string()]), + response_types_supported: vec!["code".to_string()], + grant_types_supported: Some(vec!["authorization_code".to_string()]), + token_endpoint_auth_methods_supported: Some(vec!["client_secret_post".to_string()]), + code_challenge_methods_supported: Some(vec!["S256".to_string()]), + }; + + let json = serde_json::to_value(&metadata).unwrap(); + + assert_eq!(json["issuer"], "https://auth.example.com"); + assert_eq!(json["response_types_supported"][0], "code"); + assert_eq!(json["code_challenge_methods_supported"][0], "S256"); + } + + #[test] + fn test_openid_configuration_flattens_base() { + let base = AuthorizationServerMetadata { + issuer: "https://auth.example.com".to_string(), + authorization_endpoint: "https://auth.example.com/authorize".to_string(), + token_endpoint: "https://auth.example.com/token".to_string(), + jwks_uri: None, + registration_endpoint: None, + scopes_supported: None, + response_types_supported: vec!["code".to_string()], + grant_types_supported: None, + token_endpoint_auth_methods_supported: None, + code_challenge_methods_supported: None, + }; + + let oidc = OpenIDConfiguration { + base, + userinfo_endpoint: Some("https://auth.example.com/userinfo".to_string()), + subject_types_supported: Some(vec!["public".to_string()]), + id_token_signing_alg_values_supported: Some(vec!["RS256".to_string()]), + }; + + let json = serde_json::to_value(&oidc).unwrap(); + + // Base fields should be flattened + assert_eq!(json["issuer"], "https://auth.example.com"); + assert_eq!(json["authorization_endpoint"], "https://auth.example.com/authorize"); + // OIDC-specific fields + assert_eq!(json["userinfo_endpoint"], "https://auth.example.com/userinfo"); + assert_eq!(json["subject_types_supported"][0], "public"); + assert_eq!(json["id_token_signing_alg_values_supported"][0], "RS256"); + } + + #[test] + fn test_client_registration_response_serialization() { + let response = ClientRegistrationResponse { + client_id: "test-client".to_string(), + client_secret: Some("secret".to_string()), + client_id_issued_at: 1704067200, + client_secret_expires_at: 0, + redirect_uris: vec!["https://example.com/callback".to_string()], + grant_types: vec!["authorization_code".to_string(), "refresh_token".to_string()], + response_types: vec!["code".to_string()], + token_endpoint_auth_method: "client_secret_post".to_string(), + }; + + let json = serde_json::to_value(&response).unwrap(); + + assert_eq!(json["client_id"], "test-client"); + assert_eq!(json["client_secret"], "secret"); + assert_eq!(json["client_id_issued_at"], 1704067200); + assert_eq!(json["client_secret_expires_at"], 0); + assert_eq!(json["grant_types"][0], "authorization_code"); + } + + #[test] + fn test_client_registration_response_omits_none_secret() { + let response = ClientRegistrationResponse { + client_id: "public-client".to_string(), + client_secret: None, + client_id_issued_at: 1704067200, + client_secret_expires_at: 0, + redirect_uris: vec![], + grant_types: vec!["authorization_code".to_string()], + response_types: vec!["code".to_string()], + token_endpoint_auth_method: "none".to_string(), + }; + + let json = serde_json::to_value(&response).unwrap(); + + // client_secret field should be absent (not null) + assert!(json.get("client_secret").is_none()); + // But client_secret_expires_at should still be present + assert!(json.get("client_secret_expires_at").is_some()); + } + + #[test] + fn test_client_registration_request_deserialization() { + let json = r#"{"redirect_uris": ["https://example.com/callback"]}"#; + let request: ClientRegistrationRequest = serde_json::from_str(json).unwrap(); + + assert_eq!(request.redirect_uris.len(), 1); + assert_eq!(request.redirect_uris[0], "https://example.com/callback"); + } + + #[test] + fn test_client_registration_request_empty() { + let json = r#"{}"#; + let request: ClientRegistrationRequest = serde_json::from_str(json).unwrap(); + + assert!(request.redirect_uris.is_empty()); + } +} diff --git a/examples/auth/src/tools/mod.rs b/examples/auth/src/tools/mod.rs new file mode 100644 index 00000000..c6ca03ef --- /dev/null +++ b/examples/auth/src/tools/mod.rs @@ -0,0 +1,5 @@ +//! MCP tools module. +//! +//! Contains example tools with scope-based access control. + +pub mod weather_tools; diff --git a/examples/auth/src/tools/weather_tools.rs b/examples/auth/src/tools/weather_tools.rs new file mode 100644 index 00000000..0357c8d9 --- /dev/null +++ b/examples/auth/src/tools/weather_tools.rs @@ -0,0 +1,555 @@ +//! Weather Tools +//! +//! Example MCP tools demonstrating OAuth scope-based access control. +//! Mirrors the weather tools from the TypeScript and C++ auth examples. + +use serde_json::{json, Value}; + +use crate::routes::mcp_handler::{AuthContext, McpHandler, ToolContent, ToolResult, ToolSpec}; + +/// Weather conditions for simulation. +const CONDITIONS: [&str; 6] = [ + "Sunny", + "Cloudy", + "Rainy", + "Partly Cloudy", + "Windy", + "Stormy", +]; + +/// Day names for forecast. +const DAYS: [&str; 5] = ["Today", "Tomorrow", "Day 3", "Day 4", "Day 5"]; + +/// Check if a scope is present in a space-separated scope string. +/// +/// # Arguments +/// +/// * `scopes` - Space-separated scope string +/// * `required` - Required scope to check for +/// +/// # Returns +/// +/// true if the required scope is present +pub fn has_scope(scopes: &str, required: &str) -> bool { + if scopes.is_empty() || required.is_empty() { + return false; + } + scopes.split_whitespace().any(|s| s == required) +} + +/// Calculate a simple hash of a city name. +/// +/// # Arguments +/// +/// * `city` - City name +/// +/// # Returns +/// +/// Sum of character byte values +pub fn city_hash(city: &str) -> usize { + city.bytes().map(|b| b as usize).sum() +} + +/// Get a deterministic but varying condition based on city name. +/// +/// # Arguments +/// +/// * `city` - City name +/// * `offset` - Offset for variation (e.g., day number) +/// +/// # Returns +/// +/// Weather condition string +pub fn get_condition(city: &str, offset: usize) -> &'static str { + let hash = city_hash(city); + CONDITIONS[(hash + offset) % CONDITIONS.len()] +} + +/// Get a deterministic but varying temperature based on city name. +/// +/// # Arguments +/// +/// * `city` - City name +/// * `offset` - Offset for variation (e.g., day number) +/// +/// # Returns +/// +/// Temperature in Celsius (10-35°C range) +pub fn get_temp(city: &str, offset: usize) -> i32 { + let hash = city_hash(city); + 10 + ((hash + offset * 7) % 26) as i32 +} + +/// Create an access denied error result. +/// +/// # Arguments +/// +/// * `scope` - Required scope that was missing +/// +/// # Returns +/// +/// ToolResult with error +pub fn access_denied(scope: &str) -> ToolResult { + ToolResult { + content: vec![ToolContent::text( + serde_json::to_string(&json!({ + "error": "access_denied", + "message": format!("Access denied. Required scope: {}", scope), + })) + .unwrap_or_else(|_| "{}".to_string()), + )], + is_error: Some(true), + } +} + +/// Get simulated current weather for a city. +fn get_simulated_weather(city: &str) -> Value { + let hash = city_hash(city); + + json!({ + "city": city, + "temperature": get_temp(city, 0), + "condition": get_condition(city, 0), + "humidity": 40 + (hash % 40), // 40-80% + "windSpeed": 5 + (hash % 25), // 5-30 km/h + }) +} + +/// Get simulated 5-day forecast for a city. +fn get_simulated_forecast(city: &str) -> Value { + let forecast: Vec = DAYS + .iter() + .enumerate() + .map(|(index, day)| { + json!({ + "day": day, + "high": get_temp(city, index) + 5, + "low": get_temp(city, index) - 5, + "condition": get_condition(city, index), + }) + }) + .collect(); + + json!({ + "city": city, + "forecast": forecast, + }) +} + +/// Get simulated weather alerts for a region. +fn get_simulated_alerts(region: &str) -> Value { + let hash = city_hash(region); + + let alerts: Vec = if hash % 3 == 0 { + vec![json!({ + "type": "Heat Warning", + "severity": "moderate", + "message": format!("High temperatures expected in {}. Stay hydrated.", region), + })] + } else if hash % 3 == 1 { + vec![ + json!({ + "type": "Storm Watch", + "severity": "high", + "message": format!("Severe thunderstorms possible in {}. Seek shelter if needed.", region), + }), + json!({ + "type": "Wind Advisory", + "severity": "low", + "message": format!("Strong winds expected in {}. Secure loose objects.", region), + }), + ] + } else { + vec![] // No alerts + }; + + json!({ + "region": region, + "alerts": alerts, + }) +} + +/// Register weather tools with the MCP handler. +/// +/// # Arguments +/// +/// * `mcp` - MCP handler instance +/// * `auth_disabled` - Whether authentication is disabled +pub fn register_weather_tools(mcp: &mut McpHandler, auth_disabled: bool) { + // get-weather - No authentication required + // Returns current weather for a specified city. + mcp.register_tool( + "get-weather", + ToolSpec { + name: "get-weather".to_string(), + description: "Get current weather for a city. No authentication required.".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name to get weather for" + } + }, + "required": ["city"] + }), + }, + |args, _auth_context| { + let city = args + .get("city") + .and_then(|v| v.as_str()) + .unwrap_or("Unknown"); + + let weather = get_simulated_weather(city); + ToolResult::text(serde_json::to_string_pretty(&weather).unwrap_or_default()) + }, + ); + + // get-forecast - Requires mcp:read scope + // Returns 5-day weather forecast for a specified city. + let auth_disabled_forecast = auth_disabled; + mcp.register_tool( + "get-forecast", + ToolSpec { + name: "get-forecast".to_string(), + description: "Get 5-day weather forecast for a city. Requires mcp:read scope." + .to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name to get forecast for" + } + }, + "required": ["city"] + }), + }, + move |args, auth_context| { + // Check scope if auth is enabled + if !auth_disabled_forecast && !has_scope(&auth_context.scopes, "mcp:read") { + return access_denied("mcp:read"); + } + + let city = args + .get("city") + .and_then(|v| v.as_str()) + .unwrap_or("Unknown"); + + let forecast = get_simulated_forecast(city); + ToolResult::text(serde_json::to_string_pretty(&forecast).unwrap_or_default()) + }, + ); + + // get-weather-alerts - Requires mcp:admin scope + // Returns weather alerts for a specified region. + let auth_disabled_alerts = auth_disabled; + mcp.register_tool( + "get-weather-alerts", + ToolSpec { + name: "get-weather-alerts".to_string(), + description: "Get weather alerts for a region. Requires mcp:admin scope.".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "region": { + "type": "string", + "description": "Region name to get alerts for" + } + }, + "required": ["region"] + }), + }, + move |args, auth_context| { + // Check scope if auth is enabled + if !auth_disabled_alerts && !has_scope(&auth_context.scopes, "mcp:admin") { + return access_denied("mcp:admin"); + } + + let region = args + .get("region") + .and_then(|v| v.as_str()) + .unwrap_or("Unknown"); + + let alerts = get_simulated_alerts(region); + ToolResult::text(serde_json::to_string_pretty(&alerts).unwrap_or_default()) + }, + ); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_has_scope_present() { + assert!(has_scope("openid profile mcp:read", "mcp:read")); + assert!(has_scope("openid profile mcp:read", "openid")); + assert!(has_scope("openid profile mcp:read", "profile")); + } + + #[test] + fn test_has_scope_absent() { + assert!(!has_scope("openid profile", "mcp:read")); + assert!(!has_scope("openid profile mcp:read", "mcp:admin")); + } + + #[test] + fn test_has_scope_empty() { + assert!(!has_scope("", "mcp:read")); + assert!(!has_scope("openid profile", "")); + assert!(!has_scope("", "")); + } + + #[test] + fn test_city_hash() { + // "NYC" = 78 + 89 + 67 = 234 + assert_eq!(city_hash("NYC"), 234); + // Hash should be consistent + assert_eq!(city_hash("NYC"), city_hash("NYC")); + // Different cities have different hashes + assert_ne!(city_hash("NYC"), city_hash("LA")); + } + + #[test] + fn test_get_condition_deterministic() { + let condition1 = get_condition("NYC", 0); + let condition2 = get_condition("NYC", 0); + assert_eq!(condition1, condition2); + } + + #[test] + fn test_get_condition_varies_with_offset() { + // Different offsets should give different conditions (for most cities) + let conditions: Vec<&str> = (0..6).map(|i| get_condition("NYC", i)).collect(); + // At least some should be different + let unique_count = conditions + .iter() + .collect::>() + .len(); + assert!(unique_count > 1); + } + + #[test] + fn test_get_temp_range() { + // Temperature should be in 10-35 range + for city in &["NYC", "LA", "Chicago", "Seattle", "Miami"] { + let temp = get_temp(city, 0); + assert!(temp >= 10 && temp <= 35, "Temp for {} was {}", city, temp); + } + } + + #[test] + fn test_get_temp_deterministic() { + assert_eq!(get_temp("NYC", 0), get_temp("NYC", 0)); + } + + #[test] + fn test_access_denied_result() { + let result = access_denied("mcp:read"); + assert_eq!(result.is_error, Some(true)); + assert_eq!(result.content.len(), 1); + + let text = result.content[0].text.as_ref().unwrap(); + assert!(text.contains("access_denied")); + assert!(text.contains("mcp:read")); + } + + #[test] + fn test_simulated_weather() { + let weather = get_simulated_weather("NYC"); + assert_eq!(weather["city"], "NYC"); + assert!(weather["temperature"].is_number()); + assert!(weather["condition"].is_string()); + assert!(weather["humidity"].is_number()); + assert!(weather["windSpeed"].is_number()); + } + + #[test] + fn test_simulated_forecast() { + let forecast = get_simulated_forecast("NYC"); + assert_eq!(forecast["city"], "NYC"); + let days = forecast["forecast"].as_array().unwrap(); + assert_eq!(days.len(), 5); + + for day in days { + assert!(day["day"].is_string()); + assert!(day["high"].is_number()); + assert!(day["low"].is_number()); + assert!(day["condition"].is_string()); + } + } + + #[test] + fn test_simulated_alerts() { + // Test different regions to exercise all branches + let alerts1 = get_simulated_alerts("Region1"); + let alerts2 = get_simulated_alerts("Region2"); + let alerts3 = get_simulated_alerts("Region3"); + + // All should have region and alerts fields + assert!(alerts1.get("region").is_some()); + assert!(alerts1.get("alerts").is_some()); + assert!(alerts2.get("region").is_some()); + assert!(alerts2.get("alerts").is_some()); + assert!(alerts3.get("region").is_some()); + assert!(alerts3.get("alerts").is_some()); + } + + #[test] + fn test_register_weather_tools() { + let mut mcp = McpHandler::new(); + register_weather_tools(&mut mcp, true); + + // Test get-weather tool + let auth_context = AuthContext::default(); + let result = mcp.handle_request( + json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get-weather", + "arguments": {"city": "NYC"} + } + }), + &auth_context, + ); + + assert!(result.result.is_some()); + } + + #[test] + fn test_get_forecast_with_scope() { + let mut mcp = McpHandler::new(); + register_weather_tools(&mut mcp, false); + + // Without proper scope should fail + let auth_context = AuthContext { + user_id: "user1".to_string(), + scopes: "openid".to_string(), + ..Default::default() + }; + + let result = mcp.handle_request( + json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get-forecast", + "arguments": {"city": "NYC"} + } + }), + &auth_context, + ); + + // Result should contain access_denied error + let result_value = result.result.unwrap(); + assert_eq!(result_value["isError"], true); + + // With proper scope should succeed + let auth_context = AuthContext { + user_id: "user1".to_string(), + scopes: "openid mcp:read".to_string(), + ..Default::default() + }; + + let result = mcp.handle_request( + json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get-forecast", + "arguments": {"city": "NYC"} + } + }), + &auth_context, + ); + + let result_value = result.result.unwrap(); + assert!(result_value.get("isError").is_none()); + } + + #[test] + fn test_get_weather_alerts_with_scope() { + let mut mcp = McpHandler::new(); + register_weather_tools(&mut mcp, false); + + // Without proper scope should fail + let auth_context = AuthContext { + user_id: "user1".to_string(), + scopes: "openid mcp:read".to_string(), + ..Default::default() + }; + + let result = mcp.handle_request( + json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get-weather-alerts", + "arguments": {"region": "Northeast"} + } + }), + &auth_context, + ); + + // Result should contain access_denied error + let result_value = result.result.unwrap(); + assert_eq!(result_value["isError"], true); + + // With proper scope should succeed + let auth_context = AuthContext { + user_id: "user1".to_string(), + scopes: "openid mcp:admin".to_string(), + ..Default::default() + }; + + let result = mcp.handle_request( + json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get-weather-alerts", + "arguments": {"region": "Northeast"} + } + }), + &auth_context, + ); + + let result_value = result.result.unwrap(); + assert!(result_value.get("isError").is_none()); + } + + #[test] + fn test_tools_list() { + let mut mcp = McpHandler::new(); + register_weather_tools(&mut mcp, true); + + let auth_context = AuthContext::default(); + let result = mcp.handle_request( + json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list" + }), + &auth_context, + ); + + let result_value = result.result.unwrap(); + let tools = result_value["tools"].as_array().unwrap(); + assert_eq!(tools.len(), 3); + + let tool_names: Vec<&str> = tools + .iter() + .filter_map(|t| t["name"].as_str()) + .collect(); + assert!(tool_names.contains(&"get-weather")); + assert!(tool_names.contains(&"get-forecast")); + assert!(tool_names.contains(&"get-weather-alerts")); + } +} diff --git a/install-native.sh b/install-native.sh new file mode 100755 index 00000000..ec821252 --- /dev/null +++ b/install-native.sh @@ -0,0 +1,195 @@ +#!/bin/bash +# +# install-native.sh - Download and install native libraries for gopher-mcp-rust +# +# Usage: +# ./install-native.sh [VERSION] [INSTALL_DIR] +# +# Arguments: +# VERSION - Version to install (default: latest) +# INSTALL_DIR - Installation directory (default: ./native) +# +# Examples: +# ./install-native.sh # Install latest to ./native +# ./install-native.sh v0.1.2 # Install specific version +# ./install-native.sh latest /usr/local # Install to /usr/local +# + +set -e + +# Colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +CYAN='\033[0;36m' +NC='\033[0m' + +VERSION="${1:-latest}" +INSTALL_DIR="${2:-./native}" + +echo -e "${CYAN}========================================${NC}" +echo -e "${CYAN} gopher-mcp-rust Native Library Installer${NC}" +echo -e "${CYAN}========================================${NC}" +echo "" + +# Check for gh CLI +if ! command -v gh &> /dev/null; then + echo -e "${RED}Error: GitHub CLI (gh) is not installed${NC}" + echo "Install it with: brew install gh" + echo "Then authenticate: gh auth login" + exit 1 +fi + +# Detect platform +OS=$(uname -s | tr '[:upper:]' '[:lower:]') +ARCH=$(uname -m) + +case "$OS" in + darwin) OS_NAME="macos" ;; + linux) OS_NAME="linux" ;; + mingw*|msys*|cygwin*) OS_NAME="windows" ;; + *) echo -e "${RED}Error: Unsupported OS: $OS${NC}"; exit 1 ;; +esac + +case "$ARCH" in + x86_64|amd64) ARCH_NAME="x64" ;; + arm64|aarch64) ARCH_NAME="arm64" ;; + *) echo -e "${RED}Error: Unsupported architecture: $ARCH${NC}"; exit 1 ;; +esac + +PLATFORM="${OS_NAME}-${ARCH_NAME}" +echo -e "Detected platform: ${GREEN}${PLATFORM}${NC}" + +# Determine file extension +if [ "$OS_NAME" = "windows" ]; then + ARCHIVE_EXT="zip" +else + ARCHIVE_EXT="tar.gz" +fi + +ARCHIVE_NAME="libgopher-orch-${PLATFORM}.${ARCHIVE_EXT}" + +# Get version if "latest" +if [ "$VERSION" = "latest" ]; then + echo -e "${YELLOW}Fetching latest version...${NC}" + VERSION=$(gh release view -R GopherSecurity/gopher-mcp-rust --json tagName -q '.tagName' 2>/dev/null) || { + echo -e "${RED}Error: Could not fetch latest release${NC}" + echo "Make sure the repository has releases and you have access." + exit 1 + } +fi + +echo -e "Version: ${GREEN}${VERSION}${NC}" +echo -e "Archive: ${GREEN}${ARCHIVE_NAME}${NC}" +echo -e "Install directory: ${GREEN}${INSTALL_DIR}${NC}" +echo "" + +# Create temp directory +TEMP_DIR=$(mktemp -d) +trap "rm -rf $TEMP_DIR" EXIT + +cd "$TEMP_DIR" + +# Download +echo -e "${YELLOW}Downloading native library...${NC}" +gh release download "$VERSION" \ + -R GopherSecurity/gopher-mcp-rust \ + -p "$ARCHIVE_NAME" || { + echo -e "${RED}Error: Could not download $ARCHIVE_NAME${NC}" + echo "" + echo "Available assets for $VERSION:" + gh release view "$VERSION" -R GopherSecurity/gopher-mcp-rust --json assets -q '.assets[].name' + exit 1 +} + +echo -e "${GREEN}✓ Downloaded${NC}" + +# Extract +echo -e "${YELLOW}Extracting...${NC}" + +if [ "$ARCHIVE_EXT" = "zip" ]; then + unzip -o "$ARCHIVE_NAME" +else + tar -xzf "$ARCHIVE_NAME" +fi + +echo -e "${GREEN}✓ Extracted${NC}" + +# Get absolute path for install directory +ORIGINAL_DIR=$(pwd) +cd - > /dev/null +INSTALL_DIR=$(cd "$(dirname "$INSTALL_DIR")" 2>/dev/null && pwd)/$(basename "$INSTALL_DIR") || INSTALL_DIR="$PWD/$INSTALL_DIR" +cd "$TEMP_DIR" + +# Install +echo -e "${YELLOW}Installing to ${INSTALL_DIR}...${NC}" + +# Check if we need sudo +NEED_SUDO="" +if [ ! -w "$(dirname "$INSTALL_DIR")" ] 2>/dev/null && [ ! -d "$INSTALL_DIR" ]; then + if [ ! -w "$INSTALL_DIR" ] 2>/dev/null; then + NEED_SUDO="sudo" + echo -e "${YELLOW} (requires sudo)${NC}" + fi +fi + +# Create directories +$NEED_SUDO mkdir -p "${INSTALL_DIR}/lib" +$NEED_SUDO mkdir -p "${INSTALL_DIR}/include" + +# Copy libraries +if [ -d "lib" ]; then + $NEED_SUDO cp -P lib/* "${INSTALL_DIR}/lib/" 2>/dev/null || true +fi + +# Copy headers +if [ -d "include" ]; then + $NEED_SUDO cp -r include/* "${INSTALL_DIR}/include/" 2>/dev/null || true +fi + +# Handle flat structure (files directly in archive) +$NEED_SUDO cp -P *.dylib "${INSTALL_DIR}/lib/" 2>/dev/null || true +$NEED_SUDO cp -P *.so* "${INSTALL_DIR}/lib/" 2>/dev/null || true +$NEED_SUDO cp -P *.dll "${INSTALL_DIR}/lib/" 2>/dev/null || true +$NEED_SUDO cp -P *.h "${INSTALL_DIR}/include/" 2>/dev/null || true + +echo -e "${GREEN}✓ Installed${NC}" +echo "" + +# Show installed files +echo -e "${CYAN}Installed files:${NC}" +echo " Libraries:" +ls -la "${INSTALL_DIR}/lib/"*gopher* 2>/dev/null | sed 's/^/ /' || echo " (none found)" +echo " Headers:" +ls -la "${INSTALL_DIR}/include/"*gopher* 2>/dev/null | sed 's/^/ /' || \ +ls -la "${INSTALL_DIR}/include/"orch* 2>/dev/null | sed 's/^/ /' || echo " (none found)" +echo "" + +# Print environment setup +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN} Installation Complete!${NC}" +echo -e "${GREEN}========================================${NC}" +echo "" +echo -e "${YELLOW}For Rust projects:${NC}" +echo "" +echo " # Add to Cargo.toml" +echo " [dependencies]" +echo " gopher-orch = \"${VERSION#v}\"" +echo "" +echo -e "${YELLOW}Set environment variables:${NC}" +echo "" + +if [ "$OS_NAME" = "macos" ]; then + echo " export DYLD_LIBRARY_PATH=\"${INSTALL_DIR}/lib:\$DYLD_LIBRARY_PATH\"" +elif [ "$OS_NAME" = "linux" ]; then + echo " export LD_LIBRARY_PATH=\"${INSTALL_DIR}/lib:\$LD_LIBRARY_PATH\"" + echo "" + echo -e "${YELLOW}Or add to system library path:${NC}" + echo " echo '${INSTALL_DIR}/lib' | sudo tee /etc/ld.so.conf.d/gopher-orch.conf" + echo " sudo ldconfig" +fi + +echo "" +echo -e "${CYAN}To verify installation:${NC}" +echo " cargo build" +echo "" diff --git a/src/error.rs b/src/error.rs index cd558b40..721b40f0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -29,6 +29,10 @@ pub enum Error { /// Agent has been disposed. Disposed, + + /// Authentication error (gopher-auth). + #[cfg(feature = "auth")] + Auth(String), } impl fmt::Display for Error { @@ -41,6 +45,8 @@ impl fmt::Display for Error { Error::Timeout(msg) => write!(f, "Timeout error: {}", msg), Error::Config(msg) => write!(f, "Configuration error: {}", msg), Error::Disposed => write!(f, "Agent has been disposed"), + #[cfg(feature = "auth")] + Error::Auth(msg) => write!(f, "Auth error: {}", msg), } } } @@ -72,4 +78,10 @@ impl Error { pub fn config>(msg: S) -> Self { Error::Config(msg.into()) } + + /// Create a new auth error. + #[cfg(feature = "auth")] + pub fn auth>(msg: S) -> Self { + Error::Auth(msg.into()) + } } diff --git a/src/ffi/auth.rs b/src/ffi/auth.rs new file mode 100644 index 00000000..ca165670 --- /dev/null +++ b/src/ffi/auth.rs @@ -0,0 +1,637 @@ +//! FFI Bindings to gopher-auth +//! +//! Provides safe Rust bindings to the gopher-auth native library for JWT validation. +//! +//! # Example +//! +//! ```ignore +//! use gopher_orch::ffi::auth::GopherAuthClient; +//! +//! let client = GopherAuthClient::new( +//! "https://auth.example.com/.well-known/jwks.json", +//! "https://auth.example.com" +//! )?; +//! +//! let result = client.validate_token("eyJ...", 60); +//! if result.valid { +//! println!("Token is valid!"); +//! } +//! ``` + +use std::ffi::{c_char, c_int, c_uint, c_void, CStr, CString}; +use std::ptr; +use std::sync::Arc; + +use libloading::{Library, Symbol}; + +use crate::error::Error; + +/// Result of token validation. +#[derive(Debug, Clone)] +pub struct ValidationResult { + /// Whether the token is valid. + pub valid: bool, + /// Error code (0 for success). + pub error_code: i32, + /// Error message if validation failed. + pub error_message: Option, +} + +impl ValidationResult { + /// Create a successful validation result. + pub fn success() -> Self { + Self { + valid: true, + error_code: 0, + error_message: None, + } + } + + /// Create a failed validation result. + pub fn failure(code: i32, message: impl Into) -> Self { + Self { + valid: false, + error_code: code, + error_message: Some(message.into()), + } + } +} + +/// Extracted token payload. +#[derive(Debug, Clone)] +pub struct TokenPayload { + /// Token subject (user ID). + pub subject: String, + /// Space-separated scopes. + pub scopes: String, + /// Token audience. + pub audience: String, + /// Expiration timestamp (unix seconds). + pub expiration: u64, +} + +// FFI function type definitions +type GopherAuthInitFn = unsafe extern "C" fn() -> c_int; +type GopherAuthClientCreateFn = unsafe extern "C" fn( + out: *mut *mut c_void, + jwks_uri: *const c_char, + issuer: *const c_char, +) -> c_int; +type GopherAuthClientDestroyFn = unsafe extern "C" fn(client: *mut c_void); +type GopherAuthSetOptionFn = + unsafe extern "C" fn(client: *mut c_void, key: *const c_char, value: *const c_char) -> c_int; +type GopherAuthValidateTokenFn = unsafe extern "C" fn( + client: *mut c_void, + token: *const c_char, + clock_skew: c_uint, + out_valid: *mut c_int, + out_error: *mut *mut c_char, +) -> c_int; +type GopherAuthExtractPayloadFn = + unsafe extern "C" fn(client: *mut c_void, token: *const c_char, out: *mut *mut c_void) -> c_int; +type GopherAuthPayloadGetSubjectFn = unsafe extern "C" fn(payload: *mut c_void) -> *const c_char; +type GopherAuthPayloadGetScopesFn = unsafe extern "C" fn(payload: *mut c_void) -> *const c_char; +type GopherAuthPayloadGetAudienceFn = unsafe extern "C" fn(payload: *mut c_void) -> *const c_char; +type GopherAuthPayloadGetExpirationFn = unsafe extern "C" fn(payload: *mut c_void) -> u64; +type GopherAuthPayloadDestroyFn = unsafe extern "C" fn(payload: *mut c_void); +type GopherAuthFreeStringFn = unsafe extern "C" fn(s: *mut c_char); + +/// Client for gopher-auth native library. +/// +/// Provides JWT validation and payload extraction using the gopher-auth C library. +/// +/// # Thread Safety +/// +/// The client is `Send` and `Sync`, allowing it to be shared across threads. +pub struct GopherAuthClient { + /// Opaque handle to the native client. + handle: *mut c_void, + /// Reference to the loaded library. + library: Arc, +} + +// Safety: The native library handles are thread-safe when used correctly +unsafe impl Send for GopherAuthClient {} +unsafe impl Sync for GopherAuthClient {} + +impl GopherAuthClient { + /// Create a new gopher-auth client. + /// + /// # Arguments + /// + /// * `jwks_uri` - URI to fetch JWKS from + /// * `issuer` - Expected token issuer + /// + /// # Returns + /// + /// A new client instance or an error if initialization failed. + /// + /// # Example + /// + /// ```ignore + /// let client = GopherAuthClient::new( + /// "https://auth.example.com/.well-known/jwks.json", + /// "https://auth.example.com" + /// )?; + /// ``` + pub fn new(jwks_uri: &str, issuer: &str) -> Result { + // Load the native library + let library = Self::load_library()?; + let library = Arc::new(library); + + // Initialize the library + unsafe { + let init: Symbol = library + .get(b"gopher_auth_init\0") + .map_err(|e| Error::auth(format!("Failed to load gopher_auth_init: {}", e)))?; + + let result = init(); + if result != 0 { + return Err(Error::auth(format!( + "gopher_auth_init failed with code {}", + result + ))); + } + } + + // Create the client + let jwks_uri_c = + CString::new(jwks_uri).map_err(|e| Error::auth(format!("Invalid jwks_uri: {}", e)))?; + let issuer_c = + CString::new(issuer).map_err(|e| Error::auth(format!("Invalid issuer: {}", e)))?; + + let handle = unsafe { + let create: Symbol = + library.get(b"gopher_auth_client_create\0").map_err(|e| { + Error::auth(format!("Failed to load gopher_auth_client_create: {}", e)) + })?; + + let mut handle: *mut c_void = ptr::null_mut(); + let result = create(&mut handle, jwks_uri_c.as_ptr(), issuer_c.as_ptr()); + + if result != 0 || handle.is_null() { + return Err(Error::auth(format!( + "gopher_auth_client_create failed with code {}", + result + ))); + } + + handle + }; + + Ok(Self { handle, library }) + } + + /// Load the native library from known locations. + fn load_library() -> Result { + // Try multiple library names - the auth functions are in libgopher-orch + let lib_names = if cfg!(target_os = "macos") { + vec![ + "libgopher-orch.dylib", + "libgopher-orch.0.dylib", + "libgopher_orch.dylib", + ] + } else if cfg!(target_os = "windows") { + vec!["gopher-orch.dll", "libgopher-orch.dll", "gopher_orch.dll"] + } else { + vec![ + "libgopher-orch.so", + "libgopher-orch.so.0", + "libgopher_orch.so", + ] + }; + + // Build search paths including environment-specified locations + let mut search_paths = vec![ + String::new(), // Current directory / system paths + String::from("./"), + String::from("./native/lib/"), + String::from("../native/lib/"), + ]; + + // Add paths from DYLD_LIBRARY_PATH / LD_LIBRARY_PATH + if let Ok(lib_path) = std::env::var("DYLD_LIBRARY_PATH") { + for path in lib_path.split(':') { + if !path.is_empty() { + let mut p = path.to_string(); + if !p.ends_with('/') { + p.push('/'); + } + search_paths.push(p); + } + } + } + if let Ok(lib_path) = std::env::var("LD_LIBRARY_PATH") { + for path in lib_path.split(':') { + if !path.is_empty() { + let mut p = path.to_string(); + if !p.ends_with('/') { + p.push('/'); + } + search_paths.push(p); + } + } + } + + // Add standard system paths + search_paths.push(String::from("/usr/local/lib/")); + search_paths.push(String::from("/usr/lib/")); + + for path in &search_paths { + for name in &lib_names { + let full_path = format!("{}{}", path, name); + if let Ok(lib) = unsafe { Library::new(&full_path) } { + return Ok(lib); + } + } + } + + Err(Error::auth(format!( + "Failed to load gopher-auth library. Tried: {:?}", + lib_names + ))) + } + + /// Check if the gopher-auth library is available. + /// + /// # Returns + /// + /// `true` if the library can be loaded, `false` otherwise. + pub fn is_available() -> bool { + Self::load_library().is_ok() + } + + /// Validate a JWT token. + /// + /// # Arguments + /// + /// * `token` - The JWT token string + /// * `clock_skew` - Allowed clock skew in seconds + /// + /// # Returns + /// + /// Validation result indicating success or failure. + /// + /// # Example + /// + /// ```ignore + /// let result = client.validate_token("eyJ...", 60); + /// if result.valid { + /// println!("Token is valid!"); + /// } else { + /// println!("Validation failed: {:?}", result.error_message); + /// } + /// ``` + pub fn validate_token(&self, token: &str, clock_skew: u32) -> ValidationResult { + let token_c = match CString::new(token) { + Ok(s) => s, + Err(e) => return ValidationResult::failure(-1, format!("Invalid token string: {}", e)), + }; + + unsafe { + let validate: Symbol = + match self.library.get(b"gopher_auth_validate_token\0") { + Ok(f) => f, + Err(e) => { + return ValidationResult::failure( + -1, + format!("Failed to load validate function: {}", e), + ) + } + }; + + let mut valid: c_int = 0; + let mut error: *mut c_char = ptr::null_mut(); + + let result = validate( + self.handle, + token_c.as_ptr(), + clock_skew, + &mut valid, + &mut error, + ); + + if result != 0 { + let error_msg = if !error.is_null() { + let msg = CStr::from_ptr(error).to_string_lossy().into_owned(); + self.free_string(error); + msg + } else { + format!("Validation failed with code {}", result) + }; + return ValidationResult::failure(result, error_msg); + } + + if valid != 0 { + ValidationResult::success() + } else { + let error_msg = if !error.is_null() { + let msg = CStr::from_ptr(error).to_string_lossy().into_owned(); + self.free_string(error); + msg + } else { + "Token validation failed".to_string() + }; + ValidationResult::failure(-2, error_msg) + } + } + } + + /// Extract payload from a JWT token. + /// + /// # Arguments + /// + /// * `token` - The JWT token string + /// + /// # Returns + /// + /// Extracted token payload or an error. + /// + /// # Example + /// + /// ```ignore + /// let payload = client.extract_payload("eyJ...")?; + /// println!("User: {}", payload.subject); + /// println!("Scopes: {}", payload.scopes); + /// ``` + pub fn extract_payload(&self, token: &str) -> Result { + let token_c = + CString::new(token).map_err(|e| Error::auth(format!("Invalid token string: {}", e)))?; + + unsafe { + let extract: Symbol = self + .library + .get(b"gopher_auth_extract_payload\0") + .map_err(|e| Error::auth(format!("Failed to load extract function: {}", e)))?; + + let mut payload: *mut c_void = ptr::null_mut(); + let result = extract(self.handle, token_c.as_ptr(), &mut payload); + + if result != 0 || payload.is_null() { + return Err(Error::auth(format!( + "Failed to extract payload, code {}", + result + ))); + } + + // Extract fields from payload + let subject = self.get_payload_string(payload, b"gopher_auth_payload_get_subject\0")?; + let scopes = self.get_payload_string(payload, b"gopher_auth_payload_get_scopes\0")?; + let audience = + self.get_payload_string(payload, b"gopher_auth_payload_get_audience\0")?; + let expiration = self.get_payload_expiration(payload)?; + + // Destroy the payload + self.destroy_payload(payload); + + Ok(TokenPayload { + subject, + scopes, + audience, + expiration, + }) + } + } + + /// Get a string field from a payload. + unsafe fn get_payload_string( + &self, + payload: *mut c_void, + fn_name: &[u8], + ) -> Result { + let get_fn: Symbol = self + .library + .get(fn_name) + .map_err(|e| Error::auth(format!("Failed to load getter function: {}", e)))?; + + let ptr = get_fn(payload); + if ptr.is_null() { + return Ok(String::new()); + } + + Ok(CStr::from_ptr(ptr).to_string_lossy().into_owned()) + } + + /// Get the expiration field from a payload. + unsafe fn get_payload_expiration(&self, payload: *mut c_void) -> Result { + let get_fn: Symbol = self + .library + .get(b"gopher_auth_payload_get_expiration\0") + .map_err(|e| Error::auth(format!("Failed to load expiration getter: {}", e)))?; + + Ok(get_fn(payload)) + } + + /// Destroy a payload handle. + unsafe fn destroy_payload(&self, payload: *mut c_void) { + if let Ok(destroy) = self + .library + .get::(b"gopher_auth_payload_destroy\0") + { + destroy(payload); + } + } + + /// Free a string allocated by the native library. + unsafe fn free_string(&self, s: *mut c_char) { + if let Ok(free) = self + .library + .get::(b"gopher_auth_free_string\0") + { + free(s); + } + } + + /// Set a client option. + /// + /// # Arguments + /// + /// * `key` - Option key + /// * `value` - Option value + /// + /// # Returns + /// + /// Ok if successful, Err otherwise. + /// + /// # Common Options + /// + /// - `cache_duration` - JWKS cache duration in seconds + /// - `auto_refresh` - Enable automatic JWKS refresh ("true"/"false") + /// - `request_timeout` - HTTP request timeout in seconds + pub fn set_option(&self, key: &str, value: &str) -> Result<(), Error> { + let key_c = + CString::new(key).map_err(|e| Error::auth(format!("Invalid option key: {}", e)))?; + let value_c = + CString::new(value).map_err(|e| Error::auth(format!("Invalid option value: {}", e)))?; + + unsafe { + let set_option: Symbol = self + .library + .get(b"gopher_auth_client_set_option\0") + .map_err(|e| Error::auth(format!("Failed to load set_option: {}", e)))?; + + let result = set_option(self.handle, key_c.as_ptr(), value_c.as_ptr()); + if result != 0 { + return Err(Error::auth(format!( + "Failed to set option '{}', code {}", + key, result + ))); + } + } + + Ok(()) + } + + /// Explicitly destroy the client handle. + /// + /// This is called automatically by Drop, but can be called manually + /// to release resources early. + pub fn destroy(&mut self) { + if !self.handle.is_null() { + unsafe { + if let Ok(destroy) = self + .library + .get::(b"gopher_auth_client_destroy\0") + { + destroy(self.handle); + } + } + self.handle = ptr::null_mut(); + } + } + + /// Create a dummy client for testing purposes. + /// + /// This client has a null handle and no library, and should only be + /// used in tests that need to check if a client exists without actually + /// performing any operations. + #[cfg(test)] + pub fn dummy() -> Self { + Self { + handle: ptr::null_mut(), + library: Arc::new(unsafe { + // Create a dummy library reference that won't be used + // This is safe because we never call any functions on it + Library::new("/dev/null").unwrap_or_else(|_| { + // If /dev/null doesn't work, try a path that definitely exists + #[cfg(target_os = "macos")] + { + Library::new("/usr/lib/libSystem.B.dylib") + .expect("Failed to load system library for test dummy") + } + #[cfg(target_os = "linux")] + { + Library::new("/lib/x86_64-linux-gnu/libc.so.6") + .or_else(|_| Library::new("/lib/libc.so.6")) + .expect("Failed to load system library for test dummy") + } + #[cfg(target_os = "windows")] + { + Library::new("kernel32.dll") + .expect("Failed to load system library for test dummy") + } + }) + }), + } + } +} + +impl Drop for GopherAuthClient { + fn drop(&mut self) { + self.destroy(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validation_result_success() { + let result = ValidationResult::success(); + assert!(result.valid); + assert_eq!(result.error_code, 0); + assert!(result.error_message.is_none()); + } + + #[test] + fn test_validation_result_failure() { + let result = ValidationResult::failure(-1, "Token expired"); + assert!(!result.valid); + assert_eq!(result.error_code, -1); + assert_eq!(result.error_message, Some("Token expired".to_string())); + } + + #[test] + fn test_token_payload_fields() { + let payload = TokenPayload { + subject: "user123".to_string(), + scopes: "openid profile".to_string(), + audience: "my-app".to_string(), + expiration: 1234567890, + }; + + assert_eq!(payload.subject, "user123"); + assert_eq!(payload.scopes, "openid profile"); + assert_eq!(payload.audience, "my-app"); + assert_eq!(payload.expiration, 1234567890); + } + + #[test] + fn test_token_payload_clone() { + let payload = TokenPayload { + subject: "user".to_string(), + scopes: "read write".to_string(), + audience: "api".to_string(), + expiration: 9999999999, + }; + + let cloned = payload.clone(); + assert_eq!(payload.subject, cloned.subject); + assert_eq!(payload.scopes, cloned.scopes); + } + + #[test] + fn test_is_available() { + // Just check it doesn't panic - result depends on library presence + let _ = GopherAuthClient::is_available(); + } + + // Note: The following tests require the native library to be installed + // They are marked as ignored by default and can be run with: + // cargo test --features auth --ignored + + #[test] + #[ignore] + fn test_client_creation() { + let result = GopherAuthClient::new( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ); + + // This may fail if the library is not installed + // That's expected in CI environments without the native library + if let Err(e) = result { + println!( + "Client creation failed (expected without native lib): {}", + e + ); + } + } + + #[test] + #[ignore] + fn test_client_validate_token() { + let client = match GopherAuthClient::new( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ) { + Ok(c) => c, + Err(_) => return, // Skip if library not available + }; + + // This will fail because the token is invalid + let result = client.validate_token("invalid.token.here", 0); + assert!(!result.valid); + } +} diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs new file mode 100644 index 00000000..b08bee77 --- /dev/null +++ b/src/ffi/mod.rs @@ -0,0 +1,13 @@ +//! FFI bindings to native Gopher libraries. +//! +//! This module provides safe Rust bindings to: +//! - `gopher-orch` - AI agent orchestration library +//! - `gopher-auth` - OAuth/JWT authentication library (optional, requires `auth` feature) + +pub mod orch; + +#[cfg(feature = "auth")] +pub mod auth; + +// Re-export orch types at module level for backward compatibility +pub use orch::*; diff --git a/src/ffi.rs b/src/ffi/orch.rs similarity index 100% rename from src/ffi.rs rename to src/ffi/orch.rs diff --git a/src/lib.rs b/src/lib.rs index 4232b604..35bedc4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,6 +36,10 @@ pub use config::{Config, ConfigBuilder}; pub use error::{Error, Result}; pub use result::{AgentResult, AgentResultStatus}; +// Re-export auth types when the auth feature is enabled +#[cfg(feature = "auth")] +pub use ffi::auth::{GopherAuthClient, TokenPayload, ValidationResult}; + use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Once; diff --git a/third_party/gopher-orch b/third_party/gopher-orch index 6b45ffbb..c8e7c406 160000 --- a/third_party/gopher-orch +++ b/third_party/gopher-orch @@ -1 +1 @@ -Subproject commit 6b45ffbbee74d5ae034008fc2cb2a927f3131992 +Subproject commit c8e7c40606db330142632ecf90aaa8777bc42a3a