diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py index aaa0d6f347f8..f05e0cb7b99e 100644 --- a/backend/python/mlx/backend.py +++ b/backend/python/mlx/backend.py @@ -60,6 +60,31 @@ def Health(self, request, context): """ return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + def _initialize_rdma(self): + """ + Initialize RDMA support using mlx.distributed if enabled. + """ + # Check if RDMA is enabled via environment variable + mlx_rdma_enabled = os.environ.get("MLX_GRPC_SERVERS", "") != "" + + if mlx_rdma_enabled: + try: + print("Initializing RDMA with mlx.distributed...", file=sys.stderr) + mx.distributed.init(backend="jaccl") + print(f"RDMA initialized: rank={mx.distributed.rank()}, world_size={mx.distributed.world_size()}", file=sys.stderr) + self.rdma_enabled = True + self.rdma_rank = mx.distributed.rank() + self.rdma_world_size = mx.distributed.world_size() + except Exception as e: + print(f"Failed to initialize RDMA: {e}", file=sys.stderr) + self.rdma_enabled = False + self.rdma_rank = 0 + self.rdma_world_size = 1 + else: + self.rdma_enabled = False + self.rdma_rank = 0 + self.rdma_world_size = 1 + async def LoadModel(self, request, context): """ Loads a language model using MLX. @@ -130,6 +155,9 @@ async def LoadModel(self, request, context): can_trim_fn=can_trim_prompt_cache, trim_fn=trim_prompt_cache, ) + + # Initialize RDMA support + self._initialize_rdma() except Exception as err: print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr) diff --git a/core/cli/run.go b/core/cli/run.go index a67b35fadc41..92af696eb26e 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -141,9 +141,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithAgentJobRetentionDays(r.AgentJobRetentionDays), config.WithTunnelCallback(func(tunnels []string) { tunnelEnvVar := strings.Join(tunnels, ",") - // TODO: this is very specific to llama.cpp, we should have a more generic way to set the environment variable os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar) + os.Setenv("MLX_GRPC_SERVERS", tunnelEnvVar) xlog.Debug("setting LLAMACPP_GRPC_SERVERS", "value", tunnelEnvVar) + xlog.Debug("setting MLX_GRPC_SERVERS", "value", tunnelEnvVar) }), } diff --git a/core/cli/worker/worker.go b/core/cli/worker/worker.go index 0a636c3bfacb..1b374890089c 100644 --- a/core/cli/worker/worker.go +++ b/core/cli/worker/worker.go @@ -5,9 +5,11 @@ type WorkerFlags struct { BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"` ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"` + ExtraMLXRDMAArgs string `name:"mlx-rdma-args" env:"LOCALAI_EXTRA_MLX_RDMA_ARGS,EXTRA_MLX_RDMA_ARGS" help:"Extra arguments to pass to mlx-rdma backend"` } type Worker struct { P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"` LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"` + MLXRDMA MLXRDMA `cmd:"" name:"mlx-rdma" help:"Starts a mlx-rdma worker in standalone mode"` } diff --git a/core/cli/worker/worker_mlx.go b/core/cli/worker/worker_mlx.go new file mode 100644 index 000000000000..3d7888fe600c --- /dev/null +++ b/core/cli/worker/worker_mlx.go @@ -0,0 +1,68 @@ +package worker + +import ( + "fmt" + "os" + "strings" + + cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/pkg/system" +) + +type MLXRDMA struct { + WorkerFlags `embed:""` +} + +func (r *MLXRDMA) Run(ctx *cliContext.Context) error { + if len(os.Args) < 4 { + return fmt.Errorf("usage: local-ai worker mlx-rdma -- ") + } + + systemState, err := system.GetSystemState( + system.WithBackendPath(r.BackendsPath), + system.WithBackendSystemPath(r.BackendsSystemPath), + ) + if err != nil { + return err + } + + // Get the python binary + pythonPath, err := system.GetPythonBinary(systemState) + if err != nil { + return err + } + + // Get the backend path + backendPath, err := getMLXBackendPath(systemState, r.BackendGalleries) + if err != nil { + return err + } + + // Prepare the arguments + args := strings.Split(r.ExtraMLXRDMAArgs, " ") + args = append([]string{backendPath}, args...) + + // Set environment variables for RDMA + if os.Getenv("MLX_GRPC_SERVERS") == "" { + os.Setenv("MLX_GRPC_SERVERS", os.Getenv("LLAMACPP_GRPC_SERVERS")) + } + + // Execute the backend + return system.ExecPython(pythonPath, args, os.Environ()) +} + +func getMLXBackendPath(systemState *system.SystemState, galleries string) (string, error) { + // TODO: Implement backend discovery for MLX (similar to llama.cpp) + // For now, assume the backend is at a known location + backend := "mlx" + backendPath := systemState.BackendSystemPath + + // Check if backend exists + fullPath := fmt.Sprintf("%s/%s/backend.py", backendPath, backend) + if _, err := os.Stat(fullPath); err == nil { + return fullPath, nil + } + + // Fallback: try to find the backend in the system path + return fmt.Sprintf("%s/backend.py", backendPath), nil +} \ No newline at end of file