33This module handles connections to one or more MCP servers, including setup,
44initialization, and communication.
55"""
6-
76import os
87import shutil
98from contextlib import AsyncExitStack
10- from typing import Dict , List , Any , Optional , Tuple
11- from rich .console import Console
12- from rich .panel import Panel
9+ from typing import Any , Dict , List , Optional , Tuple
10+
11+ import mcp .types
12+ import ollama
13+ import rich .json
1314from mcp import ClientSession , Tool
14- from mcp .client .stdio import stdio_client , StdioServerParameters
15+ from mcp .client .session import LoggingFnT , SamplingFnT
1516from mcp .client .sse import sse_client
17+ from mcp .client .stdio import StdioServerParameters , stdio_client
1618from mcp .client .streamable_http import streamablehttp_client
19+ from mcp .shared .context import LifespanContextT , RequestContext
20+ from rich .console import Console , Group
21+ from rich .markdown import Markdown
22+ from rich .panel import Panel
1723
1824from .discovery import process_server_paths , parse_server_configs , auto_discover_servers
1925
@@ -25,14 +31,16 @@ class ServerConnector:
2531 tools provided by those servers.
2632 """
2733
28- def __init__ (self , exit_stack : AsyncExitStack , console : Optional [Console ] = None ):
34+ def __init__ (self , exit_stack : AsyncExitStack , default_model : str , ollama_client : ollama . AsyncClient , console : Optional [Console ] = None ):
2935 """Initialize the ServerConnector.
3036
3137 Args:
3238 exit_stack: AsyncExitStack to manage server connections
3339 console: Rich console for output (optional)
3440 """
3541 self .exit_stack = exit_stack
42+ self .ollama = ollama_client
43+ self .default_model = default_model
3644 self .console = console or Console ()
3745 self .sessions = {} # Dict to store multiple sessions
3846 self .available_tools = [] # List to store all available tools
@@ -97,6 +105,45 @@ async def connect_to_servers(self, server_paths=None, config_path=None, auto_dis
97105
98106 return self .sessions , self .available_tools , self .enabled_tools
99107
108+ def create_log_callback (self , server_name : str ) -> LoggingFnT :
109+ async def log_callback (params : mcp .types .LoggingMessageNotificationParams ) -> None :
110+ self .console .log (f"[green]\[{ params .level .upper ()} ] - { server_name } :[/green]" , params .data )
111+
112+ return log_callback
113+
114+ def create_sampling_callback (self ) -> SamplingFnT :
115+ async def _sampling_handler (
116+ context : RequestContext [ClientSession , LifespanContextT ],
117+ params : mcp .types .CreateMessageRequestParams ,
118+ ) -> mcp .types .CreateMessageResult | mcp .types .ErrorData :
119+ messages = [{"role" : "system" , "content" : params .systemPrompt }] + [{'role' : msg .role , 'content' : msg .content .text } for msg in params .messages ]
120+ self .console .print (Panel (Group (* (Panel (Markdown (msg ["content" ]), title = msg ["role" ], ) for msg in messages )), title = "🧠 Handling sampling request..." , border_style = "cyan" , expand = False ))
121+ try :
122+ response = await self .ollama .chat (
123+ self .default_model ,
124+ messages ,
125+ options = {
126+ "temperature" : params .temperature ,
127+ "num_predict" : params .maxTokens ,
128+ "stop" : params .stopSequences ,
129+ }
130+ )
131+ except Exception as e :
132+ self .console .print_exception ()
133+ return mcp .types .ErrorData (
134+ code = mcp .types .INTERNAL_ERROR ,
135+ message = str (e ),
136+ )
137+ else :
138+ return mcp .CreateMessageResult (
139+ role = "assistant" ,
140+ model = "fastmcp-client" ,
141+ content = mcp .types .TextContent (type = "text" , text = response .message .content ),
142+ )
143+
144+ return _sampling_handler
145+
146+
100147 async def _connect_to_server (self , server : Dict [str , Any ]) -> bool :
101148 """Connect to a single MCP server
102149
@@ -126,7 +173,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
126173 # Connect using SSE transport
127174 sse_transport = await self .exit_stack .enter_async_context (sse_client (url , headers = headers ))
128175 read_stream , write_stream = sse_transport
129- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
176+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
130177
131178 elif server_type == "streamable_http" :
132179 # Connect to Streamable HTTP server
@@ -142,7 +189,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
142189 streamablehttp_client (url , headers = headers )
143190 )
144191 read_stream , write_stream , session_info = transport
145- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
192+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
146193
147194 # Store session ID if provided
148195 if hasattr (session_info , 'session_id' ) and session_info .session_id :
@@ -156,7 +203,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
156203
157204 stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
158205 read_stream , write_stream = stdio_transport
159- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
206+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
160207
161208 else :
162209 # Connect to config-based server using STDIO
@@ -166,7 +213,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
166213
167214 stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
168215 read_stream , write_stream = stdio_transport
169- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
216+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
170217
171218 # Initialize the session
172219 await session .initialize ()
0 commit comments