Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions backend/python/mlx/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,31 @@ def Health(self, request, context):
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

def _initialize_rdma(self):
"""
Initialize RDMA support using mlx.distributed if enabled.
"""
# Check if RDMA is enabled via environment variable
mlx_rdma_enabled = os.environ.get("MLX_GRPC_SERVERS", "") != ""

if mlx_rdma_enabled:
try:
print("Initializing RDMA with mlx.distributed...", file=sys.stderr)
mx.distributed.init(backend="jaccl")
print(f"RDMA initialized: rank={mx.distributed.rank()}, world_size={mx.distributed.world_size()}", file=sys.stderr)
self.rdma_enabled = True
self.rdma_rank = mx.distributed.rank()
self.rdma_world_size = mx.distributed.world_size()
except Exception as e:
print(f"Failed to initialize RDMA: {e}", file=sys.stderr)
self.rdma_enabled = False
self.rdma_rank = 0
self.rdma_world_size = 1
else:
self.rdma_enabled = False
self.rdma_rank = 0
self.rdma_world_size = 1

async def LoadModel(self, request, context):
"""
Loads a language model using MLX.
Expand Down Expand Up @@ -130,6 +155,9 @@ async def LoadModel(self, request, context):
can_trim_fn=can_trim_prompt_cache,
trim_fn=trim_prompt_cache,
)

# Initialize RDMA support
self._initialize_rdma()

except Exception as err:
print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr)
Expand Down
3 changes: 2 additions & 1 deletion core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithAgentJobRetentionDays(r.AgentJobRetentionDays),
config.WithTunnelCallback(func(tunnels []string) {
tunnelEnvVar := strings.Join(tunnels, ",")
// TODO: this is very specific to llama.cpp, we should have a more generic way to set the environment variable
os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar)
os.Setenv("MLX_GRPC_SERVERS", tunnelEnvVar)
xlog.Debug("setting LLAMACPP_GRPC_SERVERS", "value", tunnelEnvVar)
xlog.Debug("setting MLX_GRPC_SERVERS", "value", tunnelEnvVar)
}),
}

Expand Down
2 changes: 2 additions & 0 deletions core/cli/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ type WorkerFlags struct {
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
ExtraMLXRDMAArgs string `name:"mlx-rdma-args" env:"LOCALAI_EXTRA_MLX_RDMA_ARGS,EXTRA_MLX_RDMA_ARGS" help:"Extra arguments to pass to mlx-rdma backend"`
}

type Worker struct {
P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"`
LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
MLXRDMA MLXRDMA `cmd:"" name:"mlx-rdma" help:"Starts a mlx-rdma worker in standalone mode"`
}
68 changes: 68 additions & 0 deletions core/cli/worker/worker_mlx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package worker

import (
"fmt"
"os"
"strings"

cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/pkg/system"
)

type MLXRDMA struct {
WorkerFlags `embed:""`
}

func (r *MLXRDMA) Run(ctx *cliContext.Context) error {
if len(os.Args) < 4 {
return fmt.Errorf("usage: local-ai worker mlx-rdma -- <mlx-rdma-args>")
}

systemState, err := system.GetSystemState(
system.WithBackendPath(r.BackendsPath),
system.WithBackendSystemPath(r.BackendsSystemPath),
)
if err != nil {
return err
}

// Get the python binary
pythonPath, err := system.GetPythonBinary(systemState)
if err != nil {
return err
}

// Get the backend path
backendPath, err := getMLXBackendPath(systemState, r.BackendGalleries)
if err != nil {
return err
}

// Prepare the arguments
args := strings.Split(r.ExtraMLXRDMAArgs, " ")
args = append([]string{backendPath}, args...)

// Set environment variables for RDMA
if os.Getenv("MLX_GRPC_SERVERS") == "" {
os.Setenv("MLX_GRPC_SERVERS", os.Getenv("LLAMACPP_GRPC_SERVERS"))
}

// Execute the backend
return system.ExecPython(pythonPath, args, os.Environ())
}

func getMLXBackendPath(systemState *system.SystemState, galleries string) (string, error) {
// TODO: Implement backend discovery for MLX (similar to llama.cpp)
// For now, assume the backend is at a known location
backend := "mlx"
backendPath := systemState.BackendSystemPath

// Check if backend exists
fullPath := fmt.Sprintf("%s/%s/backend.py", backendPath, backend)
if _, err := os.Stat(fullPath); err == nil {
return fullPath, nil
}

// Fallback: try to find the backend in the system path
return fmt.Sprintf("%s/backend.py", backendPath), nil
}