diff --git a/libraries/core/src/conversation/ConversationService/ConversationService.test.ts b/libraries/core/src/conversation/ConversationService/ConversationService.test.ts index d8df2e92738..47d09bf2528 100644 --- a/libraries/core/src/conversation/ConversationService/ConversationService.test.ts +++ b/libraries/core/src/conversation/ConversationService/ConversationService.test.ts @@ -679,6 +679,90 @@ describe('ConversationService', () => { expect(conversationService.joinByExternalCommit).not.toHaveBeenCalled(); expect(establishedConversation.epoch).toEqual(updatedEpoch); }); + + it('allows establishing cross-domain MLS 1:1 conversations', async () => { + const [conversationService, {apiClient, mlsService}] = await buildConversationService(); + + const mockConversationId = {id: 'mock-conversation-id', domain: 'remote.wire.com'}; + const mockGroupId = 'mock-group-id'; + + const selfUser = {user: {id: 'self-user-id', domain: 'local.wire.com'}, client: 'self-user-client-id'}; + const otherUserId = {id: 'other-user-id', domain: 'remote.wire.com'}; + + const remoteEpoch = 0; + const updatedEpoch = 1; + + jest.spyOn(apiClient.api.conversation, 'getMLS1to1Conversation').mockResolvedValueOnce({ + qualified_id: mockConversationId, + protocol: CONVERSATION_PROTOCOL.MLS, + epoch: remoteEpoch, + group_id: mockGroupId, + } as unknown as MLSConversation); + + jest.spyOn(apiClient.api.conversation, 'getMLS1to1Conversation').mockResolvedValueOnce({ + qualified_id: mockConversationId, + protocol: CONVERSATION_PROTOCOL.MLS, + epoch: updatedEpoch, + group_id: mockGroupId, + } as unknown as MLSConversation); + + jest.spyOn(mlsService, 'wipeConversation'); + + const establishedConversation = await conversationService.establishMLS1to1Conversation( + mockGroupId, + selfUser, + otherUserId, + ); + + expect(mlsService.register1to1Conversation).toHaveBeenCalledTimes(1); + expect(mlsService.register1to1Conversation).toHaveBeenCalledWith(mockGroupId, otherUserId, selfUser, undefined); + expect(conversationService.joinByExternalCommit).not.toHaveBeenCalled(); + expect(establishedConversation.epoch).toEqual(updatedEpoch); + }); + }); + + describe('domain mismatch guards', () => { + it('throws when establishing MLS group conversation with mismatched self and conversation domains', async () => { + const [conversationService, {mlsService}] = await buildConversationService(); + + const groupId = 'group-domain-mismatch-establish'; + const selfUserId = {id: 'self-user-id', domain: 'local.wire.com'}; + const conversationQualifiedId = {id: PayloadHelper.getUUID(), domain: 'staging.zinfra.io'}; + + await expect( + conversationService.establishMLSGroupConversation( + groupId, + [], + selfUserId, + 'self-client-id', + conversationQualifiedId, + ), + ).rejects.toThrow('does not match conversation domain'); + + expect(mlsService.registerConversation).not.toHaveBeenCalled(); + }); + + it('throws when resetting MLS conversation if self user domain mismatches conversation domain', async () => { + const [conversationService, {apiClient}] = await buildConversationService(); + + const conversationId = {id: PayloadHelper.getUUID(), domain: 'staging.zinfra.io'}; + jest.spyOn(apiClient, 'domain', 'get').mockReturnValue('staging.zinfra.io'); + + jest.spyOn(apiClient.api.conversation, 'getConversation').mockResolvedValueOnce({ + qualified_id: {id: conversationId.id, domain: 'local.wire.com'}, + protocol: CONVERSATION_PROTOCOL.MLS, + epoch: 1, + group_id: 'group-domain-mismatch-reset', + } as unknown as Conversation); + + const resetSpy = jest.spyOn(apiClient.api.conversation, 'resetMLSConversation'); + + await expect((conversationService as any).resetMLSConversation(conversationId)).rejects.toThrow( + 'does not match conversation domain', + ); + + expect(resetSpy).not.toHaveBeenCalled(); + }); }); describe('handleEvent', () => { @@ -1075,6 +1159,31 @@ describe('ConversationService', () => { expect(conversationService.addUsersToMLSConversation).not.toHaveBeenCalled(); }); + + it('throws when self user domain does not match conversation domain', async () => { + const [conversationService, {mlsService}] = await buildConversationService(); + const selfUserId = {id: 'self-user-id', domain: 'local.wire.com'}; + + const mockConversationId = {id: PayloadHelper.getUUID(), domain: 'staging.zinfra.io'}; + const mockGroupId = 'groupId'; + const otherUsersToAdd = Array(3) + .fill(0) + .map(() => ({id: PayloadHelper.getUUID(), domain: 'local.wire.com'})); + + const addUsersSpy = jest.spyOn(conversationService, 'addUsersToMLSConversation'); + + await expect( + conversationService.tryEstablishingMLSGroup({ + conversationId: mockConversationId, + groupId: mockGroupId, + qualifiedUsers: otherUsersToAdd, + selfUserId, + }), + ).rejects.toThrow('does not match conversation domain'); + + expect(mlsService.tryEstablishingMLSGroup).not.toHaveBeenCalled(); + expect(addUsersSpy).not.toHaveBeenCalled(); + }); }); describe('reactToKeyMaterialUpdateFailure', () => { diff --git a/libraries/core/src/conversation/ConversationService/ConversationService.ts b/libraries/core/src/conversation/ConversationService/ConversationService.ts index 4f701945ec4..a59464ffce5 100644 --- a/libraries/core/src/conversation/ConversationService/ConversationService.ts +++ b/libraries/core/src/conversation/ConversationService/ConversationService.ts @@ -138,6 +138,16 @@ export class ConversationService extends TypedEventEmitter { return this._mlsService; } + private validateDomainMatch(selfUserDomain: string, conversationDomain: string): void { + if (selfUserDomain === conversationDomain) { + return; + } + + const errorMessage = `Self user domain (${selfUserDomain}) does not match conversation domain (${conversationDomain})`; + this.logger.error(errorMessage); + throw new Error(errorMessage); + } + /** * Get a fresh list from backend of clients for all the participants of the conversation. * @fixme there are some case where this method is not enough to detect removed devices @@ -394,6 +404,8 @@ export class ConversationService extends TypedEventEmitter { selfClientId: string; conversationQualifiedId: QualifiedId; }): Promise { + this.validateDomainMatch(selfUserId.domain, conversationQualifiedId.domain); + const failures = await this.mlsService.registerConversation(groupId, userIdsToAdd.concat(selfUserId), { creator: { user: selfUserId, @@ -619,9 +631,27 @@ export class ConversationService extends TypedEventEmitter { private async resetMLSConversation(conversationId: QualifiedId): Promise { this.logger.info(`Resetting MLS conversation with id ${conversationId.id}`); - // STEP 1: Fetch the conversation to retrieve the group ID & epoch + // STEP 1: fetch self user info + this.logger.info( + `Re-establishing the conversation by re-adding all members (conversation_id: ${conversationId.id})`, + ); + const {validatedClientId: clientId, userId, domain: selfUserDomain} = this.apiClient; + + if (!selfUserDomain) { + const errorMessage = 'Could not find domain of the self user'; + this.logger.error(errorMessage, {conversationId}); + throw new Error(errorMessage); + } + + // STEP 2: Fetch the conversation to retrieve the group ID & epoch const conversation = await this.apiClient.api.conversation.getConversation(conversationId); - const {group_id: groupId, epoch} = conversation; + const { + group_id: groupId, + epoch, + qualified_id: {domain: conversationDomain}, + } = conversation; + + this.validateDomainMatch(selfUserDomain, conversationDomain); if (!groupId || !epoch) { const errorMessage = 'Could not find group id or epoch for the conversation'; @@ -629,26 +659,20 @@ export class ConversationService extends TypedEventEmitter { throw new Error(errorMessage); } - // STEP 2: Request backend to reset the conversation + // STEP 3: Request backend to reset the conversation this.logger.info(`Requesting backend to reset the conversation (group_id: ${groupId}, epoch: ${String(epoch)})`); await this.apiClient.api.conversation.resetMLSConversation({ epoch, groupId, }); - // STEP 3: fetch self user info - this.logger.info( - `Re-establishing the conversation by re-adding all members (conversation_id: ${conversationId.id})`, - ); - const {validatedClientId: clientId, userId, domain} = this.apiClient; - - if (!userId || !domain) { + if (!userId || !selfUserDomain) { const errorMessage = 'Could not find userId or domain of the self user'; this.logger.error(errorMessage, {conversationId}); throw new Error(errorMessage); } - const selfUserQualifiedId = {id: userId, domain}; + const selfUserQualifiedId = {id: userId, domain: selfUserDomain}; // STEP 4: Fetch the updated conversation data from backend to retrieve the new group ID const updatedConversation = await this.apiClient.api.conversation.getConversation(conversationId); @@ -912,6 +936,8 @@ export class ConversationService extends TypedEventEmitter { qualifiedUsers: QualifiedId[]; }): Promise { try { + this.validateDomainMatch(selfUserId.domain, conversationId.domain); + const wasGroupEstablishedBySelfClient = await this.mlsService.tryEstablishingMLSGroup(groupId); if (!wasGroupEstablishedBySelfClient) {