diff --git a/app/src/main/java/com/urik/keyboard/UrikInputMethodService.kt b/app/src/main/java/com/urik/keyboard/UrikInputMethodService.kt index 91c6589..b1b494d 100644 --- a/app/src/main/java/com/urik/keyboard/UrikInputMethodService.kt +++ b/app/src/main/java/com/urik/keyboard/UrikInputMethodService.kt @@ -41,6 +41,7 @@ import com.urik.keyboard.service.EnglishPronounI import com.urik.keyboard.service.InputMethod import com.urik.keyboard.service.InputStateManager import com.urik.keyboard.service.LanguageManager +import com.urik.keyboard.service.LastAutocorrection import com.urik.keyboard.service.OutputBridge import com.urik.keyboard.service.PostCommitReplacementState import com.urik.keyboard.service.ProcessingResult @@ -1150,6 +1151,7 @@ class UrikInputMethodService : } inputState.postCommitReplacementState = null + inputState.lastAutocorrection = null inputState.isActivelyEditing = true @@ -1272,6 +1274,7 @@ class UrikInputMethodService : inputState.postCommitReplacementState = null swipeKeyboardView?.clearSuggestions() } + inputState.lastAutocorrection = null if (inputState.spellConfirmationState == SpellConfirmationState.AWAITING_CONFIRMATION) { outputBridge.beginBatchEdit() @@ -2062,11 +2065,22 @@ class UrikInputMethodService : inputState.displayBuffer = word inputState.composingRegionStart = wordStart - suggestionPipeline.requestSuggestions( - buffer = word, - inputMethod = InputMethod.TYPED, - isCharacterInput = false - ) + val autocorrection = inputState.lastAutocorrection + if (autocorrection != null && + word.equals(autocorrection.correctedWord, ignoreCase = true) + ) { + outputBridge.setComposingText(autocorrection.originalTypedWord, 1) + inputState.displayBuffer = autocorrection.originalTypedWord + inputState.pendingSuggestions = emptyList() + swipeKeyboardView?.clearSuggestions() + } else { + inputState.lastAutocorrection = null + suggestionPipeline.requestSuggestions( + buffer = word, + inputMethod = InputMethod.TYPED, + isCharacterInput = false + ) + } } else { coordinateStateClear() } @@ -2192,7 +2206,10 @@ class UrikInputMethodService : return@launch } - if (currentSettings.autocorrectionEnabled && displaySuggestions.isNotEmpty()) { + if (currentSettings.autocorrectionEnabled && + displaySuggestions.isNotEmpty() && + inputState.lastAutocorrection == null + ) { val topSuggestion = displaySuggestions.first() if (isSafeForAutocorrect(topSuggestion)) { val originalWord = inputState.displayBuffer @@ -2209,8 +2226,12 @@ class UrikInputMethodService : originalWord = originalWord, committedWord = topSuggestion ) - inputState.pendingSuggestions = - displaySuggestions.drop(1) + listOf(originalWord) + inputState.lastAutocorrection = + LastAutocorrection( + originalTypedWord = originalWord, + correctedWord = topSuggestion + ) + inputState.pendingSuggestions = listOf(originalWord) swipeKeyboardView?.updateSuggestions(inputState.pendingSuggestions) val textBefore = outputBridge.safeGetTextBeforeCursor(50) diff --git a/app/src/main/java/com/urik/keyboard/service/InputStateManager.kt b/app/src/main/java/com/urik/keyboard/service/InputStateManager.kt index e248e21..12ebe08 100644 --- a/app/src/main/java/com/urik/keyboard/service/InputStateManager.kt +++ b/app/src/main/java/com/urik/keyboard/service/InputStateManager.kt @@ -9,6 +9,8 @@ enum class SpellConfirmationState { data class PostCommitReplacementState(val originalWord: String, val committedWord: String) +data class LastAutocorrection(val originalTypedWord: String, val correctedWord: String) + interface ViewCallback { fun clearSuggestions() @@ -113,6 +115,10 @@ class InputStateManager( var postCommitReplacementState: PostCommitReplacementState? = null internal set + @Volatile + var lastAutocorrection: LastAutocorrection? = null + internal set + val selectionStateTracker = SelectionStateTracker() val requiresDirectCommit: Boolean @@ -170,6 +176,7 @@ class InputStateManager( spellConfirmationState = SpellConfirmationState.NORMAL pendingWordForLearning = null postCommitReplacementState = null + lastAutocorrection = null viewCallback.clearSuggestions() composingRegionStart = -1 composingReassertionCount = 0 @@ -209,6 +216,7 @@ class InputStateManager( spellConfirmationState = SpellConfirmationState.NORMAL pendingWordForLearning = null postCommitReplacementState = null + lastAutocorrection = null viewCallback.clearSuggestions() composingRegionStart = -1 lastKnownCursorPosition = -1 diff --git a/app/src/test/java/com/urik/keyboard/service/InputStateManagerTest.kt b/app/src/test/java/com/urik/keyboard/service/InputStateManagerTest.kt new file mode 100644 index 0000000..be9e2b0 --- /dev/null +++ b/app/src/test/java/com/urik/keyboard/service/InputStateManagerTest.kt @@ -0,0 +1,96 @@ +package com.urik.keyboard.service + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNotNull +import org.junit.Assert.assertNull +import org.junit.Before +import org.junit.Test + +class InputStateManagerTest { + private var suggestionsCleared = false + private var lastSuggestions: List = emptyList() + + private lateinit var stateManager: InputStateManager + + @Before + fun setup() { + suggestionsCleared = false + lastSuggestions = emptyList() + + val viewCallback = object : ViewCallback { + override fun clearSuggestions() { + suggestionsCleared = true + } + + override fun updateSuggestions(suggestions: List) { + lastSuggestions = suggestions + } + } + + stateManager = InputStateManager( + viewCallback = viewCallback, + onShiftStateChanged = {}, + isCapsLockOn = { false }, + cancelDebounceJob = {} + ) + } + + @Test + fun `lastAutocorrection persists independently of postCommitReplacementState`() { + stateManager.lastAutocorrection = LastAutocorrection("teh", "the") + stateManager.postCommitReplacementState = PostCommitReplacementState("teh", "the") + + stateManager.postCommitReplacementState = null + + assertNotNull(stateManager.lastAutocorrection) + assertEquals("teh", stateManager.lastAutocorrection?.originalTypedWord) + assertEquals("the", stateManager.lastAutocorrection?.correctedWord) + } + + @Test + fun `clearInternalStateOnly clears lastAutocorrection`() { + stateManager.lastAutocorrection = LastAutocorrection("teh", "the") + + stateManager.clearInternalStateOnly() + + assertNull(stateManager.lastAutocorrection) + } + + @Test + fun `invalidateComposingState clears lastAutocorrection`() { + stateManager.lastAutocorrection = LastAutocorrection("teh", "the") + + stateManager.invalidateComposingState() + + assertNull(stateManager.lastAutocorrection) + } + + @Test + fun `clearInternalStateOnly clears postCommitReplacementState`() { + stateManager.postCommitReplacementState = PostCommitReplacementState("teh", "the") + + stateManager.clearInternalStateOnly() + + assertNull(stateManager.postCommitReplacementState) + } + + @Test + fun `clearBigramPredictions does not affect lastAutocorrection`() { + stateManager.lastAutocorrection = LastAutocorrection("teh", "the") + stateManager.isShowingBigramPredictions = true + + stateManager.clearBigramPredictions() + + assertNotNull(stateManager.lastAutocorrection) + } + + @Test + fun `clearSpellConfirmationFields does not affect lastAutocorrection`() { + stateManager.lastAutocorrection = LastAutocorrection("teh", "the") + stateManager.spellConfirmationState = SpellConfirmationState.AWAITING_CONFIRMATION + + stateManager.clearSpellConfirmationFields() + + assertNotNull(stateManager.lastAutocorrection) + } +} diff --git a/app/src/test/java/com/urik/keyboard/service/SuggestionPipelineTest.kt b/app/src/test/java/com/urik/keyboard/service/SuggestionPipelineTest.kt new file mode 100644 index 0000000..37ae888 --- /dev/null +++ b/app/src/test/java/com/urik/keyboard/service/SuggestionPipelineTest.kt @@ -0,0 +1,223 @@ +@file:Suppress("ktlint:standard:no-wildcard-imports") + +package com.urik.keyboard.service + +import android.view.inputmethod.InputConnection +import com.urik.keyboard.data.WordFrequencyRepository +import com.urik.keyboard.data.database.LearnedWord +import com.urik.keyboard.model.KeyboardState +import com.urik.keyboard.ui.keyboard.components.SwipeDetector +import com.urik.keyboard.utils.CaseTransformer +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.resetMain +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.test.setMain +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNull +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.kotlin.any +import org.mockito.kotlin.eq +import org.mockito.kotlin.mock +import org.mockito.kotlin.never +import org.mockito.kotlin.times +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever +import org.robolectric.RobolectricTestRunner + +@OptIn(ExperimentalCoroutinesApi::class) +@RunWith(RobolectricTestRunner::class) +class SuggestionPipelineTest { + private val testDispatcher = StandardTestDispatcher() + + private lateinit var mockIc: InputConnection + private lateinit var mockSwipeDetector: SwipeDetector + private lateinit var mockSwipeSpaceManager: SwipeSpaceManager + private lateinit var mockTextInputProcessor: TextInputProcessor + private lateinit var mockSpellCheckManager: SpellCheckManager + private lateinit var mockWordLearningEngine: WordLearningEngine + private lateinit var mockWordFrequencyRepository: WordFrequencyRepository + private lateinit var mockLanguageManager: LanguageManager + private lateinit var mockCaseTransformer: CaseTransformer + + private lateinit var inputState: InputStateManager + private lateinit var outputBridge: OutputBridge + private lateinit var pipeline: SuggestionPipeline + + private var capturedSuggestions: List = emptyList() + private var suggestionsCleared = false + + @Before + fun setup() = runBlocking { + Dispatchers.setMain(testDispatcher) + + mockIc = mock() + mockSwipeDetector = mock() + mockSwipeSpaceManager = mock() + mockTextInputProcessor = mock() + mockSpellCheckManager = mock() + mockWordLearningEngine = mock() + mockWordFrequencyRepository = mock() + mockLanguageManager = mock() + mockCaseTransformer = mock() + + whenever(mockIc.beginBatchEdit()).thenReturn(true) + whenever(mockIc.endBatchEdit()).thenReturn(true) + whenever(mockIc.commitText(any(), any())).thenReturn(true) + whenever(mockIc.deleteSurroundingText(any(), any())).thenReturn(true) + whenever(mockIc.finishComposingText()).thenReturn(true) + whenever(mockLanguageManager.currentLanguage).thenReturn( + kotlinx.coroutines.flow.MutableStateFlow("en") + ) + whenever(mockWordLearningEngine.learnWord(any(), any())).thenReturn(Result.success(null as LearnedWord?)) + + val viewCallback = object : ViewCallback { + override fun clearSuggestions() { + suggestionsCleared = true + } + + override fun updateSuggestions(suggestions: List) { + capturedSuggestions = suggestions + } + } + + inputState = InputStateManager( + viewCallback = viewCallback, + onShiftStateChanged = {}, + isCapsLockOn = { false }, + cancelDebounceJob = {} + ) + + outputBridge = OutputBridge( + state = inputState, + swipeDetector = mockSwipeDetector, + swipeSpaceManager = mockSwipeSpaceManager, + icProvider = { mockIc } + ) + + pipeline = SuggestionPipeline( + state = inputState, + outputBridge = outputBridge, + textInputProcessor = mockTextInputProcessor, + spellCheckManager = mockSpellCheckManager, + wordLearningEngine = mockWordLearningEngine, + wordFrequencyRepository = mockWordFrequencyRepository, + languageManager = mockLanguageManager, + caseTransformer = mockCaseTransformer, + serviceScope = kotlinx.coroutines.CoroutineScope(testDispatcher), + showSuggestions = { true }, + effectiveSuggestionCount = { 3 }, + getKeyboardState = { KeyboardState() }, + shouldAutoCapitalize = { false }, + currentLanguageProvider = { "en" } + ) + } + + @After + fun teardown() { + Dispatchers.resetMain() + } + + @Test + fun `coordinatePostCommitReplacement learns word on autocorrect undo`() = runTest(testDispatcher) { + val replacementState = PostCommitReplacementState( + originalWord = "teh", + committedWord = "the" + ) + whenever(mockIc.getTextBeforeCursor(4, 0)).thenReturn("the ") + whenever(mockTextInputProcessor.getCurrentSettings()).thenReturn( + com.urik.keyboard.settings.KeyboardSettings() + ) + whenever(mockSpellCheckManager.isWordInSymSpellDictionary(any())).thenReturn(false) + + pipeline.coordinatePostCommitReplacement( + selectedSuggestion = "teh", + replacementState = replacementState, + checkAutoCapitalization = {} + ) + + verify(mockWordLearningEngine).learnWord("teh", InputMethod.TYPED) + verify(mockWordFrequencyRepository, times(3)).incrementFrequency("teh", "en") + } + + @Test + fun `coordinatePostCommitReplacement does not learn on non-autocorrect replacement`() = runTest(testDispatcher) { + val replacementState = PostCommitReplacementState( + originalWord = "hello", + committedWord = "hello" + ) + whenever(mockIc.getTextBeforeCursor(6, 0)).thenReturn("hello ") + whenever(mockTextInputProcessor.getCurrentSettings()).thenReturn( + com.urik.keyboard.settings.KeyboardSettings() + ) + + pipeline.coordinatePostCommitReplacement( + selectedSuggestion = "help", + replacementState = replacementState, + checkAutoCapitalization = {} + ) + + verify(mockWordLearningEngine, never()).learnWord(any(), any()) + } + + @Test + fun `coordinatePostCommitReplacement clears postCommitReplacementState`() = runTest(testDispatcher) { + inputState.postCommitReplacementState = PostCommitReplacementState("teh", "the") + whenever(mockIc.getTextBeforeCursor(4, 0)).thenReturn("the ") + whenever(mockTextInputProcessor.getCurrentSettings()).thenReturn( + com.urik.keyboard.settings.KeyboardSettings() + ) + whenever(mockSpellCheckManager.isWordInSymSpellDictionary(any())).thenReturn(false) + + pipeline.coordinatePostCommitReplacement( + selectedSuggestion = "teh", + replacementState = inputState.postCommitReplacementState!!, + checkAutoCapitalization = {} + ) + + assertNull(inputState.postCommitReplacementState) + } + + @Test + fun `coordinatePostCommitReplacement aborts on stale text`() = runTest(testDispatcher) { + val replacementState = PostCommitReplacementState( + originalWord = "teh", + committedWord = "the" + ) + whenever(mockIc.getTextBeforeCursor(4, 0)).thenReturn("oops") + + pipeline.coordinatePostCommitReplacement( + selectedSuggestion = "teh", + replacementState = replacementState, + checkAutoCapitalization = {} + ) + + verify(mockIc, never()).deleteSurroundingText(any(), any()) + verify(mockWordLearningEngine, never()).learnWord(any(), any()) + } + + @Test + fun `coordinateSuggestionSelection records word usage`() = runTest(testDispatcher) { + inputState.displayBuffer = "helo" + inputState.composingRegionStart = 0 + whenever(mockIc.getTextBeforeCursor(any(), any())).thenReturn("helo") + whenever(mockIc.commitText(any(), any())).thenReturn(true) + whenever(mockIc.finishComposingText()).thenReturn(true) + + val extractedText = android.view.inputmethod.ExtractedText().apply { + startOffset = 0 + selectionStart = 4 + } + whenever(mockIc.getExtractedText(any(), eq(0))).thenReturn(extractedText) + + pipeline.coordinateSuggestionSelection("hello", checkAutoCapitalization = {}) + + verify(mockWordFrequencyRepository).incrementFrequency("hello", "en") + assertEquals("hello", inputState.lastCommittedWord) + } +}