Skip to content

Commit cf16dcd

Browse files
committed
fix: fetch protocol details per ID to avoid incorrect data when retrieving protocol list from protocols.io API
1 parent fbbdc93 commit cf16dcd

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

src/protocols_io_mcp/tools/protocol.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async def to_string(step: "ProtocolStepInput") -> str:
4848
response_get_protocol = await helpers.access_protocols_io_resource("GET", f"/v4/protocols/{protocol_id}")
4949
if response_get_protocol["status_code"] != 0:
5050
return response_get_protocol["status_text"]
51-
protocol = Protocol.from_api_response(response_get_protocol["payload"])
51+
protocol = await Protocol.from_protocol_id(response_get_protocol["payload"]["id"])
5252
step_content += f"- {protocol.title}[{protocol.id}] {protocol.doi}\n"
5353
return step_content
5454

@@ -118,15 +118,17 @@ class Protocol(BaseModel):
118118
published_on: Annotated[datetime | None, Field(description="Date and time the protocol was published, if the protocol is private, this will be null")] = None
119119

120120
@classmethod
121-
def from_api_response(cls, data: dict) -> "Protocol":
121+
async def from_protocol_id(cls, protocol_id: int) -> "Protocol":
122+
response = await helpers.access_protocols_io_resource("GET", f"/v4/protocols/{protocol_id}?content_format=markdown")
123+
protocol = response["payload"]
122124
return cls(
123-
id=data["id"],
124-
title=data["title"],
125-
description=data.get("description") or "",
126-
doi=data.get("doi") or None,
127-
url=data["url"],
128-
created_on=datetime.fromtimestamp(data.get("created_on"), tz=timezone.utc),
129-
published_on=datetime.fromtimestamp(data.get("published_on"), tz=timezone.utc) if data.get("published_on") else None,
125+
id=protocol_id,
126+
title=protocol["title"],
127+
description=protocol.get("description") or "",
128+
doi=protocol.get("doi") or None,
129+
url=protocol.get("url"),
130+
created_on=datetime.fromtimestamp(protocol.get("created_on"), tz=timezone.utc),
131+
published_on=datetime.fromtimestamp(protocol.get("published_on"), tz=timezone.utc) if protocol.get("published_on") else None
130132
)
131133

132134
class ProtocolSearchResult(BaseModel):
@@ -135,8 +137,8 @@ class ProtocolSearchResult(BaseModel):
135137
total_pages: Annotated[int, Field(description="Total number of pages available for the search results")]
136138

137139
@classmethod
138-
def from_api_response(cls, data: dict) -> "ProtocolSearchResult":
139-
protocols = [Protocol.from_api_response(protocol) for protocol in data["items"]]
140+
async def from_api_response(cls, data: dict) -> "ProtocolSearchResult":
141+
protocols = [await Protocol.from_protocol_id(protocol["id"]) for protocol in data["items"]]
140142
return cls(
141143
protocols=protocols,
142144
current_page=data["pagination"]["current_page"],
@@ -156,10 +158,11 @@ async def search_public_protocols(
156158
page: Annotated[int, Field(description="Page number for pagination, starting from 1")] = 1,
157159
) -> ProtocolSearchResult | ErrorMessage:
158160
"""Search for public protocols on protocols.io using a keyword."""
161+
page = page - 1 # weird bug in protocols.io API where it returns page 2 if page 1 is requested
159162
response = await helpers.access_protocols_io_resource("GET", f"/v3/protocols?filter=public&key={keyword}&page_size=3&page_id={page}")
160163
if response["status_code"] != 0:
161164
return ErrorMessage.from_string(response["error_message"])
162-
search_result = ProtocolSearchResult.from_api_response(response)
165+
search_result = await ProtocolSearchResult.from_api_response(response)
163166
return search_result
164167

165168
@mcp.tool()
@@ -172,7 +175,7 @@ async def get_my_protocols() -> list[Protocol] | ErrorMessage:
172175
response = await helpers.access_protocols_io_resource("GET", f"/v3/researchers/{user.username}/protocols?filter=user_all")
173176
if response["status_code"] != 0:
174177
return ErrorMessage.from_api_response(response["error_message"])
175-
protocols = [Protocol.from_api_response(protocol) for protocol in response.get("items")]
178+
protocols = [await Protocol.from_protocol_id(protocol["id"]) for protocol in response.get("items")]
176179
return protocols
177180

178181
@mcp.tool()
@@ -183,7 +186,7 @@ async def get_protocol(
183186
response = await helpers.access_protocols_io_resource("GET", f"/v4/protocols/{protocol_id}")
184187
if response["status_code"] != 0:
185188
return ErrorMessage.from_string(response["status_text"])
186-
protocol = Protocol.from_api_response(response["payload"])
189+
protocol = await Protocol.from_protocol_id(response["payload"]["id"])
187190
return protocol
188191

189192
@mcp.tool()
@@ -206,15 +209,15 @@ async def create_protocol(
206209
response_create_blank_protocol = await helpers.access_protocols_io_resource("POST", f"/v3/protocols/{uuid.uuid4().hex}", {"type_id": 1})
207210
if response_create_blank_protocol["status_code"] != 0:
208211
return ErrorMessage.from_string(response_create_blank_protocol["status_text"])
209-
protocol = Protocol.from_api_response(response_create_blank_protocol["protocol"])
212+
protocol = await Protocol.from_protocol_id(response_create_blank_protocol["protocol"]["id"])
210213
data = {"title": title, "description": description}
211214
response_update_protocol = await helpers.access_protocols_io_resource("PUT", f"/v4/protocols/{protocol.id}", data)
212215
if response_update_protocol["status_code"] != 0:
213216
return ErrorMessage.from_string(response_update_protocol["status_text"])
214217
response_get_protocol = await helpers.access_protocols_io_resource("GET", f"/v4/protocols/{protocol.id}")
215218
if response_get_protocol["status_code"] != 0:
216219
return ErrorMessage.from_string(response_get_protocol["status_text"])
217-
protocol = Protocol.from_api_response(response_get_protocol["payload"])
220+
protocol = await Protocol.from_protocol_id(response_get_protocol["payload"]["id"])
218221
return protocol
219222

220223
@mcp.tool()
@@ -230,7 +233,7 @@ async def update_protocol_title(
230233
response_get_protocol = await helpers.access_protocols_io_resource("GET", f"/v4/protocols/{protocol_id}")
231234
if response_get_protocol["status_code"] != 0:
232235
return ErrorMessage.from_string(response_get_protocol["status_text"])
233-
protocol = Protocol.from_api_response(response_get_protocol["payload"])
236+
protocol = await Protocol.from_protocol_id(response_get_protocol["payload"]["id"])
234237
return protocol
235238

236239
@mcp.tool()
@@ -246,7 +249,7 @@ async def update_protocol_description(
246249
response_get_protocol = await helpers.access_protocols_io_resource("GET", f"/v4/protocols/{protocol_id}")
247250
if response_get_protocol["status_code"] != 0:
248251
return ErrorMessage.from_string(response_get_protocol["status_text"])
249-
protocol = Protocol.from_api_response(response_get_protocol["payload"])
252+
protocol = await Protocol.from_protocol_id(response_get_protocol["payload"]["id"])
250253
return protocol
251254

252255
@mcp.tool()

0 commit comments

Comments
 (0)