diff --git a/pom.xml b/pom.xml
index 42a8068..33e7653 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
cn.bigmodel.openapi
oapi-java-sdk
- release-V4-2.4.4
+ release-V4-2.4.5
jar
diff --git a/src/main/java/com/zhipu/oapi/service/v4/model/ChatTool.java b/src/main/java/com/zhipu/oapi/service/v4/model/ChatTool.java
index c611585..0fae408 100644
--- a/src/main/java/com/zhipu/oapi/service/v4/model/ChatTool.java
+++ b/src/main/java/com/zhipu/oapi/service/v4/model/ChatTool.java
@@ -22,6 +22,8 @@ public class ChatTool extends ObjectNode {
@JsonProperty("web_search")
private WebSearch web_search;
+ private MCPTool mcp;
+
public ChatTool(){
super(JsonNodeFactory.instance);
}
@@ -48,4 +50,9 @@ public void setWeb_search(WebSearch web_search) {
this.web_search = web_search;
this.putPOJO("web_search",web_search);
}
+
+ public void setMcp(MCPTool mcp) {
+ this.mcp = mcp;
+ this.putPOJO("mcp", mcp);
+ }
}
diff --git a/src/main/java/com/zhipu/oapi/service/v4/model/ChatToolType.java b/src/main/java/com/zhipu/oapi/service/v4/model/ChatToolType.java
index 2fef94b..863522b 100644
--- a/src/main/java/com/zhipu/oapi/service/v4/model/ChatToolType.java
+++ b/src/main/java/com/zhipu/oapi/service/v4/model/ChatToolType.java
@@ -6,7 +6,10 @@ public enum ChatToolType {
RETRIEVAL("retrieval"),
- FUNCTION("function");
+ FUNCTION("function"),
+
+ MCP("mcp"),
+ ;
private final String value;
diff --git a/src/main/java/com/zhipu/oapi/service/v4/model/MCPTool.java b/src/main/java/com/zhipu/oapi/service/v4/model/MCPTool.java
new file mode 100644
index 0000000..a27ba3a
--- /dev/null
+++ b/src/main/java/com/zhipu/oapi/service/v4/model/MCPTool.java
@@ -0,0 +1,75 @@
+package com.zhipu.oapi.service.v4.model;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.node.JsonNodeFactory;
+import com.fasterxml.jackson.databind.node.ObjectNode;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+
+import java.io.Serializable;
+import java.util.Map;
+import java.util.Set;
+
+@Getter
+public class MCPTool extends ObjectNode {
+ /**
+ * mcp server 的标识,用于区分不同的 mcp server,必填
+ */
+ private String server_label;
+
+ /**
+ * mcp server 的 url,非必填
+ * 默认(若该字段为空):以 server_label 作为 mcpCode,连接智谱AI的 mcp servers,
+ */
+ private String server_url;
+
+ /**
+ * mcp 调用的传输方式:sse/streamable-http,默认为 streamable-http
+ */
+ private String transport_type;
+
+ /**
+ * 允许调用的工具列表,默认为空,即允许所有工具
+ */
+ private Set allowed_tools;
+
+ /**
+ * 连接 mcp server 的 headers,鉴权使用
+ */
+ private Map headers;
+
+ public MCPTool() {
+ super(JsonNodeFactory.instance);
+ }
+
+ public MCPTool(JsonNodeFactory nc, Map kids) {
+ super(nc, kids);
+ }
+
+ public void setServer_label(String server_label) {
+ this.server_label = server_label;
+ this.put("server_label", server_label);
+ }
+
+ public void setServer_url(String server_url) {
+ this.server_url = server_url;
+ this.put("server_url", server_url);
+ }
+
+ public void setTransport_type(String transport_type) {
+ this.transport_type = transport_type;
+ this.put("transport_type", transport_type);
+ }
+
+ public void setAllowed_tools(Set allowed_tools) {
+ this.allowed_tools = allowed_tools;
+ this.putPOJO("allowed_tools", allowed_tools);
+ }
+
+ public void setHeaders(Map headers) {
+ this.headers = headers;
+ this.putPOJO("headers", headers);
+ }
+}
diff --git a/src/main/java/com/zhipu/oapi/service/v4/model/McpToolTransportType.java b/src/main/java/com/zhipu/oapi/service/v4/model/McpToolTransportType.java
new file mode 100644
index 0000000..d8f8cff
--- /dev/null
+++ b/src/main/java/com/zhipu/oapi/service/v4/model/McpToolTransportType.java
@@ -0,0 +1,16 @@
+package com.zhipu.oapi.service.v4.model;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+@AllArgsConstructor
+@Getter
+public enum McpToolTransportType {
+
+ SSE("sse", "SSE"),
+ STREAMABLE_HTTP("streamable-http", "可流式传输的HTTP");
+
+ private final String code;
+ private final String value;
+
+}
diff --git a/src/test/java/com/zhipu/oapi/McpTest.java b/src/test/java/com/zhipu/oapi/McpTest.java
new file mode 100644
index 0000000..33ebcf4
--- /dev/null
+++ b/src/test/java/com/zhipu/oapi/McpTest.java
@@ -0,0 +1,156 @@
+package com.zhipu.oapi;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.zhipu.oapi.service.v4.model.*;
+import com.zhipu.oapi.utils.StringUtils;
+import io.reactivex.Flowable;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testcontainers.junit.jupiter.Testcontainers;
+
+import java.util.*;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+
+@Testcontainers
+public class McpTest {
+
+ private final static Logger logger = LoggerFactory.getLogger(TestAssistantClientApiService.class);
+ private static final String ZHIPUAI_API_KEY = getTestApiKey();
+
+ private static ClientV4 client = null;
+
+ private static final String requestIdTemplate = "mycompany-%d";
+
+ private static final ObjectMapper mapper = new ObjectMapper();
+
+ static {
+ client = new ClientV4.Builder(ZHIPUAI_API_KEY)
+ .enableTokenCache()
+ .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
+ .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
+ .build();
+ }
+
+ private static String getTestApiKey() {
+ String apiKey = Constants.getApiKey();
+ return apiKey != null ? apiKey : "test-api-key.test-api-secret";
+ }
+
+ @Test
+ void testMcpTool_ServerUrl_SSE() throws JsonProcessingException {
+ // MCP 参数构建部分
+ Map headers = new HashMap<>();
+ headers.put("Authorization", "Bearer" + ZHIPUAI_API_KEY);
+ MCPTool mcpTool = new MCPTool();
+ mcpTool.setServer_label("sougou_search");
+ mcpTool.setServer_url("https://open.bigmodel.cn/api/mcp/sogou/sse");
+ mcpTool.setTransport_type(McpToolTransportType.SSE.getCode());
+ mcpTool.setHeaders(headers);
+
+ ChatTool chatTool = new ChatTool();
+ chatTool.setType(ChatToolType.MCP.value());
+ chatTool.setMcp(mcpTool);
+
+ List messages = new ArrayList<>();
+ ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "今天是几月几号?");
+
+ messages.add(chatMessage);
+ String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
+ ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
+ .model(Constants.ModelChatGLM4)
+ .stream(Boolean.FALSE)
+ .messages(messages)
+ .requestId(requestId)
+ .invokeMethod(Constants.invokeMethod)
+ .tools(Collections.singletonList(chatTool))
+ .build();
+ ModelApiResponse modelApiResp = client.invokeModelApi(chatCompletionRequest);
+ logger.info("model output: {}", mapper.writeValueAsString(modelApiResp));
+ }
+
+ @Test
+ void testMcpTool_ServerLabel() throws JsonProcessingException {
+ // MCP 参数构建部分
+ Map headers = new HashMap<>();
+ headers.put("Authorization", "Bearer " + ZHIPUAI_API_KEY);
+ MCPTool mcpTool = new MCPTool();
+ mcpTool.setServer_label("aviation");
+ mcpTool.setHeaders(headers);
+
+ ChatTool chatTool = new ChatTool();
+ chatTool.setType(ChatToolType.MCP.value());
+ chatTool.setMcp(mcpTool);
+
+ List messages = new ArrayList<>();
+ ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "北京现在天气怎么样?");
+
+ messages.add(chatMessage);
+ String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
+ ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
+ .model(Constants.ModelChatGLM4)
+ .stream(Boolean.TRUE)
+ .messages(messages)
+ .requestId(requestId)
+ .invokeMethod(Constants.invokeMethod)
+ .tools(Collections.singletonList(chatTool))
+ .build();
+ ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
+ if (sseModelApiResp.isSuccess()) {
+ AtomicBoolean isFirst = new AtomicBoolean(true);
+ List choices = new ArrayList<>();
+ AtomicReference lastAccumulator = new AtomicReference<>();
+
+ mapStreamToAccumulator(sseModelApiResp.getFlowable())
+ .doOnNext(accumulator -> {
+ {
+ if (isFirst.getAndSet(false)) {
+ logger.info("Response: ");
+ }
+ if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
+ String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
+ logger.info("tool_calls: {}", jsonString);
+ }
+ if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
+ logger.info(accumulator.getDelta().getContent());
+ }
+ choices.add(accumulator.getChoice());
+ lastAccumulator.set(accumulator);
+
+ }
+ })
+ .doOnComplete(() -> System.out.println("Stream completed."))
+ .doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors
+ .blockingSubscribe();// Use blockingSubscribe instead of blockingGet()
+
+ ChatMessageAccumulator chatMessageAccumulator = lastAccumulator.get();
+ ModelData data = new ModelData();
+ data.setChoices(choices);
+ if (chatMessageAccumulator != null) {
+ data.setUsage(chatMessageAccumulator.getUsage());
+ data.setId(chatMessageAccumulator.getId());
+ data.setCreated(chatMessageAccumulator.getCreated());
+ }
+ data.setRequestId(chatCompletionRequest.getRequestId());
+ sseModelApiResp.setFlowable(null);// 打印前置空
+ sseModelApiResp.setData(data);
+ }
+ logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
+ client.getConfig().getHttpClient().dispatcher().executorService().shutdown();
+
+ client.getConfig().getHttpClient().connectionPool().evictAll();
+ // List all active threads
+ for (Thread t : Thread.getAllStackTraces().keySet()) {
+ logger.info("Thread: " + t.getName() + " State: " + t.getState());
+ }
+ }
+
+ public static Flowable mapStreamToAccumulator(Flowable flowable) {
+ return flowable.map(chunk -> {
+ return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
+ });
+ }
+}