Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,90 @@ describe('ConversationService', () => {
expect(conversationService.joinByExternalCommit).not.toHaveBeenCalled();
expect(establishedConversation.epoch).toEqual(updatedEpoch);
});

it('allows establishing cross-domain MLS 1:1 conversations', async () => {
const [conversationService, {apiClient, mlsService}] = await buildConversationService();

const mockConversationId = {id: 'mock-conversation-id', domain: 'remote.wire.com'};
const mockGroupId = 'mock-group-id';

const selfUser = {user: {id: 'self-user-id', domain: 'local.wire.com'}, client: 'self-user-client-id'};
const otherUserId = {id: 'other-user-id', domain: 'remote.wire.com'};

const remoteEpoch = 0;
const updatedEpoch = 1;

jest.spyOn(apiClient.api.conversation, 'getMLS1to1Conversation').mockResolvedValueOnce({
qualified_id: mockConversationId,
protocol: CONVERSATION_PROTOCOL.MLS,
epoch: remoteEpoch,
group_id: mockGroupId,
} as unknown as MLSConversation);

jest.spyOn(apiClient.api.conversation, 'getMLS1to1Conversation').mockResolvedValueOnce({
qualified_id: mockConversationId,
protocol: CONVERSATION_PROTOCOL.MLS,
epoch: updatedEpoch,
group_id: mockGroupId,
} as unknown as MLSConversation);

jest.spyOn(mlsService, 'wipeConversation');

const establishedConversation = await conversationService.establishMLS1to1Conversation(
mockGroupId,
selfUser,
otherUserId,
);

expect(mlsService.register1to1Conversation).toHaveBeenCalledTimes(1);
expect(mlsService.register1to1Conversation).toHaveBeenCalledWith(mockGroupId, otherUserId, selfUser, undefined);
expect(conversationService.joinByExternalCommit).not.toHaveBeenCalled();
expect(establishedConversation.epoch).toEqual(updatedEpoch);
});
});

describe('domain mismatch guards', () => {
it('throws when establishing MLS group conversation with mismatched self and conversation domains', async () => {
const [conversationService, {mlsService}] = await buildConversationService();

const groupId = 'group-domain-mismatch-establish';
const selfUserId = {id: 'self-user-id', domain: 'local.wire.com'};
const conversationQualifiedId = {id: PayloadHelper.getUUID(), domain: 'staging.zinfra.io'};

await expect(
conversationService.establishMLSGroupConversation(
groupId,
[],
selfUserId,
'self-client-id',
conversationQualifiedId,
),
).rejects.toThrow('does not match conversation domain');

expect(mlsService.registerConversation).not.toHaveBeenCalled();
});

it('throws when resetting MLS conversation if self user domain mismatches conversation domain', async () => {
const [conversationService, {apiClient}] = await buildConversationService();

const conversationId = {id: PayloadHelper.getUUID(), domain: 'staging.zinfra.io'};
jest.spyOn(apiClient, 'domain', 'get').mockReturnValue('staging.zinfra.io');

jest.spyOn(apiClient.api.conversation, 'getConversation').mockResolvedValueOnce({
qualified_id: {id: conversationId.id, domain: 'local.wire.com'},
protocol: CONVERSATION_PROTOCOL.MLS,
epoch: 1,
group_id: 'group-domain-mismatch-reset',
} as unknown as Conversation);

const resetSpy = jest.spyOn(apiClient.api.conversation, 'resetMLSConversation');

await expect((conversationService as any).resetMLSConversation(conversationId)).rejects.toThrow(
'does not match conversation domain',
);

expect(resetSpy).not.toHaveBeenCalled();
});
});

describe('handleEvent', () => {
Expand Down Expand Up @@ -1075,6 +1159,31 @@ describe('ConversationService', () => {

expect(conversationService.addUsersToMLSConversation).not.toHaveBeenCalled();
});

it('throws when self user domain does not match conversation domain', async () => {
const [conversationService, {mlsService}] = await buildConversationService();
const selfUserId = {id: 'self-user-id', domain: 'local.wire.com'};

const mockConversationId = {id: PayloadHelper.getUUID(), domain: 'staging.zinfra.io'};
const mockGroupId = 'groupId';
const otherUsersToAdd = Array(3)
.fill(0)
.map(() => ({id: PayloadHelper.getUUID(), domain: 'local.wire.com'}));

const addUsersSpy = jest.spyOn(conversationService, 'addUsersToMLSConversation');

await expect(
conversationService.tryEstablishingMLSGroup({
conversationId: mockConversationId,
groupId: mockGroupId,
qualifiedUsers: otherUsersToAdd,
selfUserId,
}),
).rejects.toThrow('does not match conversation domain');

expect(mlsService.tryEstablishingMLSGroup).not.toHaveBeenCalled();
expect(addUsersSpy).not.toHaveBeenCalled();
});
});

describe('reactToKeyMaterialUpdateFailure', () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ export class ConversationService extends TypedEventEmitter<Events> {
return this._mlsService;
}

private validateDomainMatch(selfUserDomain: string, conversationDomain: string): void {
if (selfUserDomain === conversationDomain) {
return;
}

const errorMessage = `Self user domain (${selfUserDomain}) does not match conversation domain (${conversationDomain})`;
this.logger.error(errorMessage);
throw new Error(errorMessage);
}

/**
* Get a fresh list from backend of clients for all the participants of the conversation.
* @fixme there are some case where this method is not enough to detect removed devices
Expand Down Expand Up @@ -394,6 +404,8 @@ export class ConversationService extends TypedEventEmitter<Events> {
selfClientId: string;
conversationQualifiedId: QualifiedId;
}): Promise<BaseCreateConversationResponse> {
this.validateDomainMatch(selfUserId.domain, conversationQualifiedId.domain);

const failures = await this.mlsService.registerConversation(groupId, userIdsToAdd.concat(selfUserId), {
creator: {
user: selfUserId,
Expand Down Expand Up @@ -619,36 +631,48 @@ export class ConversationService extends TypedEventEmitter<Events> {
private async resetMLSConversation(conversationId: QualifiedId): Promise<BaseCreateConversationResponse> {
this.logger.info(`Resetting MLS conversation with id ${conversationId.id}`);

// STEP 1: Fetch the conversation to retrieve the group ID & epoch
// STEP 1: fetch self user info
this.logger.info(
`Re-establishing the conversation by re-adding all members (conversation_id: ${conversationId.id})`,
);
const {validatedClientId: clientId, userId, domain: selfUserDomain} = this.apiClient;

if (!selfUserDomain) {
const errorMessage = 'Could not find domain of the self user';
this.logger.error(errorMessage, {conversationId});
throw new Error(errorMessage);
}

// STEP 2: Fetch the conversation to retrieve the group ID & epoch
const conversation = await this.apiClient.api.conversation.getConversation(conversationId);
const {group_id: groupId, epoch} = conversation;
const {
group_id: groupId,
epoch,
qualified_id: {domain: conversationDomain},
} = conversation;

this.validateDomainMatch(selfUserDomain, conversationDomain);

if (!groupId || !epoch) {
const errorMessage = 'Could not find group id or epoch for the conversation';
this.logger.error(errorMessage, {conversationId});
throw new Error(errorMessage);
}

// STEP 2: Request backend to reset the conversation
// STEP 3: Request backend to reset the conversation
this.logger.info(`Requesting backend to reset the conversation (group_id: ${groupId}, epoch: ${String(epoch)})`);
await this.apiClient.api.conversation.resetMLSConversation({
epoch,
groupId,
});

// STEP 3: fetch self user info
this.logger.info(
`Re-establishing the conversation by re-adding all members (conversation_id: ${conversationId.id})`,
);
const {validatedClientId: clientId, userId, domain} = this.apiClient;

if (!userId || !domain) {
if (!userId || !selfUserDomain) {
const errorMessage = 'Could not find userId or domain of the self user';
this.logger.error(errorMessage, {conversationId});
throw new Error(errorMessage);
}

const selfUserQualifiedId = {id: userId, domain};
const selfUserQualifiedId = {id: userId, domain: selfUserDomain};

// STEP 4: Fetch the updated conversation data from backend to retrieve the new group ID
const updatedConversation = await this.apiClient.api.conversation.getConversation(conversationId);
Expand Down Expand Up @@ -912,6 +936,8 @@ export class ConversationService extends TypedEventEmitter<Events> {
qualifiedUsers: QualifiedId[];
}): Promise<void> {
try {
this.validateDomainMatch(selfUserId.domain, conversationId.domain);

const wasGroupEstablishedBySelfClient = await this.mlsService.tryEstablishingMLSGroup(groupId);

if (!wasGroupEstablishedBySelfClient) {
Expand Down