diff --git a/Sources/OpenAI/Private/Streaming/ModelResponseEventsStreamInterpreter.swift b/Sources/OpenAI/Private/Streaming/ModelResponseEventsStreamInterpreter.swift index 559073c5..3b472b97 100644 --- a/Sources/OpenAI/Private/Streaming/ModelResponseEventsStreamInterpreter.swift +++ b/Sources/OpenAI/Private/Streaming/ModelResponseEventsStreamInterpreter.swift @@ -55,25 +55,24 @@ final class ModelResponseEventsStreamInterpreter: @unchecked Sendable, StreamInt } private func processEvent(_ event: ServerSentEventsStreamParser.Event) throws { - var finalEvent = event - if event.eventType == "response.output_text.annotation.added" { - // Remove when they have fixed (unified)! - // - // By looking at [API Reference](https://platform.openai.com/docs/api-reference/responses-streaming/response/output_text_annotation/added) - // and generated type `Schemas.ResponseOutputTextAnnotationAddedEvent` - // We can see that "output_text.annotation" is incorrect, whereas output_text_annotation is the correct one - let fixedDataString = event.decodedData.replacingOccurrences(of: "response.output_text.annotation.added", with: "response.output_text_annotation.added") - finalEvent = .init(id: event.id, data: fixedDataString.data(using: .utf8) ?? event.data, decodedData: fixedDataString, eventType: "response.output_text_annotation.added", retry: event.retry) + let finalEvent = event.fixMappingError() + var eventType = finalEvent.eventType + + /// If the SSE `event` property is not specified by the provider service, our parser defaults it to "message" which is not a valid model response type. + /// In this case we check the `data.type` property for a valid model response type. + if eventType == "message" || eventType.isEmpty, + let payloadEventType = finalEvent.getPayloadType() { + eventType = payloadEventType } - - guard let modelResponseEventType = ModelResponseStreamEventType(rawValue: finalEvent.eventType) else { - throw InterpreterError.unknownEventType(finalEvent.eventType) + + guard let modelResponseEventType = ModelResponseStreamEventType(rawValue: eventType) else { + throw InterpreterError.unknownEventType(eventType) } let responseStreamEvent = try responseStreamEvent(modelResponseEventType: modelResponseEventType, data: finalEvent.data) onEventDispatched?(responseStreamEvent) } - + private func processError(_ error: Error) { onError?(error) } @@ -210,3 +209,35 @@ final class ModelResponseEventsStreamInterpreter: @unchecked Sendable, StreamInt try decoder.decode(T.self, from: data) } } + +private extension ServerSentEventsStreamParser.Event { + + // Remove when they have fixed (unified)! + // + // By looking at [API Reference](https://platform.openai.com/docs/api-reference/responses-streaming/response/output_text_annotation/added) + // and generated type `Schemas.ResponseOutputTextAnnotationAddedEvent` + // We can see that "output_text.annotation" is incorrect, whereas output_text_annotation is the correct one + func fixMappingError() -> Self { + let incorrectEventType = "response.output_text.annotation.added" + let correctEventType = "response.output_text_annotation.added" + + guard self.eventType == incorrectEventType || self.getPayloadType() == incorrectEventType else { + return self + } + + let fixedDataString = self.decodedData.replacingOccurrences(of: incorrectEventType, with: correctEventType) + return .init( + id: self.id, + data: fixedDataString.data(using: .utf8) ?? self.data, + decodedData: fixedDataString, + eventType: correctEventType, + retry: self.retry + ) + } + + struct TypeEnvelope: Decodable { let type: String } + + func getPayloadType() -> String? { + try? JSONDecoder().decode(TypeEnvelope.self, from: self.data).type + } +} diff --git a/Tests/OpenAITests/MockServerSentEvent.swift b/Tests/OpenAITests/MockServerSentEvent.swift index 39cea945..f09a183e 100644 --- a/Tests/OpenAITests/MockServerSentEvent.swift +++ b/Tests/OpenAITests/MockServerSentEvent.swift @@ -20,4 +20,18 @@ struct MockServerSentEvent { static func chatCompletionError() -> Data { "{\n \"error\": {\n \"message\": \"The model `o3-mini` does not exist or you do not have access to it.\",\n \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\": \"model_not_found\"\n }\n}\n".data(using: .utf8)! } + + static func responseStreamEvent( + itemId: String = "msg_1", + payloadType: String, + outputIndex: Int = 0, + contentIndex: Int = 0, + delta: String = "", + sequenceNumber: Int = 1 + ) -> Data { + let json = """ + {"type":"\(payloadType)","output_index":\(outputIndex),"item_id":"\(itemId)","content_index":\(contentIndex),"delta":"\(delta)","sequence_number":\(sequenceNumber)} + """ + return "data: \(json)\n\n".data(using: .utf8)! + } } diff --git a/Tests/OpenAITests/ModelResponseEventsStreamInterpreterTests.swift b/Tests/OpenAITests/ModelResponseEventsStreamInterpreterTests.swift index 9e3952af..37422ce1 100644 --- a/Tests/OpenAITests/ModelResponseEventsStreamInterpreterTests.swift +++ b/Tests/OpenAITests/ModelResponseEventsStreamInterpreterTests.swift @@ -39,4 +39,49 @@ final class ModelResponseEventsStreamInterpreterTests: XCTestCase { XCTAssertNotNil(receivedError, "Expected an error to be received, but got nil.") XCTAssertTrue(receivedError is APIErrorResponse, "Expected received error to be of type APIErrorResponse.") } + + func testParsesOutputTextDeltaUsingPayloadType() async throws { + let expectation = XCTestExpectation(description: "OutputText delta event received") + var receivedEvent: ResponseStreamEvent? + + interpreter.setCallbackClosures { event in + Task { + await MainActor.run { + receivedEvent = event + expectation.fulfill() + } + } + } onError: { error in + XCTFail("Unexpected error received: \(error)") + } + + interpreter.processData( + MockServerSentEvent.responseStreamEvent( + itemId: "msg_1", + payloadType: "response.output_text.delta", + outputIndex: 0, + contentIndex: 0, + delta: "Hi", + sequenceNumber: 1 + ) + ) + + await fulfillment(of: [expectation], timeout: 1.0) + + guard let receivedEvent else { + XCTFail("No event received") + return + } + + switch receivedEvent { + case .outputText(.delta(let deltaEvent)): + XCTAssertEqual(deltaEvent.itemId, "msg_1") + XCTAssertEqual(deltaEvent.outputIndex, 0) + XCTAssertEqual(deltaEvent.contentIndex, 0) + XCTAssertEqual(deltaEvent.delta, "Hi") + XCTAssertEqual(deltaEvent.sequenceNumber, 1) + default: + XCTFail("Expected .outputText(.delta), got \(receivedEvent)") + } + } }