diff --git a/pom.xml b/pom.xml index 42a8068..33e7653 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ cn.bigmodel.openapi oapi-java-sdk - release-V4-2.4.4 + release-V4-2.4.5 jar diff --git a/src/main/java/com/zhipu/oapi/service/v4/model/ChatTool.java b/src/main/java/com/zhipu/oapi/service/v4/model/ChatTool.java index c611585..0fae408 100644 --- a/src/main/java/com/zhipu/oapi/service/v4/model/ChatTool.java +++ b/src/main/java/com/zhipu/oapi/service/v4/model/ChatTool.java @@ -22,6 +22,8 @@ public class ChatTool extends ObjectNode { @JsonProperty("web_search") private WebSearch web_search; + private MCPTool mcp; + public ChatTool(){ super(JsonNodeFactory.instance); } @@ -48,4 +50,9 @@ public void setWeb_search(WebSearch web_search) { this.web_search = web_search; this.putPOJO("web_search",web_search); } + + public void setMcp(MCPTool mcp) { + this.mcp = mcp; + this.putPOJO("mcp", mcp); + } } diff --git a/src/main/java/com/zhipu/oapi/service/v4/model/ChatToolType.java b/src/main/java/com/zhipu/oapi/service/v4/model/ChatToolType.java index 2fef94b..863522b 100644 --- a/src/main/java/com/zhipu/oapi/service/v4/model/ChatToolType.java +++ b/src/main/java/com/zhipu/oapi/service/v4/model/ChatToolType.java @@ -6,7 +6,10 @@ public enum ChatToolType { RETRIEVAL("retrieval"), - FUNCTION("function"); + FUNCTION("function"), + + MCP("mcp"), + ; private final String value; diff --git a/src/main/java/com/zhipu/oapi/service/v4/model/MCPTool.java b/src/main/java/com/zhipu/oapi/service/v4/model/MCPTool.java new file mode 100644 index 0000000..a27ba3a --- /dev/null +++ b/src/main/java/com/zhipu/oapi/service/v4/model/MCPTool.java @@ -0,0 +1,75 @@ +package com.zhipu.oapi.service.v4.model; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.Getter; +import lombok.NoArgsConstructor; + +import java.io.Serializable; +import java.util.Map; +import java.util.Set; + +@Getter +public class MCPTool extends ObjectNode { + /** + * mcp server 的标识,用于区分不同的 mcp server,必填 + */ + private String server_label; + + /** + * mcp server 的 url,非必填 + * 默认(若该字段为空):以 server_label 作为 mcpCode,连接智谱AI的 mcp servers, + */ + private String server_url; + + /** + * mcp 调用的传输方式:sse/streamable-http,默认为 streamable-http + */ + private String transport_type; + + /** + * 允许调用的工具列表,默认为空,即允许所有工具 + */ + private Set allowed_tools; + + /** + * 连接 mcp server 的 headers,鉴权使用 + */ + private Map headers; + + public MCPTool() { + super(JsonNodeFactory.instance); + } + + public MCPTool(JsonNodeFactory nc, Map kids) { + super(nc, kids); + } + + public void setServer_label(String server_label) { + this.server_label = server_label; + this.put("server_label", server_label); + } + + public void setServer_url(String server_url) { + this.server_url = server_url; + this.put("server_url", server_url); + } + + public void setTransport_type(String transport_type) { + this.transport_type = transport_type; + this.put("transport_type", transport_type); + } + + public void setAllowed_tools(Set allowed_tools) { + this.allowed_tools = allowed_tools; + this.putPOJO("allowed_tools", allowed_tools); + } + + public void setHeaders(Map headers) { + this.headers = headers; + this.putPOJO("headers", headers); + } +} diff --git a/src/main/java/com/zhipu/oapi/service/v4/model/McpToolTransportType.java b/src/main/java/com/zhipu/oapi/service/v4/model/McpToolTransportType.java new file mode 100644 index 0000000..d8f8cff --- /dev/null +++ b/src/main/java/com/zhipu/oapi/service/v4/model/McpToolTransportType.java @@ -0,0 +1,16 @@ +package com.zhipu.oapi.service.v4.model; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@AllArgsConstructor +@Getter +public enum McpToolTransportType { + + SSE("sse", "SSE"), + STREAMABLE_HTTP("streamable-http", "可流式传输的HTTP"); + + private final String code; + private final String value; + +} diff --git a/src/test/java/com/zhipu/oapi/McpTest.java b/src/test/java/com/zhipu/oapi/McpTest.java new file mode 100644 index 0000000..33ebcf4 --- /dev/null +++ b/src/test/java/com/zhipu/oapi/McpTest.java @@ -0,0 +1,156 @@ +package com.zhipu.oapi; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.zhipu.oapi.service.v4.model.*; +import com.zhipu.oapi.utils.StringUtils; +import io.reactivex.Flowable; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.*; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +@Testcontainers +public class McpTest { + + private final static Logger logger = LoggerFactory.getLogger(TestAssistantClientApiService.class); + private static final String ZHIPUAI_API_KEY = getTestApiKey(); + + private static ClientV4 client = null; + + private static final String requestIdTemplate = "mycompany-%d"; + + private static final ObjectMapper mapper = new ObjectMapper(); + + static { + client = new ClientV4.Builder(ZHIPUAI_API_KEY) + .enableTokenCache() + .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS) + .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS)) + .build(); + } + + private static String getTestApiKey() { + String apiKey = Constants.getApiKey(); + return apiKey != null ? apiKey : "test-api-key.test-api-secret"; + } + + @Test + void testMcpTool_ServerUrl_SSE() throws JsonProcessingException { + // MCP 参数构建部分 + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer" + ZHIPUAI_API_KEY); + MCPTool mcpTool = new MCPTool(); + mcpTool.setServer_label("sougou_search"); + mcpTool.setServer_url("https://open.bigmodel.cn/api/mcp/sogou/sse"); + mcpTool.setTransport_type(McpToolTransportType.SSE.getCode()); + mcpTool.setHeaders(headers); + + ChatTool chatTool = new ChatTool(); + chatTool.setType(ChatToolType.MCP.value()); + chatTool.setMcp(mcpTool); + + List messages = new ArrayList<>(); + ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "今天是几月几号?"); + + messages.add(chatMessage); + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(Constants.ModelChatGLM4) + .stream(Boolean.FALSE) + .messages(messages) + .requestId(requestId) + .invokeMethod(Constants.invokeMethod) + .tools(Collections.singletonList(chatTool)) + .build(); + ModelApiResponse modelApiResp = client.invokeModelApi(chatCompletionRequest); + logger.info("model output: {}", mapper.writeValueAsString(modelApiResp)); + } + + @Test + void testMcpTool_ServerLabel() throws JsonProcessingException { + // MCP 参数构建部分 + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + ZHIPUAI_API_KEY); + MCPTool mcpTool = new MCPTool(); + mcpTool.setServer_label("aviation"); + mcpTool.setHeaders(headers); + + ChatTool chatTool = new ChatTool(); + chatTool.setType(ChatToolType.MCP.value()); + chatTool.setMcp(mcpTool); + + List messages = new ArrayList<>(); + ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "北京现在天气怎么样?"); + + messages.add(chatMessage); + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(Constants.ModelChatGLM4) + .stream(Boolean.TRUE) + .messages(messages) + .requestId(requestId) + .invokeMethod(Constants.invokeMethod) + .tools(Collections.singletonList(chatTool)) + .build(); + ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest); + if (sseModelApiResp.isSuccess()) { + AtomicBoolean isFirst = new AtomicBoolean(true); + List choices = new ArrayList<>(); + AtomicReference lastAccumulator = new AtomicReference<>(); + + mapStreamToAccumulator(sseModelApiResp.getFlowable()) + .doOnNext(accumulator -> { + { + if (isFirst.getAndSet(false)) { + logger.info("Response: "); + } + if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) { + String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls()); + logger.info("tool_calls: {}", jsonString); + } + if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) { + logger.info(accumulator.getDelta().getContent()); + } + choices.add(accumulator.getChoice()); + lastAccumulator.set(accumulator); + + } + }) + .doOnComplete(() -> System.out.println("Stream completed.")) + .doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors + .blockingSubscribe();// Use blockingSubscribe instead of blockingGet() + + ChatMessageAccumulator chatMessageAccumulator = lastAccumulator.get(); + ModelData data = new ModelData(); + data.setChoices(choices); + if (chatMessageAccumulator != null) { + data.setUsage(chatMessageAccumulator.getUsage()); + data.setId(chatMessageAccumulator.getId()); + data.setCreated(chatMessageAccumulator.getCreated()); + } + data.setRequestId(chatCompletionRequest.getRequestId()); + sseModelApiResp.setFlowable(null);// 打印前置空 + sseModelApiResp.setData(data); + } + logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp)); + client.getConfig().getHttpClient().dispatcher().executorService().shutdown(); + + client.getConfig().getHttpClient().connectionPool().evictAll(); + // List all active threads + for (Thread t : Thread.getAllStackTraces().keySet()) { + logger.info("Thread: " + t.getName() + " State: " + t.getState()); + } + } + + public static Flowable mapStreamToAccumulator(Flowable flowable) { + return flowable.map(chunk -> { + return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId()); + }); + } +}