77import os
88import shutil
99from contextlib import AsyncExitStack
10- from typing import Dict , List , Any , Optional , Tuple
11- from rich .console import Console
12- from rich .panel import Panel
10+ from typing import Any , Dict , List , Optional , Tuple
11+
12+ import mcp .types
13+ import ollama
1314from mcp import ClientSession , Tool
14- from mcp .client .stdio import stdio_client , StdioServerParameters
15+ from mcp .client .session import LoggingFnT , SamplingFnT
1516from mcp .client .sse import sse_client
17+ from mcp .client .stdio import StdioServerParameters , stdio_client
1618from mcp .client .streamable_http import streamablehttp_client
19+ from mcp .shared .context import LifespanContextT , RequestContext
20+ from rich .console import Console
21+ from rich .panel import Panel
1722
1823from .discovery import process_server_paths , parse_server_configs , auto_discover_servers
1924
@@ -25,14 +30,16 @@ class ServerConnector:
2530 tools provided by those servers.
2631 """
2732
28- def __init__ (self , exit_stack : AsyncExitStack , console : Optional [Console ] = None ):
33+ def __init__ (self , exit_stack : AsyncExitStack , default_model : str , ollama_client : ollama . AsyncClient , console : Optional [Console ] = None ):
2934 """Initialize the ServerConnector.
3035
3136 Args:
3237 exit_stack: AsyncExitStack to manage server connections
3338 console: Rich console for output (optional)
3439 """
3540 self .exit_stack = exit_stack
41+ self .ollama = ollama_client
42+ self .default_model = default_model
3643 self .console = console or Console ()
3744 self .sessions = {} # Dict to store multiple sessions
3845 self .available_tools = [] # List to store all available tools
@@ -97,6 +104,40 @@ async def connect_to_servers(self, server_paths=None, config_path=None, auto_dis
97104
98105 return self .sessions , self .available_tools , self .enabled_tools
99106
107+ def create_log_callback (self , server_name : str ) -> LoggingFnT :
108+ async def log_callback (params : mcp .types .LoggingMessageNotificationParams ) -> None :
109+ self .console .log (f"[green]\[{ params .level .upper ()} ] - { server_name } :[/green]" , params .data )
110+
111+ return log_callback
112+
113+ def create_sampling_callback (self ) -> SamplingFnT :
114+ async def _sampling_handler (
115+ context : RequestContext [ClientSession , LifespanContextT ],
116+ params : mcp .types .CreateMessageRequestParams ,
117+ ) -> mcp .types .CreateMessageResult | mcp .types .ErrorData :
118+ self .console .print ("[cyan]Handling sampling request...[/cyan]" )
119+ try :
120+ response = await self .ollama .chat (self .default_model , [{'role' : msg .role , 'content' : msg .content .text } for msg in params .messages ], options = {
121+ "temperature" : params .temperature ,
122+ "num_predict" : params .maxTokens ,
123+ "stop" : params .stopSequences ,
124+ })
125+ except Exception as e :
126+ self .console .print_exception ()
127+ return mcp .types .ErrorData (
128+ code = mcp .types .INTERNAL_ERROR ,
129+ message = str (e ),
130+ )
131+ else :
132+ return mcp .CreateMessageResult (
133+ role = "assistant" ,
134+ model = "fastmcp-client" ,
135+ content = mcp .types .TextContent (type = "text" , text = response .message .content ),
136+ )
137+
138+ return _sampling_handler
139+
140+
100141 async def _connect_to_server (self , server : Dict [str , Any ]) -> bool :
101142 """Connect to a single MCP server
102143
@@ -126,7 +167,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
126167 # Connect using SSE transport
127168 sse_transport = await self .exit_stack .enter_async_context (sse_client (url , headers = headers ))
128169 read_stream , write_stream = sse_transport
129- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
170+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
130171
131172 elif server_type == "streamable_http" :
132173 # Connect to Streamable HTTP server
@@ -142,7 +183,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
142183 streamablehttp_client (url , headers = headers )
143184 )
144185 read_stream , write_stream , session_info = transport
145- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
186+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
146187
147188 # Store session ID if provided
148189 if hasattr (session_info , 'session_id' ) and session_info .session_id :
@@ -156,7 +197,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
156197
157198 stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
158199 read_stream , write_stream = stdio_transport
159- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
200+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
160201
161202 else :
162203 # Connect to config-based server using STDIO
@@ -166,7 +207,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
166207
167208 stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
168209 read_stream , write_stream = stdio_transport
169- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
210+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
170211
171212 # Initialize the session
172213 await session .initialize ()
0 commit comments