diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/BufferRequester.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/BufferRequester.java new file mode 100644 index 0000000000000..7b74aaa41134b --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/BufferRequester.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import javax.annotation.Nullable; + +import java.io.IOException; + +/** Supplies per-channel network buffers to the recovery pipeline. */ +@Internal +interface BufferRequester { + + /** Non-blocking; returns {@code null} when no buffer is currently available. */ + @Nullable + Buffer requestBuffer(InputChannelInfo channelInfo) throws IOException; + + Buffer requestBufferBlocking(InputChannelInfo channelInfo) + throws InterruptedException, IOException; + + /** + * Releases exclusive buffers for every channel served by this requester. Idempotent. Must run + * after the dispatcher's drain has finished so the underlying pools are no longer being read + * from. + */ + void releaseExclusiveBuffers() throws IOException; +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java index 4173bb7140e78..55e53d26e4b93 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java @@ -26,6 +26,8 @@ import org.apache.flink.runtime.state.CheckpointStateOutputStream; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.util.CloseableIterator; +import org.apache.flink.util.IOUtils; import org.apache.flink.util.Preconditions; import org.apache.flink.util.function.RunnableWithException; @@ -180,6 +182,42 @@ void writeOutput( } } + /** + * Writes spilled input-channel state chunks as [4-byte length prefix][data bytes], matching the + * buffer-based path. The iterator is closed when done. + */ + void writeInputFromSpill( + JobVertexID jobVertexID, + int subtaskIndex, + CloseableIterator chunks) { + if (isDone()) { + IOUtils.closeQuietly(chunks); + return; + } + ChannelStatePendingResult pendingResult = + getChannelStatePendingResult(jobVertexID, subtaskIndex); + runWithChecks( + () -> { + checkState(!pendingResult.isAllInputsReceived()); + try { + while (chunks.hasNext()) { + FilteredSpillFile.Chunk chunk = chunks.next(); + InputChannelInfo info = chunk.getChannelInfo(); + long offset = checkpointStream.getPos(); + dataStream.writeInt(chunk.getLength()); + dataStream.write(chunk.getData(), 0, chunk.getLength()); + long size = checkpointStream.getPos() - offset; + pendingResult + .getInputChannelOffsets() + .computeIfAbsent(info, unused -> new StateContentMetaInfo()) + .withDataAdded(offset, size); + } + } finally { + chunks.close(); + } + }); + } + private void write( Map offsets, K key, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java index b257c3b40544e..4b03406c80c22 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java @@ -49,17 +49,14 @@ import static org.apache.flink.util.Preconditions.checkNotNull; /** - * Filters recovered channel state buffers during the channel-state-unspilling phase, removing - * records that do not belong to the current subtask after rescaling. - * - *

Uses a per-gate architecture: each {@link InputGate} gets its own {@link GateFilterHandler} - * with the correct serializer, so multi-input tasks (e.g., TwoInputStreamTask) correctly - * deserialize different record types on different gates. + * Filters recovered channel-state buffers during the unspilling phase, dropping records that don't + * belong to the current subtask after rescaling. Each {@link InputGate} has its own {@link + * GateFilterHandler} with the correct serializer so multi-input tasks deserialize different record + * types on different gates. */ @Internal public class ChannelStateFilteringHandler implements Closeable { - // Wildcard allows heterogeneous record types across gates. private final GateFilterHandler[] gateHandlers; ChannelStateFilteringHandler(GateFilterHandler[] gateHandlers) { @@ -67,14 +64,12 @@ public class ChannelStateFilteringHandler implements Closeable { } /** - * Creates a handler from the recovery context, building per-gate virtual channels based on - * rescaling descriptors. Returns {@code null} if no filtering is needed (e.g., source tasks or - * no rescaling). + * Builds per-gate virtual channels from the rescaling descriptor. Returns {@code null} when no + * filtering is needed (source tasks, no rescaling). */ @Nullable public static ChannelStateFilteringHandler createFromContext( RecordFilterContext filterContext, InputGate[] inputGates) { - // Source tasks have no network inputs if (filterContext.getNumberOfGates() == 0) { return null; } @@ -101,23 +96,17 @@ public static ChannelStateFilteringHandler createFromContext( } /** - * Filters a recovered buffer from the specified virtual channel, returning new buffers - * containing only the records that belong to the current subtask. - * - *

One source buffer may produce 0 to N result buffers: 0 if all records are filtered out, - * and potentially more than 1 when a spanning record completes in this buffer. The deserializer - * caches partial record data from previous buffers, so the output may contain data that was not - * in the current source buffer, causing the total output size to exceed one buffer capacity. - * This can happen with any spanning record regardless of its size. - * - * @return filtered buffers, possibly empty if all records were filtered out. + * Filters records from {@code sourceBuffer} on the given virtual channel and writes survivors + * to {@code dispatcher}. One source buffer may yield 0..N records (the deserializer caches + * partial-record data across buffers). */ - public List filterAndRewrite( + public void filterAndRewrite( int gateIndex, int oldSubtaskIndex, int oldChannelIndex, Buffer sourceBuffer, - BufferSupplier bufferSupplier) + FilteredBufferDispatcher dispatcher, + InputChannelInfo targetChannelInfo) throws IOException, InterruptedException { if (gateIndex < 0 || gateIndex >= gateHandlers.length) { @@ -135,11 +124,10 @@ public List filterAndRewrite( + gateIndex + ". This gate is not a network input and should not have recovered buffers."); } - return gateHandler.filterAndRewrite( - oldSubtaskIndex, oldChannelIndex, sourceBuffer, bufferSupplier); + gateHandler.filterAndRewrite( + oldSubtaskIndex, oldChannelIndex, sourceBuffer, dispatcher, targetChannelInfo); } - /** Returns {@code true} if any virtual channel has a partial (spanning) record pending. */ public boolean hasPartialData() { for (GateFilterHandler handler : gateHandlers) { if (handler != null && handler.hasPartialData()) { @@ -158,14 +146,7 @@ public void close() { } } - // ------------------------------------------------------------------------------------------- - // Private static helper methods - // ------------------------------------------------------------------------------------------- - - /** - * Creates a {@link GateFilterHandler} for a single gate. The method-level type parameter - * ensures type safety within each gate while allowing different gates to have different types. - */ + /** Method-level {@code } keeps each gate type-safe while allowing heterogeneous gates. */ @SuppressWarnings("unchecked") @Nullable private static GateFilterHandler createGateHandler( @@ -206,7 +187,7 @@ private static GateFilterHandler createGateHandler( continue; } - // Only ambiguous channels need actual filtering; non-ambiguous ones pass through + // Only ambiguous channels need filtering; non-ambiguous ones pass through. boolean isAmbiguous = rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex); RecordFilter recordFilter = @@ -229,10 +210,7 @@ private static GateFilterHandler createGateHandler( return new GateFilterHandler<>(gateVirtualChannels, elementSerializer); } - /** - * Collects all old channel indexes that are mapped from any new channel index in this gate. - * channelMapping is new-to-old, so we iterate new indexes and collect their old counterparts. - */ + /** {@code channelMapping} is new-to-old; iterate new indexes and collect the old ones. */ private static int[] getOldChannelIndexes(RescaleMappings channelMapping, int numChannels) { List oldIndexes = new ArrayList<>(); for (int newIndex = 0; newIndex < numChannels; newIndex++) { @@ -256,20 +234,7 @@ private static RecordDeserializer> create } } - // ------------------------------------------------------------------------------------------- - // Inner classes - // ------------------------------------------------------------------------------------------- - - /** Provides buffers for re-serializing filtered records. Implementations may block. */ - @FunctionalInterface - public interface BufferSupplier { - Buffer requestBufferBlocking() throws IOException, InterruptedException; - } - - /** - * Handles record filtering for a single input gate. Each gate has its own serializer and set of - * virtual channels, allowing different gates to handle different record types independently. - */ + /** Filters records for a single input gate; owns its own serializer and virtual channels. */ static class GateFilterHandler { private final Map> virtualChannels; @@ -287,20 +252,15 @@ static class GateFilterHandler { this.outputSerializer = new DataOutputSerializer(128); } - /** - * Deserializes records from {@code sourceBuffer}, applies the virtual channel's record - * filter, and immediately re-serializes each surviving record into output buffers. - */ - List filterAndRewrite( + void filterAndRewrite( int oldSubtaskIndex, int oldChannelIndex, Buffer sourceBuffer, - BufferSupplier bufferSupplier) + FilteredBufferDispatcher dispatcher, + InputChannelInfo targetChannelInfo) throws IOException, InterruptedException { boolean sourceBufferOwnershipTransferred = false; - List resultBuffers = new ArrayList<>(); - Buffer currentBuffer = null; try { SubtaskConnectionDescriptor key = new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex); @@ -319,82 +279,38 @@ List filterAndRewrite( while (true) { DeserializationResult result = vc.getNextRecord(deserializationDelegate); if (result.isFullRecord()) { - if (currentBuffer == null) { - currentBuffer = bufferSupplier.requestBufferBlocking(); - } - currentBuffer = - serializeElement( - deserializationDelegate.getInstance(), - currentBuffer, - resultBuffers, - bufferSupplier); + serializeElement( + deserializationDelegate.getInstance(), + dispatcher, + targetChannelInfo); } if (result.isBufferConsumed()) { break; } } - - if (currentBuffer != null) { - if (currentBuffer.readableBytes() > 0) { - resultBuffers.add(currentBuffer); - } else { - currentBuffer.recycleBuffer(); - } - currentBuffer = null; - } - - return resultBuffers; } catch (Throwable t) { if (!sourceBufferOwnershipTransferred) { sourceBuffer.recycleBuffer(); } - // Avoid double-recycle: currentBuffer may already be the last element in - // resultBuffers if serializeElement added it before the exception. - if (currentBuffer != null - && (resultBuffers.isEmpty() - || resultBuffers.get(resultBuffers.size() - 1) != currentBuffer)) { - currentBuffer.recycleBuffer(); - } - for (Buffer buf : resultBuffers) { - buf.recycleBuffer(); - } - resultBuffers.clear(); throw t; } } - /** - * Serializes a single stream element into the current buffer using the length-prefixed - * format (4-byte big-endian length + record bytes) expected by Flink's record - * deserializers. Spills into new buffers from {@code bufferSupplier} when needed. - * - * @return the buffer to continue writing into (may differ from the input buffer). - */ - private Buffer serializeElement( + /** Length-prefixed format: 4-byte big-endian length + record bytes. */ + private void serializeElement( StreamElement element, - Buffer currentBuffer, - List resultBuffers, - BufferSupplier bufferSupplier) + FilteredBufferDispatcher dispatcher, + InputChannelInfo targetChannelInfo) throws IOException, InterruptedException { outputSerializer.clear(); serializer.serialize(element, outputSerializer); int recordLength = outputSerializer.length(); writeLengthToBuffer(recordLength); - currentBuffer = - writeDataToBuffer( - lengthBuffer, 0, 4, currentBuffer, resultBuffers, bufferSupplier); + dispatcher.write(lengthBuffer, 4, targetChannelInfo); byte[] serializedData = outputSerializer.getSharedBuffer(); - currentBuffer = - writeDataToBuffer( - serializedData, - 0, - recordLength, - currentBuffer, - resultBuffers, - bufferSupplier); - return currentBuffer; + dispatcher.write(serializedData, recordLength, targetChannelInfo); } private void writeLengthToBuffer(int length) { @@ -404,49 +320,6 @@ private void writeLengthToBuffer(int length) { lengthBuffer[3] = (byte) length; } - /** - * Writes data to the current buffer, spilling into new buffers from {@code bufferSupplier} - * when the current one is full. - * - * @return the buffer to continue writing into (may differ from the input buffer). - */ - private Buffer writeDataToBuffer( - byte[] data, - int dataOffset, - int dataLength, - Buffer currentBuffer, - List resultBuffers, - BufferSupplier bufferSupplier) - throws IOException, InterruptedException { - int offset = dataOffset; - int remaining = dataLength; - - while (remaining > 0) { - int writableBytes = currentBuffer.getMaxCapacity() - currentBuffer.getSize(); - - if (writableBytes == 0) { - // Buffer is full, transfer ownership to resultBuffers - resultBuffers.add(currentBuffer); - currentBuffer = bufferSupplier.requestBufferBlocking(); - writableBytes = currentBuffer.getMaxCapacity(); - } - - int bytesToWrite = Math.min(remaining, writableBytes); - currentBuffer - .getMemorySegment() - .put( - currentBuffer.getMemorySegmentOffset() + currentBuffer.getSize(), - data, - offset, - bytesToWrite); - currentBuffer.setSize(currentBuffer.getSize() + bytesToWrite); - - offset += bytesToWrite; - remaining -= bytesToWrite; - } - return currentBuffer; - } - boolean hasPartialData() { return virtualChannels.values().stream().anyMatch(VirtualChannel::hasPartialData); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java index abef241c325b8..c8b7cc0aec76c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java @@ -234,6 +234,20 @@ static ChannelStateWriteRequest buildWriteRequest( throwable -> iterator.close()); } + static ChannelStateWriteRequest buildSpillWriteRequest( + JobVertexID jobVertexID, + int subtaskIndex, + long checkpointId, + CloseableIterator chunks) { + return new CheckpointInProgressRequest( + "writeInputFromSpill", + jobVertexID, + subtaskIndex, + checkpointId, + writer -> writer.writeInputFromSpill(jobVertexID, subtaskIndex, chunks), + throwable -> chunks.close()); + } + static void checkBufferIsBuffer(Buffer buffer) { try { checkArgument(buffer.isBuffer()); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java index 6fee1402036d6..a6267869e3994 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.state.InputChannelStateHandle; import org.apache.flink.runtime.state.ResultSubpartitionStateHandle; import org.apache.flink.util.CloseableIterator; +import org.apache.flink.util.IOUtils; import java.io.Closeable; import java.util.Collection; @@ -124,6 +125,13 @@ void addInputData( int startSeqNum, CloseableIterator data); + /** + * Drains spill-file chunks into the checkpoint. Called by the dispatcher once all ready buffers + * are snapshotted and the wait-set is empty. The implementation closes the iterator. + */ + void addInputDataFromSpill( + long checkpointId, CloseableIterator chunks); + /** * Add in-flight buffers from the {@link * org.apache.flink.runtime.io.network.partition.ResultSubpartition ResultSubpartition}. Must be @@ -204,6 +212,12 @@ public void addInputData( int startSeqNum, CloseableIterator data) {} + @Override + public void addInputDataFromSpill( + long checkpointId, CloseableIterator chunks) { + IOUtils.closeQuietly(chunks); + } + @Override public void addOutputData( long checkpointId, ResultSubpartitionInfo info, int startSeqNum, Buffer... data) {} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java index 40d7ddffd1e18..df228f0388742 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java @@ -40,6 +40,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; +import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.buildSpillWriteRequest; import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeInput; import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeOutput; import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.write; @@ -194,6 +195,13 @@ public void addInputData( enqueue(write(jobVertexID, subtaskIndex, checkpointId, info, iterator), false); } + @Override + public void addInputDataFromSpill( + long checkpointId, CloseableIterator chunks) { + LOG.trace("{} adding spill input data, checkpoint {}", taskName, checkpointId); + enqueue(buildSpillWriteRequest(jobVertexID, subtaskIndex, checkpointId, chunks), false); + } + @Override public void addOutputData( long checkpointId, ResultSubpartitionInfo info, int startSeqNum, Buffer... data) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/EntryPosition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/EntryPosition.java new file mode 100644 index 0000000000000..cf35371b0f861 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/EntryPosition.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; + +import java.util.Objects; + +/** + * Position of a spill entry inside a {@link FilteredSpillFile}: physical file index in the spill + * file's {@code readers} list plus absolute byte offset within that file. Lexicographic ordering on + * {@code (fileIndex, offset)} matches the FIFO drain sequence. {@link #END} compares strictly + * greater than any real position and serves as the post-drain sentinel. + */ +@Internal +public final class EntryPosition implements Comparable { + + public static final EntryPosition END = new EntryPosition(Integer.MAX_VALUE, Long.MAX_VALUE); + + private final int fileIndex; + private final long offset; + + public EntryPosition(int fileIndex, long offset) { + this.fileIndex = fileIndex; + this.offset = offset; + } + + public int getFileIndex() { + return fileIndex; + } + + public long getOffset() { + return offset; + } + + @Override + public int compareTo(EntryPosition other) { + int byFile = Integer.compare(this.fileIndex, other.fileIndex); + if (byFile != 0) { + return byFile; + } + return Long.compare(this.offset, other.offset); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof EntryPosition)) { + return false; + } + EntryPosition that = (EntryPosition) o; + return fileIndex == that.fileIndex && offset == that.offset; + } + + @Override + public int hashCode() { + return Objects.hash(fileIndex, offset); + } + + @Override + public String toString() { + if (this == END) { + return "EntryPosition{END}"; + } + return "EntryPosition{fileIndex=" + fileIndex + ", offset=" + offset + "}"; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FilteredBufferDispatcher.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FilteredBufferDispatcher.java new file mode 100644 index 0000000000000..f6018d32b1bb0 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FilteredBufferDispatcher.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; + +import java.io.IOException; + +/** + * Dispatches filtered channel-state data across multiple channels' {@link + * org.apache.flink.runtime.io.network.partition.consumer.RecoveredBufferStore}s. {@link + * #write(byte[], int, InputChannelInfo)} pushes data for a target channel; the implementation + * decides whether to use a network buffer (P1), spill to disk (P2), or replay from disk (P3). + */ +@Internal +public interface FilteredBufferDispatcher extends AutoCloseable { + + void write(byte[] data, int length, InputChannelInfo channelInfo) + throws IOException, InterruptedException; + + /** Flushes any buffered data. After flush, no more writes are accepted. */ + void flush() throws IOException; + + /** + * Blocking drain of all remaining disk data into target stores. Must be called after {@link + * #flush()} so all Readers are frozen. Skipped on the abort path: callers that are cancelling + * go straight to {@link #close()}. Does not hold the dispatcher monitor while blocking; + * coordinator callbacks remain free to acquire it. + */ + void drainPendingSpill() throws IOException, InterruptedException; + + /** + * Releases resources held by this dispatcher. Idempotent. Pure resource release: does NOT drain + * remaining disk data — call {@link #drainPendingSpill()} first on the normal path. + */ + @Override + void close() throws IOException; +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FilteredBufferDispatcherImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FilteredBufferDispatcherImpl.java new file mode 100644 index 0000000000000..e551cc21e19a1 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FilteredBufferDispatcherImpl.java @@ -0,0 +1,613 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredBufferStoreImpl; +import org.apache.flink.util.CloseableIterator; +import org.apache.flink.util.Preconditions; + +import javax.annotation.concurrent.GuardedBy; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +import static org.apache.flink.util.IOUtils.closeQuietly; + +/** + * {@link FilteredBufferDispatcher} implementation managing three data paths: + * + *

    + *
  • P1: write directly to a network buffer and deliver to the target store + *
  • P2: when no buffer is available, write to a spill file + *
  • P3: when buffers become available later, eagerly replay spilled entries + *
+ * + *

A byte[] cache accumulates payload bytes for the active channel. On channel change or cache + * full, {@link #flushCache()} commits the bytes via P1 if the spill writer is idle and a buffer is + * available, otherwise via P2. After {@link #flush()} seals all Readers, {@link + * #drainPendingSpill()} drains the remainder; {@link #close()} releases resources. + * + *

Lock ordering

+ * + *

Two locks meet on the recovery → checkpoint hand-off: each per-channel store's intrinsic + * monitor (SMALL) and this dispatcher's {@link #dispatcherLock} (BIG). They must + * always be acquired in the order SMALL → BIG. Code holding BIG must never reach back to any + * SMALL — neither directly nor through a callee — otherwise the AB-BA cycle returns. Forbidden + * callees from inside a {@code synchronized(dispatcherLock)} block: {@link + * RecoveredBufferStoreImpl#addBuffer}, {@link RecoveredBufferStoreImpl#addBufferAfterDisk}, {@link + * RecoveredBufferStoreImpl#incrementPending}, any {@code synchronized(store)} block, and any {@link + * org.apache.flink.runtime.io.network.partition.consumer.ChannelStatePersister} entrypoint. + */ +@Internal +public class FilteredBufferDispatcherImpl + implements FilteredBufferDispatcher, RecoveredBufferStoreCoordinator { + + /** + * Typed as {@link RecoveredBufferStoreImpl} (not the interface) because the producer-side + * mutators ({@code addBuffer}, {@code incrementPending}) are deliberately not on the public + * interface — only this dispatcher calls them. + */ + private final Map storesByChannel; + + private final ChannelStateWriter channelStateWriter; + private final String[] spillDirs; + private final int memorySegmentSize; + private final BufferRequester bufferRequester; + + /** + * Explicit lock object instead of {@code synchronized} methods so every callsite that takes BIG + * is grep-visible. + */ + private final Object dispatcherLock = new Object(); + + private final byte[] cache; + private int cachePosition; + private InputChannelInfo cacheChannel; + + /** + * Lazily initialized by the recovery thread inside {@link #writeToSpillFile}; volatile so task + * threads observing it after {@link #flushed} see a fully-constructed instance. + */ + private volatile FilteredSpillFile spillFile; + + @GuardedBy("dispatcherLock") + private long currentCheckpointId = -1L; + + @GuardedBy("dispatcherLock") + private long lastStoppedCheckpointId = -1L; + + @GuardedBy("dispatcherLock") + private Set waitSet; + + /** + * Phase-2 snapshot Readers pinned at the first {@link #onChannelCheckpointStarted} for the + * in-flight checkpoint. {@code null} when no checkpoint is in progress. + */ + @GuardedBy("dispatcherLock") + private List checkpointSnapshots; + + /** + * Per-channel drain-head captured atomically with each channel's Step 1 ready snapshot. Phase-2 + * skips entries strictly below {@code startPos[X]} for channel X (already covered by Step 1) + * and writes entries at or after as that channel's checkpoint state. + */ + @GuardedBy("dispatcherLock") + private Map checkpointStartPos; + + /** + * Position of the next spill entry the drain bundle will pop from the global FIFO. Volatile + * publication provides cross-channel visibility: any other channel's Step 1 read under its own + * SMALL observes drain progress without acquiring BIG. + */ + private volatile EntryPosition drainHead; + + private volatile boolean flushed; + private volatile boolean closed; + + public FilteredBufferDispatcherImpl( + Map storesByChannel, + ChannelStateWriter channelStateWriter, + String[] spillDirs, + int memorySegmentSize, + BufferRequester bufferRequester) + throws IOException { + if (spillDirs.length == 0) { + throw new IOException("Spill directories must not be empty"); + } + this.storesByChannel = storesByChannel; + this.channelStateWriter = channelStateWriter; + this.spillDirs = spillDirs; + this.memorySegmentSize = memorySegmentSize; + this.bufferRequester = bufferRequester; + this.cache = new byte[memorySegmentSize]; + this.cachePosition = 0; + + for (RecoveredBufferStoreImpl store : storesByChannel.values()) { + synchronized (store) { + store.setCoordinator(this); + } + } + } + + @Override + public void write(byte[] data, int length, InputChannelInfo channelInfo) + throws IOException, InterruptedException { + if (flushed || closed) { + throw new IllegalStateException("Cannot write after " + (closed ? "close" : "flush")); + } + + eagerDrain(); + + if (cacheChannel != null && !cacheChannel.equals(channelInfo) && cachePosition > 0) { + flushCache(); + } + cacheChannel = channelInfo; + + int pos = 0; + while (pos < length) { + int space = memorySegmentSize - cachePosition; + int toCopy = Math.min(space, length - pos); + System.arraycopy(data, pos, cache, cachePosition, toCopy); + cachePosition += toCopy; + pos += toCopy; + + if (cachePosition == memorySegmentSize) { + flushCache(); + cacheChannel = channelInfo; + } + } + } + + @Override + public void flush() throws IOException { + if (flushed || closed) { + return; + } + flushCache(); + if (spillFile != null) { + spillFile.finish(); + // Initial value Step 1 of any channel will observe before the first drainPendingSpill + // bundle commits — never publish an unset drainHead during the live checkpoint window. + drainHead = computeDrainHeadFrom(0); + } + flushed = true; + } + + @Override + public void drainPendingSpill() throws IOException, InterruptedException { + Preconditions.checkState(flushed, "drainPendingSpill requires flush() to be called first"); + if (closed) { + return; + } + if (spillFile == null) { + return; + } + List readers = spillFile.getReaders(); + for (int i = 0; i < readers.size(); i++) { + FilteredSpillFile.Reader reader = readers.get(i); + while (true) { + FilteredSpillFile.Reader.Entry entry = reader.peekNextEntry(); + if (entry == null) { + break; + } + InputChannelInfo ch = entry.getChannelInfo(); + long entryOffset = entry.getOffset(); + int entryLength = entry.getLength(); + + // Buffer allocation may block on the pool and disk reads may miss the page + // cache; keep both outside the store lock so channel checkpoints are not + // serialised behind them. + Buffer buffer = bufferRequester.requestBufferBlocking(ch); + byte[] data = new byte[entryLength]; + reader.readBytesAt(entryOffset, entryLength, data, 0); + + RecoveredBufferStoreImpl store = + Preconditions.checkNotNull( + storesByChannel.get(ch), "No store for channel %s", ch); + synchronized (store.getGateLock()) { + reader.skipNextEntry(); + writeChunkToBuffer(buffer, data, entryLength); + store.addBuffer(buffer); + // drainHead's volatile write happens-after addBuffer's store-lock release — + // preserves "drainHead crossed e ⇒ e is in store_C.readyBuffers" for + // cross-channel readers. + drainHead = computeDrainHeadFrom(i); + } + } + } + } + + @Override + public void close() throws IOException { + if (closed) { + return; + } + + // Phase 1 (abort path only): caller skipped flush(). flushCache reaches store.addBuffer, + // which acquires SMALL — must run outside dispatcherLock to keep the lock-order rule. + // The happy path enters with flushed=true and skips this block. + if (!flushed) { + flushCache(); + if (spillFile != null) { + spillFile.finish(); + } + flushed = true; + } + + // Setting closed=true before deleting the spill file is what closes the + // close-vs-snapshot race: a concurrent onChannelCheckpointStarted either pinned its + // FileChannels before this block ran (POSIX keeps the file alive after unlink) or + // observes closed and returns before opening the file. + synchronized (dispatcherLock) { + closed = true; + if (spillFile != null) { + spillFile.close(); + spillFile = null; + } + resetCheckpointState(); + } + + bufferRequester.releaseExclusiveBuffers(); + } + + @Override + public EntryPosition getCurrentDrainHead() { + EntryPosition head = drainHead; + return head == null ? EntryPosition.END : head; + } + + /** + * On the first callback for a checkpoint id, pins an immutable phase-2 view of every frozen + * Reader and seeds the wait-set with their pending channels. Subsequent callbacks remove their + * channel; the empty wait-set triggers {@link #drainSpillEntriesToCheckpoint}. + * + *

{@code startPos} is the per-channel drain-head captured atomically with the ready-buffer + * snapshot; phase-2 uses it to skip entries already covered by that channel's Step 1. + */ + @Override + public void onChannelCheckpointStarted( + long checkpointId, InputChannelInfo channelInfo, EntryPosition startPos) { + synchronized (dispatcherLock) { + if (closed) { + return; + } + if (checkpointId < currentCheckpointId) { + return; + } + if (checkpointId <= lastStoppedCheckpointId) { + // ChannelStateWriter for this id is gone; phase-2 drain into it would rely on + // writer.isDone() to silently swallow the data. + return; + } + if (checkpointId > currentCheckpointId) { + currentCheckpointId = checkpointId; + checkpointStartPos = new HashMap<>(); + checkpointSnapshots = new ArrayList<>(); + waitSet = new HashSet<>(); + if (spillFile != null) { + pinSpillSnapshots(); + } + } + if (checkpointStartPos != null) { + checkpointStartPos.put(channelInfo, startPos); + } + if (waitSet != null) { + waitSet.remove(channelInfo); + if (waitSet.isEmpty()) { + drainSpillEntriesToCheckpoint(checkpointId); + } + } + } + } + + /** + * Drops the wait-set tied to a finished/aborted checkpoint and bumps {@code + * lastStoppedCheckpointId} so a late {@link #onChannelCheckpointStarted} for the same id is + * short-circuited as stale. Closes any pinned phase-2 snapshot Readers — otherwise every + * aborted checkpoint leaks one fd per spill file. + */ + @Override + public void onChannelCheckpointStopped(long checkpointId, InputChannelInfo channelInfo) { + synchronized (dispatcherLock) { + if (closed) { + return; + } + if (checkpointId > lastStoppedCheckpointId) { + lastStoppedCheckpointId = checkpointId; + } + if (currentCheckpointId == checkpointId) { + resetCheckpointState(); + } + } + } + + /** + * Drops every pending spill entry belonging to {@code channelInfo} from all live Readers. + * Phase-2 snapshots are intentionally not mutated: the filtering iterator drops snapshot + * entries whose channel has no recorded startPos. Mutating the live snapshot would race the + * executor thread already iterating it. + */ + @Override + public void onChannelReleased(InputChannelInfo channelInfo) { + synchronized (dispatcherLock) { + if (closed) { + return; + } + if (spillFile != null) { + // reader.removeEntriesForChannel mutates the same Reader.entries deque that + // drainPendingSpill (holds SMALL_C, not BIG) pops from, so the deque must be + // a ConcurrentLinkedDeque — that is its load-bearing role, not a decoration. + for (FilteredSpillFile.Reader reader : spillFile.getReaders()) { + reader.removeEntriesForChannel(channelInfo); + } + } + if (waitSet != null && waitSet.remove(channelInfo) && waitSet.isEmpty()) { + drainSpillEntriesToCheckpoint(currentCheckpointId); + } + } + } + + @GuardedBy("dispatcherLock") + private void pinSpillSnapshots() { + List snapshots = new ArrayList<>(); + try { + for (FilteredSpillFile.Reader reader : spillFile.getReaders()) { + Preconditions.checkState( + reader.isFrozen(), + "Reader must be frozen when checkpoint starts; writer.finish() " + + "must be called before checkpoint trigger."); + snapshots.add(reader.snapshot()); + } + } catch (IOException e) { + for (FilteredSpillFile.Reader snap : snapshots) { + closeQuietly(snap); + } + throw new RuntimeException("Failed to snapshot spill readers for checkpoint", e); + } + checkpointSnapshots = snapshots; + for (FilteredSpillFile.Reader snap : snapshots) { + waitSet.addAll(snap.getPendingChannels()); + } + } + + /** + * Hands pinned snapshot Readers off to the {@link ChannelStateWriter}. Ownership transfers to + * the iterator's {@link FilteringDrainChunkIterator#close()} so the FileChannels are released + * even if the writer never advances the iterator (e.g. on abort). + */ + @GuardedBy("dispatcherLock") + private void drainSpillEntriesToCheckpoint(long checkpointId) { + if (checkpointSnapshots == null || checkpointSnapshots.isEmpty()) { + resetCheckpointState(); + return; + } + List snapshots = checkpointSnapshots; + Map startPos = checkpointStartPos; + checkpointSnapshots = null; + checkpointStartPos = null; + waitSet = null; + // addInputDataFromSpill submits to an async writer thread and does not reach back into + // any store SMALL — required for any callee invoked while holding BIG. + channelStateWriter.addInputDataFromSpill( + checkpointId, new FilteringDrainChunkIterator(snapshots, startPos)); + } + + @GuardedBy("dispatcherLock") + private void resetCheckpointState() { + if (checkpointSnapshots != null) { + for (FilteredSpillFile.Reader snap : checkpointSnapshots) { + closeQuietly(snap); + } + checkpointSnapshots = null; + } + checkpointStartPos = null; + waitSet = null; + } + + /** + * {@code fromListIndex} is a list cursor, distinct from {@link + * FilteredSpillFile.Reader#getFileIndex()} which is the globally monotonic file-id. + */ + private EntryPosition computeDrainHeadFrom(int fromListIndex) { + if (spillFile == null) { + return EntryPosition.END; + } + List readers = spillFile.getReaders(); + for (int i = fromListIndex; i < readers.size(); i++) { + FilteredSpillFile.Reader r = readers.get(i); + FilteredSpillFile.Reader.Entry next = r.peekNextEntry(); + if (next != null) { + return new EntryPosition(r.getFileIndex(), next.getOffset()); + } + } + return EntryPosition.END; + } + + /** + * Commits the cache via P1 (direct buffer) or P2 (spill). P1 requires the spill writer idle AND + * a non-blocking buffer available; otherwise spill, which preserves FIFO ordering — once + * anything has been spilled, all subsequent data must also spill. + */ + private void flushCache() throws IOException { + if (cachePosition == 0) { + cacheChannel = null; + return; + } + + InputChannelInfo channelInfo = cacheChannel; + int bytesToFlush = cachePosition; + cachePosition = 0; + cacheChannel = null; + + if (isSpillIdle()) { + Buffer buffer = bufferRequester.requestBuffer(channelInfo); + if (buffer != null) { + writeChunkToBuffer(buffer, cache, bytesToFlush); + RecoveredBufferStoreImpl store = + Preconditions.checkNotNull( + storesByChannel.get(channelInfo), + "No store for channel %s", + channelInfo); + synchronized (store.getGateLock()) { + store.addBuffer(buffer); + } + return; + } + } + + writeToSpillFile(cache, bytesToFlush, channelInfo); + } + + private static void writeChunkToBuffer(Buffer buffer, byte[] data, int length) { + Preconditions.checkState( + buffer.getMaxCapacity() >= length, + "Buffer capacity %s is smaller than chunk length %s", + buffer.getMaxCapacity(), + length); + buffer.asByteBuf().writeBytes(data, 0, length); + } + + private void writeToSpillFile(byte[] data, int length, InputChannelInfo channelInfo) + throws IOException { + if (spillFile == null) { + spillFile = new FilteredSpillFile(spillDirs, memorySegmentSize); + } + spillFile.writeEntry(data, length, channelInfo); + RecoveredBufferStoreImpl store = + Preconditions.checkNotNull( + storesByChannel.get(channelInfo), "No store for channel %s", channelInfo); + synchronized (store) { + store.incrementPending(); + } + } + + /** + * Eagerly replays spill entries while non-blocking buffers are available. Runs only on the + * {@link #write} path before {@link #flush}, so by construction it cannot race {@link + * #onChannelCheckpointStarted}: physical channels (and thus checkpoint triggers) only exist + * after recovery's {@code finishRecovery()}, which runs after flush. Does not maintain {@code + * drainHead} — that field is initialised at flush time. + */ + private void eagerDrain() throws IOException { + if (spillFile == null) { + return; + } + for (FilteredSpillFile.Reader reader : spillFile.getReaders()) { + while (reader.hasEntries()) { + InputChannelInfo ch = reader.peekNextChannel(); + Buffer buffer = bufferRequester.requestBuffer(ch); + if (buffer == null) { + return; + } + FilteredSpillFile.Chunk chunk = reader.readNext(); + if (chunk == null) { + buffer.recycleBuffer(); + return; + } + writeChunkToBuffer(buffer, chunk.getData(), chunk.getLength()); + RecoveredBufferStoreImpl store = + Preconditions.checkNotNull( + storesByChannel.get(ch), "No store for channel %s", ch); + synchronized (store.getGateLock()) { + store.addBuffer(buffer); + } + } + } + } + + private boolean isSpillIdle() { + return spillFile == null || spillFile.isIdle(); + } + + /** + * Iterates chunks from snapshot Readers, skipping entries below each channel's recorded {@code + * startPos} cutoff (those are covered by Step 1). Each Reader is closed eagerly when exhausted; + * {@link #close()} closes whatever Readers remain. + */ + private static final class FilteringDrainChunkIterator + implements CloseableIterator { + + private final Deque remaining; + private final Map startPos; + + FilteringDrainChunkIterator( + List snapshots, + Map startPos) { + this.remaining = new ArrayDeque<>(snapshots); + this.startPos = startPos; + } + + @Override + public boolean hasNext() { + advanceToIncluded(); + return !remaining.isEmpty(); + } + + @Override + public FilteredSpillFile.Chunk next() { + advanceToIncluded(); + if (remaining.isEmpty()) { + throw new NoSuchElementException(); + } + try { + return remaining.peekFirst().readNext(); + } catch (IOException e) { + throw new UncheckedIOException("Failed to read spill chunk", e); + } + } + + private void advanceToIncluded() { + while (!remaining.isEmpty()) { + FilteredSpillFile.Reader r = remaining.peekFirst(); + if (!r.hasEntries()) { + closeQuietly(remaining.pollFirst()); + continue; + } + FilteredSpillFile.Reader.Entry e = r.peekNextEntry(); + EntryPosition cutoff = startPos.get(e.getChannelInfo()); + EntryPosition entryPos = new EntryPosition(r.getFileIndex(), e.getOffset()); + // No startPos means the channel was released before checkpoint trigger — drop + // every snapshot entry for it. Below-cutoff entries are already covered by + // Step 1 in store_C.readyBuffers. + if (cutoff == null || entryPos.compareTo(cutoff) < 0) { + r.skipNextEntry(); + } else { + return; + } + } + } + + @Override + public void close() { + while (!remaining.isEmpty()) { + closeQuietly(remaining.pollFirst()); + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FilteredSpillFile.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FilteredSpillFile.java new file mode 100644 index 0000000000000..7d22d81ad4459 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FilteredSpillFile.java @@ -0,0 +1,401 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.util.FileUtils; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentLinkedDeque; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Spill file for the {@code filterAndRewrite} recovery path. Appends raw bytes to one or more + * physical files; each logical entry is tracked in the corresponding {@link Reader}'s entry deque. + * Readers support replay ({@link Reader#readNext}) and checkpoint snapshot ({@link + * Reader#snapshot}). Rotates to a new file when the current one exceeds 64 MB; each rotation + * freezes the outgoing Reader. Files are created lazily on the first {@link #writeEntry}. + */ +@Internal +public class FilteredSpillFile implements Closeable { + + private static final long FILE_ROTATION_THRESHOLD = 64L * 1024 * 1024; + + private final String[] spillDirs; + private final int memorySegmentSize; + private int currentDirIndex; + private FileChannel currentChannel; + private long currentFileOffset; + private final List readers; + + /** Monotonic file-id counter; never reused. */ + private int nextFileIndex; + + private boolean finished; + + /** + * @param memorySegmentSize max bytes per spill entry — each entry is 1:1 aligned with a network + * buffer of this size, so longer payloads must be split upstream. + */ + public FilteredSpillFile(String[] spillDirs, int memorySegmentSize) { + this.spillDirs = spillDirs; + this.memorySegmentSize = memorySegmentSize; + this.currentDirIndex = 0; + this.currentFileOffset = 0; + this.readers = new ArrayList<>(); + this.nextFileIndex = 0; + this.finished = false; + } + + /** + * Appends {@code len} bytes for {@code channelInfo}, lazily opening the first file and rotating + * when the current file exceeds {@link #FILE_ROTATION_THRESHOLD}. Oversize entries (> {@code + * memorySegmentSize}) fail fast. + */ + public void writeEntry(byte[] data, int len, InputChannelInfo channelInfo) throws IOException { + checkState(!finished, "writeEntry after finish"); + checkArgument( + len <= memorySegmentSize, + "Entry length %s exceeds memorySegmentSize %s", + len, + memorySegmentSize); + if (currentChannel == null) { + openNewFile(); + } else if (currentFileOffset > FILE_ROTATION_THRESHOLD) { + rotateFile(); + } + long entryOffset = currentFileOffset; + FileUtils.writeCompletely(currentChannel, ByteBuffer.wrap(data, 0, len)); + currentFileOffset += len; + currentReader().addEntry(channelInfo, entryOffset, len); + } + + /** Freezes the last Reader. After finish, no more writeEntry calls are accepted. */ + public void finish() { + if (finished) { + return; + } + finished = true; + if (!readers.isEmpty()) { + currentReader().freeze(); + } + } + + /** Closes all Readers and deletes the underlying spill files. */ + @Override + public void close() throws IOException { + finish(); + try { + if (currentChannel != null) { + currentChannel.close(); + currentChannel = null; + } + } finally { + for (Reader r : readers) { + r.close(); + } + for (Reader r : readers) { + try { + Files.deleteIfExists(r.filePath); + } catch (IOException ignored) { + // best-effort cleanup + } + } + } + } + + public boolean isFinished() { + return finished; + } + + /** + * True when no entry is pending on disk. While idle, the dispatcher prefers P1; FIFO ordering + * is preserved because there are no on-disk entries to jump ahead of. + */ + public boolean isIdle() { + for (Reader r : readers) { + if (r.hasEntries()) { + return false; + } + } + return true; + } + + public List getReaders() { + return Collections.unmodifiableList(readers); + } + + private Reader currentReader() { + return readers.get(readers.size() - 1); + } + + private void openNewFile() throws IOException { + String dir = spillDirs[currentDirIndex]; + currentDirIndex = (currentDirIndex + 1) % spillDirs.length; + Path dirPath = Paths.get(dir); + Files.createDirectories(dirPath); + Path currentFilePath = dirPath.resolve("spill-" + UUID.randomUUID() + ".bin"); + currentChannel = + FileChannel.open( + currentFilePath, StandardOpenOption.CREATE_NEW, StandardOpenOption.WRITE); + currentFileOffset = 0; + readers.add(new Reader(currentFilePath, memorySegmentSize, nextFileIndex++)); + } + + private void rotateFile() throws IOException { + currentReader().freeze(); + currentChannel.close(); + currentChannel = null; + openNewFile(); + } + + /** + * Spilled-data chunk returned by {@link Reader#readNext()}. The {@code data} array is the + * Reader's internal buffer, reused between calls; callers must consume bytes before the next + * readNext. + */ + public static final class Chunk { + + private final InputChannelInfo channelInfo; + private final byte[] data; + private final int length; + + public Chunk(InputChannelInfo channelInfo, byte[] data, int length) { + this.channelInfo = channelInfo; + this.data = data; + this.length = length; + } + + public InputChannelInfo getChannelInfo() { + return channelInfo; + } + + /** Internal data buffer; valid bytes are {@code [0, length)}. */ + public byte[] getData() { + return data; + } + + public int getLength() { + return length; + } + } + + /** + * Reads entries from a single spill file. The original Reader is mutated by the recovery thread + * (write path, replay drain) and concurrently by task threads via {@link + * #removeEntriesForChannel} on channel release. To avoid undefined behavior on the entry deque, + * the backing storage is a {@link ConcurrentLinkedDeque} — its weakly consistent iterator and + * atomic poll/peek tolerate the writer-vs-release race that the post-iter_6 dispatcher no + * longer wraps in any monitor. The internal byte buffer is reused across {@link #readNext()} + * calls; callers must consume each {@link Chunk} before calling readNext again. + */ + public static class Reader implements Closeable { + + private final FileChannel channel; + final Path filePath; // accessed by FilteredSpillFile.close() to delete spill files + private final int memorySegmentSize; + private final int fileIndex; + private final Deque entries = new ConcurrentLinkedDeque<>(); + private volatile boolean frozen = false; + private final byte[] buf; + + Reader(Path filePath, int memorySegmentSize, int fileIndex) throws IOException { + this.filePath = filePath; + this.channel = FileChannel.open(filePath, StandardOpenOption.READ); + this.memorySegmentSize = memorySegmentSize; + this.fileIndex = fileIndex; + // writeEntry rejects oversized payloads, so every entry is guaranteed to fit. + this.buf = new byte[memorySegmentSize]; + } + + public int getFileIndex() { + return fileIndex; + } + + void addEntry(InputChannelInfo channelInfo, long offset, int length) { + checkState(!frozen, "addEntry after freeze"); + entries.addLast(new Entry(channelInfo, offset, length)); + } + + /** Head entry without consuming it; null if empty. */ + public Entry peekNextEntry() { + return entries.peekFirst(); + } + + /** + * Removes the head entry without disk I/O. Use after reading via {@link #readBytesAt} or to + * discard (e.g. phase-2 filter skipping an entry already covered by Step 1). + */ + public Entry skipNextEntry() { + return entries.pollFirst(); + } + + /** + * Reads {@code length} bytes from absolute {@code offset} into {@code dest}. Unlike {@link + * #readNext()} this does not mutate the entry deque, so callers can perform the I/O outside + * any lock that protects the deque. + */ + public void readBytesAt(long offset, int length, byte[] dest, int destOffset) + throws IOException { + ByteBuffer bb = ByteBuffer.wrap(dest, destOffset, length); + long position = offset; + while (bb.hasRemaining()) { + int n = channel.read(bb, position); + if (n < 0) { + throw new IOException( + "Truncated spill file: " + + length + + " bytes @" + + offset + + " in " + + filePath); + } + position += n; + } + } + + void freeze() { + frozen = true; + } + + public boolean isFrozen() { + return frozen; + } + + public boolean hasEntries() { + return !entries.isEmpty(); + } + + /** Channel of the next pending entry; null if empty. */ + public InputChannelInfo peekNextChannel() { + Entry e = entries.peekFirst(); + return e != null ? e.channelInfo : null; + } + + /** + * Reads the next pending entry as a {@link Chunk}; null when there are no more entries. The + * Chunk's data array is the Reader's internal buffer and is overwritten by the next + * readNext. + */ + public Chunk readNext() throws IOException { + Entry entry = entries.pollFirst(); + if (entry == null) { + return null; + } + ByteBuffer bb = ByteBuffer.wrap(buf, 0, entry.length); + long position = entry.offset; + while (bb.hasRemaining()) { + int n = channel.read(bb, position); + if (n < 0) { + throw new IOException( + "Truncated spill file: " + + entry.length + + " bytes @" + + entry.offset + + " in " + + filePath); + } + position += n; + } + return new Chunk(entry.channelInfo, buf, entry.length); + } + + /** + * Independent Reader over the same file with a shallow copy of the current entries, + * pre-frozen. The caller owns and must close the returned Reader. + */ + public Reader snapshot() throws IOException { + checkState(frozen, "snapshot requires frozen Reader"); + Reader snap = new Reader(filePath, memorySegmentSize, fileIndex); + snap.entries.addAll(this.entries); + snap.frozen = true; + return snap; + } + + public Set getPendingChannels() { + Set channels = new HashSet<>(); + for (Entry e : entries) { + channels.add(e.channelInfo); + } + return channels; + } + + /** + * Drops all pending entries for {@code channelInfo}; returns the count. Used when a store + * is released so disk-side bookkeeping is freed eagerly. + */ + public int removeEntriesForChannel(InputChannelInfo channelInfo) { + int removed = 0; + Iterator it = entries.iterator(); + while (it.hasNext()) { + if (it.next().channelInfo.equals(channelInfo)) { + it.remove(); + removed++; + } + } + return removed; + } + + @Override + public void close() throws IOException { + channel.close(); + } + + /** Immutable metadata for a single spilled entry: target channel, offset, length. */ + public static final class Entry { + private final InputChannelInfo channelInfo; + private final long offset; + private final int length; + + Entry(InputChannelInfo channelInfo, long offset, int length) { + this.channelInfo = channelInfo; + this.offset = offset; + this.length = length; + } + + public InputChannelInfo getChannelInfo() { + return channelInfo; + } + + public long getOffset() { + return offset; + } + + public int getLength() { + return length; + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredBufferStoreCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredBufferStoreCoordinator.java new file mode 100644 index 0000000000000..855d04648f1a9 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredBufferStoreCoordinator.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; + +/** + * Cross-channel coordinator notified by per-channel {@link + * org.apache.flink.runtime.io.network.partition.consumer.RecoveredBufferStore} instances on + * lifecycle events. Centralises bookkeeping that spans multiple channels (checkpoint wait-sets, + * shared on-disk spill state). + * + *

The {@code onChannel*} callbacks fire from the Task thread outside the calling + * store's lock, so implementations may freely acquire their own synchronisation. {@link + * #getCurrentDrainHead()} fires inside the calling store's lock so the store can capture a + * consistent (readyBuffers, drainHead) pair atomically — implementations must therefore avoid + * blocking and must not acquire any lock participating in the store-lock cycle (a plain {@code + * volatile} read is the intended implementation). + */ +@Internal +public interface RecoveredBufferStoreCoordinator { + + /** + * Position of the next entry the drain bundle will pop from the global FIFO. Returns {@link + * EntryPosition#END} when no disk entries are pending. Must be cheap and lock-free. + */ + EntryPosition getCurrentDrainHead(); + + /** + * Invoked from {@code RecoveredBufferStore#checkpoint} after the store has snapshotted its + * ready buffers. {@code startPos} is the drain-head captured atomically with that snapshot; + * phase-2 uses it as the per-channel cutoff to split spill entries between this channel's Step + * 1 snapshot and the global checkpoint drain. + */ + void onChannelCheckpointStarted( + long checkpointId, InputChannelInfo channelInfo, EntryPosition startPos); + + /** + * Invoked from {@code RecoveredBufferStore#notifyCheckpointStopped} when the owning channel has + * finished or aborted a checkpoint. Used to drop a wait-set still tied to the stopped + * checkpoint so a later release or late start callback cannot trigger a phase-2 drain into a + * concluded checkpoint. + */ + void onChannelCheckpointStopped(long checkpointId, InputChannelInfo channelInfo); + + /** + * Invoked from {@code RecoveredBufferStore#releaseAll} so the coordinator can drop + * disk-resident spill entries still associated with the released channel. + */ + void onChannelReleased(InputChannelInfo channelInfo); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java index ca01ff37bd369..a5b8bf24eab36 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java @@ -71,6 +71,14 @@ public void close() { */ void recover(Info info, int oldSubtaskIndex, BufferWithContext bufferWithContext) throws IOException, InterruptedException; + + /** + * Triggers post-recovery actions: input channels complete the buffer-filtering future and + * publish {@code EndOfInputChannelStateEvent} via their store; output partitions call {@code + * finishReadRecoveredState} on every checkpointable partition. Idempotent. Must be invoked + * between {@code dispatcher.flush()} and {@code dispatcher.drainPendingSpill()}. + */ + void finishRecovery() throws IOException; } class InputChannelRecoveredStateHandler @@ -82,43 +90,43 @@ class InputChannelRecoveredStateHandler private final Map rescaledChannels = new HashMap<>(); private final Map oldToNewMappings = new HashMap<>(); + /** When non-null, filtering runs during recovery in the channel-state-unspilling thread. */ + @Nullable private final ChannelStateFilteringHandler filteringHandler; + /** - * Optional filtering handler for filtering recovered buffers. When non-null, filtering is - * performed during recovery in the channel-state-unspilling thread. + * When non-null (filtering mode), filtered records are written here instead of directly to + * InputChannel buffers. */ - @Nullable private final ChannelStateFilteringHandler filteringHandler; + @Nullable private final FilteredBufferDispatcher dispatcher; - /** Network buffer memory segment size in bytes. Used to size the reusable pre-filter buffer. */ private final int memorySegmentSize; /** - * Reusable heap memory segment backing the pre-filter buffer in filtering mode. Lazily - * allocated on the first {@link #getPreFilterBuffer} call, reused for every subsequent call, - * and freed in {@link #close()}. - * - *

Reuse is safe because at most one pre-filter buffer is in flight per task at any moment. - * This invariant is enforced at runtime by {@link #preFilterBufferInUse}. + * Reusable heap segment backing the pre-filter buffer in filtering mode. Lazily allocated on + * first {@link #getPreFilterBuffer}, reused for subsequent calls, freed in {@link #close()}. + * Reuse is safe because at most one pre-filter buffer is in flight per task; the invariant is + * enforced at runtime by {@link #preFilterBufferInUse}. */ @Nullable private MemorySegment preFilterSegment; - /** - * Tracks whether {@link #preFilterSegment} is currently wrapped by a live {@link Buffer} that - * has not yet been recycled. Flipped to {@code true} when a new buffer is issued, and flipped - * back to {@code false} by the custom {@link BufferRecycler} when the buffer is recycled. - */ + /** True while {@link #preFilterSegment} is wrapped by a live, unreclaimed buffer. */ private boolean preFilterBufferInUse; + private boolean recoveryFinished; + InputChannelRecoveredStateHandler( InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping, @Nullable ChannelStateFilteringHandler filteringHandler, - int memorySegmentSize) { + int memorySegmentSize, + @Nullable FilteredBufferDispatcher dispatcher) { this.inputGates = inputGates; this.channelMapping = channelMapping; this.filteringHandler = filteringHandler; checkArgument( memorySegmentSize > 0, "memorySegmentSize must be positive: %s", memorySegmentSize); this.memorySegmentSize = memorySegmentSize; + this.dispatcher = dispatcher; } @Override @@ -127,7 +135,6 @@ public BufferWithContext getBuffer(InputChannelInfo channelInfo) if (filteringHandler != null) { return getPreFilterBuffer(); } - // Non-filtering mode: use existing network buffer pool allocation. RecoveredInputChannel channel = getMappedChannels(channelInfo); Buffer buffer = channel.requestBufferBlocking(); return new BufferWithContext<>(wrap(buffer), buffer); @@ -135,18 +142,8 @@ public BufferWithContext getBuffer(InputChannelInfo channelInfo) /** * Allocates a pre-filter buffer from a reusable heap segment (isolated from the Network Buffer - * Pool) in filtering mode. - * - *

Memory management: a single {@link MemorySegment} per task is lazily allocated on first - * invocation and reused across every subsequent call. The custom {@link BufferRecycler} does - * not free the segment — it only flips {@link #preFilterBufferInUse} back to {@code false} so - * the next call can reuse it. The segment itself is freed in {@link #close()}. - * - *

Runtime invariant check: the one-at-a-time invariant on pre-filter buffers is guaranteed - * by Flink's serial recovery loop and the deserializer's ownership contract. This method - * asserts the invariant before issuing a buffer: if a previously issued buffer has not yet been - * recycled, it throws {@link IllegalStateException} so any future regression fails loudly - * instead of silently corrupting memory. + * Pool). Flink's serial recovery loop guarantees at most one is in flight at a time; the + * runtime check fails loudly if that ever regresses. */ private BufferWithContext getPreFilterBuffer() { checkState( @@ -159,7 +156,6 @@ private BufferWithContext getPreFilterBuffer() { } preFilterBufferInUse = true; - // The recycler keeps the segment alive for reuse; only flips the in-use flag. BufferRecycler recycler = segment -> preFilterBufferInUse = false; Buffer buffer = new NetworkBuffer(preFilterSegment, recycler); return new BufferWithContext<>(wrap(buffer), buffer); @@ -191,12 +187,16 @@ public void recover( recoverWithFiltering( channel, channelInfo, oldSubtaskIndex, buffer.retainBuffer()); } else { - channel.onRecoveredStateBuffer( - EventSerializer.toBuffer( - new SubtaskConnectionDescriptor( - oldSubtaskIndex, channelInfo.getInputChannelIdx()), - false)); - channel.onRecoveredStateBuffer(buffer.retainBuffer()); + synchronized (channel.getStore().getGateLock()) { + channel.getStore() + .addBuffer( + EventSerializer.toBuffer( + new SubtaskConnectionDescriptor( + oldSubtaskIndex, + channelInfo.getInputChannelIdx()), + false)); + channel.getStore().addBuffer(buffer.retainBuffer()); + } } } } finally { @@ -211,33 +211,31 @@ private void recoverWithFiltering( Buffer retainedBuffer) throws IOException, InterruptedException { checkState(filteringHandler != null, "filtering handler not set."); - List filteredBuffers = - filteringHandler.filterAndRewrite( - channelInfo.getGateIdx(), - oldSubtaskIndex, - channelInfo.getInputChannelIdx(), - retainedBuffer, - channel::requestBufferBlocking); - - int i = 0; - try { - for (; i < filteredBuffers.size(); i++) { - channel.onRecoveredStateBuffer(filteredBuffers.get(i)); - } - } catch (Throwable t) { - for (int j = i; j < filteredBuffers.size(); j++) { - filteredBuffers.get(j).recycleBuffer(); - } - throw t; - } + checkState(dispatcher != null, "dispatcher not set."); + InputChannelInfo targetChannelInfo = channel.getChannelInfo(); + filteringHandler.filterAndRewrite( + channelInfo.getGateIdx(), + oldSubtaskIndex, + channelInfo.getInputChannelIdx(), + retainedBuffer, + dispatcher, + targetChannelInfo); } @Override - public void close() throws IOException { + public void finishRecovery() throws IOException { + if (recoveryFinished) { + return; + } + recoveryFinished = true; // note that we need to finish all RecoveredInputChannels, not just those with state for (final InputGate inputGate : inputGates) { inputGate.finishReadRecoveredState(); } + } + + @Override + public void close() throws IOException { if (preFilterSegment != null) { preFilterSegment.free(); preFilterSegment = null; @@ -279,6 +277,8 @@ class ResultSubpartitionRecoveredStateHandler private final boolean notifyAndBlockOnCompletion; private final ResultSubpartitionDistributor resultSubpartitionDistributor; + private boolean recoveryFinished; + ResultSubpartitionRecoveredStateHandler( ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion, @@ -286,10 +286,7 @@ class ResultSubpartitionRecoveredStateHandler this.writers = writers; this.resultSubpartitionDistributor = new ResultSubpartitionDistributor(channelMapping) { - /** - * Override the getSubpartitionInfo to perform type checking on the - * ResultPartitionWriter. - */ + /** Adds type-checking on the ResultPartitionWriter. */ @Override ResultSubpartitionInfo getSubpartitionInfo( int partitionIndex, int subPartitionIdx) { @@ -352,7 +349,11 @@ private CheckpointedResultPartition getCheckpointedResultPartition(int partition } @Override - public void close() throws IOException { + public void finishRecovery() throws IOException { + if (recoveryFinished) { + return; + } + recoveryFinished = true; for (ResultPartitionWriter writer : writers) { if (writer instanceof CheckpointedResultPartition) { ((CheckpointedResultPartition) writer) @@ -360,4 +361,7 @@ public void close() throws IOException { } } } + + @Override + public void close() throws IOException {} } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java index c52572e52faec..14f0f7f46934c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java @@ -23,8 +23,12 @@ import org.apache.flink.runtime.checkpoint.StateObjectCollection; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredBufferStoreImpl; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredInputChannel; import org.apache.flink.runtime.state.AbstractChannelStateHandle; import org.apache.flink.runtime.state.ChannelStateHelper; import org.apache.flink.runtime.state.StreamStateHandle; @@ -32,7 +36,9 @@ import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -63,19 +69,30 @@ public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot) { public void readInputData(InputGate[] inputGates, RecordFilterContext filterContext) throws IOException, InterruptedException { - // Create filtering handler if filtering is needed ChannelStateFilteringHandler filteringHandler = filterContext.isCheckpointingDuringRecoveryEnabled() ? ChannelStateFilteringHandler.createFromContext(filterContext, inputGates) : null; + FilteredBufferDispatcher dispatcher = null; + if (filteringHandler != null) { + Map storesByChannel = + createPerChannelStores(inputGates); + if (!storesByChannel.isEmpty()) { + dispatcher = + createFilteredBufferDispatcher(inputGates, storesByChannel, filterContext); + } + } + try (ChannelStateFilteringHandler ignored = filteringHandler; + FilteredBufferDispatcher d = dispatcher; InputChannelRecoveredStateHandler stateHandler = new InputChannelRecoveredStateHandler( inputGates, taskStateSnapshot.getInputRescalingDescriptor(), filteringHandler, - filterContext.getMemorySegmentSize())) { + filterContext.getMemorySegmentSize(), + dispatcher)) { read( stateHandler, groupByDelegate( @@ -92,6 +109,25 @@ public void readInputData(InputGate[] inputGates, RecordFilterContext filterCont !filteringHandler.hasPartialData(), "Not all data has been fully consumed during filtering"); } + + if (d != null) { + // flush() seals Readers; finishRecovery() triggers channel conversion; + // drainPendingSpill() blocks without holding the dispatcher monitor. + d.flush(); + stateHandler.finishRecovery(); + d.drainPendingSpill(); + } else { + // Snapshot recovered channel references BEFORE finishRecovery, because + // finishRecovery completes bufferFilteringCompleteFuture which can race a + // mailbox-driven conversion. Once conversion replaces channels[i] with a + // physical channel, looking up RecoveredInputChannels via inputGates would + // miss the converted ones, and their BufferManager rendezvous would leak. + List recovered = collectRecoveredChannels(inputGates); + stateHandler.finishRecovery(); + for (RecoveredInputChannel ch : recovered) { + ch.markDrainDone(); + } + } } } @@ -108,6 +144,7 @@ public void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlo groupByDelegate( streamSubtaskStates(), ChannelStateHelper::extractUnmergedOutputHandles)); + stateHandler.finishRecovery(); } } @@ -141,6 +178,112 @@ private > void re } } + private List collectRecoveredChannels(InputGate[] inputGates) { + List recovered = new ArrayList<>(); + for (InputGate gate : inputGates) { + for (int i = 0; i < gate.getNumberOfInputChannels(); i++) { + InputChannel ch = gate.getChannel(i); + if (ch instanceof RecoveredInputChannel) { + recovered.add((RecoveredInputChannel) ch); + } + } + } + return recovered; + } + + private Map createPerChannelStores( + InputGate[] inputGates) { + Map storesByChannel = new HashMap<>(); + for (InputGate gate : inputGates) { + for (int i = 0; i < gate.getNumberOfInputChannels(); i++) { + InputChannel ch = gate.getChannel(i); + if (ch instanceof RecoveredInputChannel) { + RecoveredInputChannel recoveredCh = (RecoveredInputChannel) ch; + InputChannelInfo info = recoveredCh.getChannelInfo(); + // Reuse the channel's own store so filtering and non-filtering paths deliver + // to the same instance. + RecoveredBufferStoreImpl store = recoveredCh.getStore(); + storesByChannel.put(info, store); + } + } + } + return storesByChannel; + } + + private FilteredBufferDispatcher createFilteredBufferDispatcher( + InputGate[] inputGates, + Map storesByChannel, + RecordFilterContext filterContext) + throws IOException { + Map channelMap = buildChannelMap(inputGates); + String[] spillDirs = filterContext.getTmpDirectories(); + int memorySegmentSize = filterContext.getMemorySegmentSize(); + + // All channels share the same writer. + RecoveredInputChannel anyChannel = channelMap.values().iterator().next(); + + return new FilteredBufferDispatcherImpl( + storesByChannel, + anyChannel.getChannelStateWriter(), + spillDirs, + memorySegmentSize, + new RecoveredChannelBufferRequester(channelMap)); + } + + /** {@link BufferRequester} routing each channel's request to its own exclusive pool. */ + private static final class RecoveredChannelBufferRequester implements BufferRequester { + + private final Map channelMap; + + RecoveredChannelBufferRequester(Map channelMap) { + this.channelMap = channelMap; + } + + @Override + public Buffer requestBuffer(InputChannelInfo channelInfo) throws IOException { + return lookup(channelInfo).requestBuffer(); + } + + @Override + public Buffer requestBufferBlocking(InputChannelInfo channelInfo) + throws InterruptedException, IOException { + return lookup(channelInfo).requestBufferBlocking(); + } + + @Override + public void releaseExclusiveBuffers() throws IOException { + // Two-flag rendezvous (drainDone + converted): releasing here directly would race + // mailbox-driven convertRecoveredInputChannels and cause pollNext on the not-yet- + // replaced gate slot to hit the released-channel checkState. + for (RecoveredInputChannel ch : channelMap.values()) { + ch.markDrainDone(); + } + } + + private RecoveredInputChannel lookup(InputChannelInfo channelInfo) { + RecoveredInputChannel ch = channelMap.get(channelInfo); + if (ch == null) { + throw new IllegalArgumentException( + "No RecoveredInputChannel for channelInfo: " + channelInfo); + } + return ch; + } + } + + private Map buildChannelMap(InputGate[] inputGates) { + Map channelMap = new HashMap<>(); + for (InputGate gate : inputGates) { + for (int i = 0; i < gate.getNumberOfInputChannels(); i++) { + InputChannel ch = gate.getChannel(i); + if (ch instanceof RecoveredInputChannel) { + RecoveredInputChannel recoveredCh = (RecoveredInputChannel) ch; + channelMap.put(recoveredCh.getChannelInfo(), recoveredCh); + } + } + } + return channelMap; + } + private Stream streamSubtaskStates() { return taskStateSnapshot.getSubtaskStateMappings().stream().map(Map.Entry::getValue); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/ChannelStatePersister.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/ChannelStatePersister.java index b5f4adf721ef0..9a55f5049b8ba 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/ChannelStatePersister.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/ChannelStatePersister.java @@ -32,16 +32,23 @@ import org.slf4j.LoggerFactory; import javax.annotation.Nullable; -import javax.annotation.concurrent.NotThreadSafe; +import javax.annotation.concurrent.GuardedBy; import java.io.IOException; import java.util.List; import java.util.OptionalLong; import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; -/** Helper class for persisting channel state via {@link ChannelStateWriter}. */ -@NotThreadSafe +/** + * Helper class for persisting channel state via {@link ChannelStateWriter}. + * + *

Holds no lock of its own. Every state-touching method requires the caller to hold the bound + * {@link RecoveredBufferStore}'s intrinsic monitor; an {@code assert Thread.holdsLock(store)} + * enforces it under {@code -ea}. This serializes the task thread (start/stop, getNext) with the + * network thread (onBuffer) on the same outer {@code synchronized(store)}. + */ public final class ChannelStatePersister { private static final Logger LOG = LoggerFactory.getLogger(ChannelStatePersister.class); @@ -53,44 +60,73 @@ private enum CheckpointStatus { BARRIER_RECEIVED } + @GuardedBy("store") private CheckpointStatus checkpointStatus = CheckpointStatus.COMPLETED; + @GuardedBy("store") private long lastSeenBarrier = -1L; - /** - * Writer must be initialized before usage. {@link #startPersisting(long, List)} enforces this - * invariant. - */ private final ChannelStateWriter channelStateWriter; - ChannelStatePersister(ChannelStateWriter channelStateWriter, InputChannelInfo channelInfo) { + private final RecoveredBufferStore store; + + ChannelStatePersister( + ChannelStateWriter channelStateWriter, + InputChannelInfo channelInfo, + RecoveredBufferStore store) { this.channelStateWriter = checkNotNull(channelStateWriter); this.channelInfo = checkNotNull(channelInfo); + this.store = checkNotNull(store); } + /** + * Snapshots the recovered store when present, otherwise writes the network inflight buffers via + * {@link ChannelStateWriter#addInputData}. Asserts that the two are mutually exclusive ({@code + * store.isEmpty() || knownBuffers.isEmpty()}). + * + * @param knownBuffers network inflight buffers (empty for LocalInputChannel) + */ + @GuardedBy("store") protected void startPersisting(long barrierId, List knownBuffers) throws CheckpointException { + assert Thread.holdsLock(store); logEvent("startPersisting", barrierId); if (checkpointStatus == CheckpointStatus.BARRIER_RECEIVED && lastSeenBarrier > barrierId) { throw new CheckpointException( String.format( "Barrier for newer checkpoint %d has already been received compared to the requested checkpoint %d", lastSeenBarrier, barrierId), - CheckpointFailureReason - .CHECKPOINT_SUBSUMED); // currently, at most one active unaligned + CheckpointFailureReason.CHECKPOINT_SUBSUMED); } if (lastSeenBarrier < barrierId) { - // Regardless of the current checkpointStatus, if we are notified about a more recent - // checkpoint then we have seen so far, always mark that this more recent barrier is - // pending. - // BARRIER_RECEIVED status can happen if we have seen an older barrier, that probably - // has not yet been processed by the task, but task is now notifying us that checkpoint - // has started for even newer checkpoint. We should spill the knownBuffers and mark that - // we are waiting for that newer barrier to arrive + // Override BARRIER_RECEIVED too: task is announcing a newer checkpoint than the + // barrier we already saw on the wire, so spill knownBuffers under the new id and + // wait for the newer barrier to arrive. checkpointStatus = CheckpointStatus.BARRIER_PENDING; lastSeenBarrier = barrierId; } - if (knownBuffers.size() > 0) { + + final boolean storeEmpty = store.isEmpty(); + final int storeSize = store.size(); + checkState( + storeEmpty || knownBuffers.isEmpty(), + "Invariant violated: store has data (size=%s) AND knownBuffers non-empty (size=%s) at barrier %s. " + + "Requires UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM=true so upstream does not " + + "replay output state into receivedBuffers while the recovered store is still draining.", + storeSize, + knownBuffers.size(), + barrierId); + + if (!storeEmpty) { + try { + store.checkpoint(channelStateWriter, barrierId); + } catch (IOException e) { + throw new CheckpointException( + "Failed to checkpoint recovered store", + CheckpointFailureReason.IO_EXCEPTION, + e); + } + } else if (!knownBuffers.isEmpty()) { channelStateWriter.addInputData( barrierId, channelInfo, @@ -99,15 +135,25 @@ protected void startPersisting(long barrierId, List knownBuffers) } } + /** + * Marks {@code id} concluded on this channel and notifies the store. Without the notification + * an aborted checkpoint's wait-set could linger and a later release / late callback would + * trigger a phase-2 drain into a concluded checkpoint. + */ + @GuardedBy("store") protected void stopPersisting(long id) { + assert Thread.holdsLock(store); logEvent("stopPersisting", id); if (id >= lastSeenBarrier) { checkpointStatus = CheckpointStatus.COMPLETED; lastSeenBarrier = id; } + store.notifyCheckpointStopped(id); } + @GuardedBy("store") protected void maybePersist(Buffer buffer) { + assert Thread.holdsLock(store); if (checkpointStatus == CheckpointStatus.BARRIER_PENDING && buffer.isBuffer()) { channelStateWriter.addInputData( lastSeenBarrier, @@ -117,7 +163,9 @@ protected void maybePersist(Buffer buffer) { } } + @GuardedBy("store") protected OptionalLong checkForBarrier(Buffer buffer) throws IOException { + assert Thread.holdsLock(store); AbstractEvent event = parseEvent(buffer); if (event instanceof CheckpointBarrier) { long barrierId = ((CheckpointBarrier) event).getId(); @@ -134,7 +182,8 @@ protected OptionalLong checkForBarrier(Buffer buffer) throws IOException { logEvent("ignoring barrier", barrierId); } } - if (event instanceof EventAnnouncement) { // NOTE: only remote channels + if (event instanceof EventAnnouncement) { + // Only remote channels announce barriers ahead of the data they overtake. EventAnnouncement announcement = (EventAnnouncement) event; if (announcement.getAnnouncedEvent() instanceof CheckpointBarrier) { long barrierId = ((CheckpointBarrier) announcement.getAnnouncedEvent()).getId(); @@ -157,30 +206,28 @@ private void logEvent(String event, long barrierId) { } } - /** - * Parses the buffer as an event and returns the {@link CheckpointBarrier} if the event is - * indeed a barrier or returns null in all other cases. - */ @Nullable protected AbstractEvent parseEvent(Buffer buffer) throws IOException { if (buffer.isBuffer()) { return null; - } else { - AbstractEvent event = EventSerializer.fromBuffer(buffer, getClass().getClassLoader()); - // reset the buffer because it would be deserialized again in SingleInputGate while - // getting next buffer. - // we can further improve to avoid double deserialization in the future. - buffer.setReaderIndex(0); - return event; } + AbstractEvent event = EventSerializer.fromBuffer(buffer, getClass().getClassLoader()); + // SingleInputGate will deserialize the same buffer again when handing it to the task; we + // can drop this double deserialization in the future. + buffer.setReaderIndex(0); + return event; } + @GuardedBy("store") protected boolean hasBarrierReceived() { + assert Thread.holdsLock(store); return checkpointStatus == CheckpointStatus.BARRIER_RECEIVED; } @Override public String toString() { + // Best-effort debug snapshot; reads are unsynchronized so the logger never blocks behind + // callers that already hold (or are about to acquire) the store lock. return "ChannelStatePersister(lastSeenBarrier=" + lastSeenBarrier + " (" diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/EmptyRecoveredBufferStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/EmptyRecoveredBufferStore.java new file mode 100644 index 0000000000000..161c5e22ff902 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/EmptyRecoveredBufferStore.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveredBufferStoreCoordinator; +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import javax.annotation.Nullable; + +/** No-op {@link RecoveredBufferStore} sentinel for channels without recovered data. */ +class EmptyRecoveredBufferStore implements RecoveredBufferStore { + + @Nullable + @Override + public Buffer tryTake() { + return null; + } + + @Override + public Buffer.DataType peekNextDataType() { + return Buffer.DataType.NONE; + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public int size() { + return 0; + } + + @Override + public void checkpoint(ChannelStateWriter writer, long checkpointId) {} + + @Override + public void releaseAll() {} + + @Override + public void notifyCheckpointStopped(long checkpointId) {} + + @Override + public void setCoordinator(RecoveredBufferStoreCoordinator coordinator) {} + + @Override + public void setDataAvailableListener(DataAvailableListener listener) {} +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java index 661e4b063c75f..4984bae411562 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java @@ -43,10 +43,11 @@ import org.slf4j.LoggerFactory; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import java.io.IOException; import java.util.ArrayDeque; -import java.util.ArrayList; +import java.util.Collections; import java.util.Deque; import java.util.List; import java.util.Optional; @@ -80,10 +81,12 @@ public class LocalInputChannel extends InputChannel implements BufferAvailabilit private final Deque toBeConsumedBuffers = new ArrayDeque<>(); + /** Always non-null: callers with no recovered data pass {@link RecoveredBufferStore#EMPTY}. */ + private final RecoveredBufferStore recoveredStore; + /** - * Flag indicating whether there is a pending priority event (e.g., checkpoint barrier) in the - * subpartitionView that should be consumed before toBeConsumedBuffers. This is set by {@link - * #notifyPriorityEvent} and checked in {@link #getNextBuffer()}. + * True when a priority event (e.g. unaligned checkpoint barrier) sits in {@code + * subpartitionView} and must be consumed before recovered data. */ private volatile boolean hasPendingPriorityEvent = false; @@ -99,7 +102,7 @@ public LocalInputChannel( Counter numBytesIn, Counter numBuffersIn, ChannelStateWriter stateWriter, - ArrayDeque initialRecoveredBuffers) { + RecoveredBufferStore recoveredStore) { super( inputGate, @@ -113,31 +116,12 @@ public LocalInputChannel( this.partitionManager = checkNotNull(partitionManager); this.taskEventPublisher = checkNotNull(taskEventPublisher); - this.channelStatePersister = new ChannelStatePersister(stateWriter, getChannelInfo()); - - // Migrate recovered buffers from RecoveredInputChannel if provided. - // These buffers have been filtered but not yet consumed by the Task. - if (!initialRecoveredBuffers.isEmpty()) { - final int expectedCount = initialRecoveredBuffers.size(); - // Sequence number starts at Integer.MIN_VALUE, consistent with RecoveredInputChannel. - int seqNum = Integer.MIN_VALUE; - while (!initialRecoveredBuffers.isEmpty()) { - Buffer buffer = initialRecoveredBuffers.poll(); - // Determine next data type based on the next buffer in the queue - Buffer.DataType nextDataType = - initialRecoveredBuffers.isEmpty() - ? Buffer.DataType.NONE - : initialRecoveredBuffers.peek().getDataType(); - // buffersInBacklog is set to 0 as these are recovered buffers - BufferAndBacklog bufferAndBacklog = - new BufferAndBacklog(buffer, 0, nextDataType, seqNum++); - toBeConsumedBuffers.add(bufferAndBacklog); - } - checkState( - toBeConsumedBuffers.size() == expectedCount, - "Buffer migration failed: expected %s buffers but got %s", - expectedCount, - toBeConsumedBuffers.size()); + + this.recoveredStore = checkNotNull(recoveredStore); + this.channelStatePersister = + new ChannelStatePersister(stateWriter, getChannelInfo(), this.recoveredStore); + synchronized (this.recoveredStore) { + this.recoveredStore.setDataAvailableListener(this::notifyChannelNonEmpty); } } @@ -146,19 +130,18 @@ public LocalInputChannel( // ------------------------------------------------------------------------ public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { - // Collect inflight buffers from toBeConsumedBuffers to be persisted. - // These are buffers that have not been consumed yet when the checkpoint barrier arrives. - List inflightBuffers = new ArrayList<>(); - for (BufferAndBacklog bufferAndBacklog : toBeConsumedBuffers) { - if (bufferAndBacklog.buffer().isBuffer()) { - inflightBuffers.add(bufferAndBacklog.buffer().retainBuffer()); - } + // Local channels have no network inflight buffers to snapshot (barriers and data arrive + // together via the local subpartition view). FullyFilledBuffer splits in + // toBeConsumedBuffers are ordinary data fragments and don't belong in channel state. + synchronized (recoveredStore) { + channelStatePersister.startPersisting(barrier.getId(), Collections.emptyList()); } - channelStatePersister.startPersisting(barrier.getId(), inflightBuffers); } public void checkpointStopped(long checkpointId) { - channelStatePersister.stopPersisting(checkpointId); + synchronized (recoveredStore) { + channelStatePersister.stopPersisting(checkpointId); + } } @Override @@ -272,10 +255,18 @@ protected int peekNextBufferSubpartitionIdInternal() throws IOException { public Optional getNextBuffer() throws IOException { checkError(); - if (!toBeConsumedBuffers.isEmpty()) { + final boolean stillRecovering; + synchronized (recoveredStore) { + stillRecovering = !recoveredStore.isEmpty(); + } + if (stillRecovering) { return getNextRecoveredBuffer(); } + if (!toBeConsumedBuffers.isEmpty()) { + return getNextSplitBuffer(); + } + ResultSubpartitionView subpartitionView = this.subpartitionView; if (subpartitionView == null) { // There is a possible race condition between writing a EndOfPartitionEvent (1) and @@ -335,14 +326,9 @@ public Optional getNextBuffer() throws IOException { return getBufferAndAvailability(next); } - /** - * Consumes the next buffer from toBeConsumedBuffers (recovered buffers), handling pending - * priority events and dynamic availability detection for the last recovered buffer. - */ private Optional getNextRecoveredBuffer() throws IOException { - // If there is a pending priority event (e.g., unaligned checkpoint barrier), fetch it - // from subpartitionView first, skipping toBeConsumedBuffers. This ensures priority - // events are processed immediately even when there are pending recovered buffers. + // Pending priority event bypasses the FIFO recovery-first rule so unaligned barriers can + // be processed immediately. if (hasPendingPriorityEvent) { checkState(subpartitionView != null, "No subpartition view available"); BufferAndBacklog next = subpartitionView.getNextBuffer(); @@ -351,18 +337,15 @@ private Optional getNextRecoveredBuffer() throws IOExcept "Expected priority event, but got %s", next == null ? "null" : next.buffer().getDataType()); - // Check for barrier to update channel state persister. - // Note: maybePersist is not needed for barriers as they are not regular data buffers. - channelStatePersister.checkForBarrier(next.buffer()); - Buffer.DataType expectedNextDataType = next.getNextDataType(); - if (!expectedNextDataType.hasPriority()) { - // Reset hasPendingPriorityEvent to false if no more priority event - hasPendingPriorityEvent = false; - if (!toBeConsumedBuffers.isEmpty()) { - // Correct nextDataType: if toBeConsumedBuffers is not empty, the actual next - // element to consume is from toBeConsumedBuffers, not from subpartitionView - expectedNextDataType = toBeConsumedBuffers.peek().buffer().getDataType(); + synchronized (recoveredStore) { + channelStatePersister.checkForBarrier(next.buffer()); + if (!expectedNextDataType.hasPriority()) { + hasPendingPriorityEvent = false; + // recoveredStore (if non-empty) is FIFO ahead of subpartitionView. + if (!recoveredStore.isEmpty()) { + expectedNextDataType = peekNextDataType(); + } } } @@ -374,26 +357,52 @@ private Optional getNextRecoveredBuffer() throws IOExcept next.getSequenceNumber())); } - BufferAndBacklog next = toBeConsumedBuffers.removeFirst(); + // tryTake + peekNextDataType under one lock so the consumer never observes a torn + // (post-take, pre-peek) view. + final Buffer next; + Buffer.DataType nextDataType; + synchronized (recoveredStore) { + next = recoveredStore.tryTake(); + if (next == null) { + return Optional.empty(); + } + nextDataType = peekNextDataType(); + } + int sequenceNumber = Integer.MIN_VALUE; - // If this is the last recovered buffer and nextDataType is NONE, - // dynamically check if subpartitionView has data available. - // The last buffer's nextDataType was preset to NONE during construction, - // but subpartitionView may already have data available. - if (toBeConsumedBuffers.isEmpty() - && next.getNextDataType() == Buffer.DataType.NONE - && subpartitionView != null) { + // subpartitionView availability check is outside the store lock: it acquires producer- + // side locks, and the consumer already holds gate → store, so adding store → + // subpartition would close an AB-BA cycle with the producer's notify path + // (subpartition → gate.notifyChannelNonEmpty → inputChannelsWithData). + if (nextDataType == Buffer.DataType.NONE && subpartitionView != null) { ResultSubpartitionView.AvailabilityWithBacklog availability = subpartitionView.getAvailabilityAndBacklog(true); if (availability.isAvailable()) { - next = - new BufferAndBacklog( - next.buffer(), - availability.getBacklog(), - Buffer.DataType.DATA_BUFFER, - next.getSequenceNumber()); + nextDataType = Buffer.DataType.DATA_BUFFER; } } + + BufferAndBacklog bufferAndBacklog = + new BufferAndBacklog(next, 0, nextDataType, sequenceNumber); + return getBufferAndAvailability(bufferAndBacklog); + } + + /** + * Data type of the next consumer-visible buffer behind the store. Caller MUST hold {@code + * recoveredStore}. The {@link #subpartitionView} tier is deliberately not consulted — doing so + * under the store lock would form an AB-BA cycle with the producer's notify path. + */ + @GuardedBy("recoveredStore") + private Buffer.DataType peekNextDataType() { + assert Thread.holdsLock(recoveredStore); + if (!recoveredStore.isEmpty()) { + return recoveredStore.peekNextDataType(); + } + return Buffer.DataType.NONE; + } + + private Optional getNextSplitBuffer() throws IOException { + BufferAndBacklog next = toBeConsumedBuffers.removeFirst(); return getBufferAndAvailability(next); } @@ -410,8 +419,10 @@ private Optional getBufferAndAvailability(BufferAndBacklo numBytesIn.inc(buffer.readableBytes()); numBuffersIn.inc(); - channelStatePersister.checkForBarrier(buffer); - channelStatePersister.maybePersist(buffer); + synchronized (recoveredStore) { + channelStatePersister.checkForBarrier(buffer); + channelStatePersister.maybePersist(buffer); + } NetworkActionsLogger.traceInput( "LocalInputChannel#getNextBuffer", buffer, @@ -434,8 +445,6 @@ public void notifyDataAvailable(ResultSubpartitionView view) { @Override public void notifyPriorityEvent(int prioritySequenceNumber) { - // Set flag so that getNextBuffer() knows to fetch priority event from subpartitionView - // before consuming toBeConsumedBuffers. hasPendingPriorityEvent = true; super.notifyPriorityEvent(prioritySequenceNumber); } @@ -512,8 +521,10 @@ void releaseAllResources() throws IOException { subpartitionView = null; } - // Release any remaining buffers in toBeConsumedBuffers to avoid memory leak. - // These may be recovered buffers or partial buffers from FullyFilledBuffer. + // EMPTY.releaseAll() is a no-op, so no null check needed. + recoveredStore.releaseAll(); + + // Release partial buffers from FullyFilledBuffer splits. for (BufferAndBacklog bufferAndBacklog : toBeConsumedBuffers) { bufferAndBacklog.buffer().recycleBuffer(); } @@ -534,14 +545,16 @@ void announceBufferSize(int newBufferSize) { @Override int getBuffersInUseCount() { ResultSubpartitionView view = this.subpartitionView; - return toBeConsumedBuffers.size() + (view == null ? 0 : view.getNumberOfQueuedBuffers()); + return recoveredStore.size() + + toBeConsumedBuffers.size() + + (view == null ? 0 : view.getNumberOfQueuedBuffers()); } @Override public int unsynchronizedGetNumberOfQueuedBuffers() { ResultSubpartitionView view = subpartitionView; - int count = toBeConsumedBuffers.size(); + int count = recoveredStore.size() + toBeConsumedBuffers.size(); if (view != null) { count += view.unsynchronizedGetNumberOfQueuedBuffers(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java index bdde2244f38ef..23bdb0c4d8bd1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java @@ -19,14 +19,11 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.runtime.io.network.TaskEventPublisher; -import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; -import java.util.ArrayDeque; - import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -64,7 +61,7 @@ public class LocalRecoveredInputChannel extends RecoveredInputChannel { } @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + protected InputChannel toInputChannelInternal(RecoveredBufferStoreImpl recoveredStore) { return new LocalInputChannel( inputGate, getChannelIndex(), @@ -77,6 +74,6 @@ protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffer numBytesIn, numBuffersIn, channelStateWriter, - remainingBuffers); + recoveredStore); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredBufferStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredBufferStore.java new file mode 100644 index 0000000000000..e1878b6355f8f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredBufferStore.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveredBufferStoreCoordinator; +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import javax.annotation.Nullable; + +import java.io.IOException; + +/** + * Per-channel store for recovered buffers during unaligned checkpoint recovery. Buffers are either + * in-memory (ready for consumption) or on disk (pending spill entries). + * + *

Locking contract

+ * + *

The store's intrinsic monitor IS the channel-private lock. Callers MUST hold {@code + * synchronized(store)} when invoking the consumer-side queries ({@link #tryTake}, {@link + * #peekNextDataType}, {@link #isEmpty}) and the setters; implementations enforce this with an + * internal {@code assert Thread.holdsLock(this)} that fires under {@code -ea}. + * + *

{@link #size()} is exempt and lock-free: a metric/bookkeeping read that tolerates a slightly + * stale value. Callers that need a consistent {@code isEmpty + size} pair must take the lock around + * both reads themselves. + * + *

The lifecycle methods ({@link #checkpoint}, {@link #releaseAll}, {@link + * #notifyCheckpointStopped}) self-manage their store-level locking and fire any coordinator + * callback outside the lock to avoid deadlock with the coordinator. + * + *

Use {@link #EMPTY} as a sentinel when no recovered data is present, rather than holding {@code + * null}; callers still wrap calls in {@code synchronized(store)} to keep the same call shape + * regardless of which implementation backs the channel. + */ +@Internal +public interface RecoveredBufferStore { + + /** Singleton no-op store used when there is no recovered data for a channel. */ + RecoveredBufferStore EMPTY = new EmptyRecoveredBufferStore(); + + /** Next buffer from the store; null if no ready buffer is available. */ + @Nullable + Buffer tryTake(); + + /** Data type of the next ready buffer, or {@link Buffer.DataType#NONE} if empty. */ + Buffer.DataType peekNextDataType(); + + /** True if the ready queue is empty and no pending spill entries exist. */ + boolean isEmpty(); + + /** + * Total buffers held for the bound channel: ready buffers in memory plus pending entries on + * disk. Used for in-use / backlog accounting. + */ + int size(); + + /** + * Snapshots ready buffers into {@code writer}, then notifies the registered coordinator + * (outside the store lock to avoid deadlock). + */ + void checkpoint(ChannelStateWriter writer, long checkpointId) throws IOException; + + /** + * Releases all buffers and clears state. Notifies the coordinator (if any) so it can drop + * still-pending spill entries for this channel. + */ + void releaseAll(); + + /** + * Forwards a checkpoint-stopped notification (completion or abort) to the coordinator. The + * store keeps no per-checkpoint state; cross-channel bookkeeping (wait-set) lives there. + */ + void notifyCheckpointStopped(long checkpointId); + + /** Registers the cross-channel coordinator. Pass non-null when one exists. */ + void setCoordinator(RecoveredBufferStoreCoordinator coordinator); + + /** + * Listener fired when a buffer is added to a previously empty ready queue. The typical + * recipient is the owning InputChannel, which uses it to wake up the Task thread. + */ + void setDataAvailableListener(DataAvailableListener listener); + + @FunctionalInterface + interface DataAvailableListener { + + void onDataAvailable(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredBufferStoreImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredBufferStoreImpl.java new file mode 100644 index 0000000000000..01a0e5261c24d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredBufferStoreImpl.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.EntryPosition; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecoveredBufferStoreCoordinator; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.util.CloseableIterator; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Per-channel store for recovered buffers. Buffers are either ready (in-memory) or pending (on + * disk, tracked by count only — the actual entries are owned by FilteredBufferDispatcher). + * + *

Locking

+ * + *

Two monitors guard the race-relevant paths: the owning gate's lock (returned by {@code + * SingleInputGate#getGateLock()}, henceforth gate lock) and this store's intrinsic monitor. + * The acquisition order is always {@code gate → store}; the same order is enforced by the consumer + * (task read), the producer (drain / EOICS publish) and conversion ({@code + * SingleInputGate#convertRecoveredInputChannels}). This is what closes the FLINK-39519 + * stale-enqueue race: producer add+fire and conversion's listener replace + {@code channels[i]} + * swap are serialised through the same gate lock. + * + *

Per-method contracts (each one's preconditions are restated as {@code assert}s at the entry): + * + *

    + *
  • Producer mutators ({@link #addBuffer}, {@link #addBufferAfterDisk}): caller holds + * the gate lock; the store self-manages its own monitor and fires the data-available listener + * inline. Firing inside the store monitor is safe because the gate lock is held by the + * caller, so {@code queueChannel} re-acquires it as a recursive intrinsic-monitor entry — no + * AB-BA cycle. + *
  • Race-path readers / setter ({@link #tryTake}, {@link #peekNextDataType}, {@link + * #setDataAvailableListener}): {@code @GuardedBy("this")} — caller wraps the call in {@code + * synchronized(store)} so compound operations such as {@code tryTake() + peekNextDataType()} + * observe a consistent snapshot — and additionally holds the gate lock so the {@code gate → + * store} order is explicit. + *
  • Store-only readers / mutators ({@link #isEmpty}, {@link #setCoordinator}, {@link + * #incrementPending}): {@code @GuardedBy("this")} only. {@link #isEmpty} is consulted from + * {@code RemoteInputChannel#onSenderBacklog} on the netty event loop where the gate lock is + * not held; {@link #setCoordinator} / {@link #incrementPending} run before the recovery + * flush, off the race path. + *
  • Lifecycle / coordinator ({@link #checkpoint}, {@link #releaseAll}, {@link + * #notifyCheckpointStopped}): self-manage the store monitor and fire any captured coordinator + * callback after exiting the lock. Independent of the producer/consumer hot path. + *
  • {@link #size()} is a deliberate lock-free best-effort read. + *
+ */ +@Internal +public class RecoveredBufferStoreImpl implements RecoveredBufferStore { + + private final InputChannelInfo channelInfo; + + private final Object gateLock; + + @GuardedBy("this") + private final ArrayDeque readyBuffers = new ArrayDeque<>(); + + @GuardedBy("this") + private int pendingCount = 0; + + /** + * Holds buffers whose contract is "everything before me has been delivered" (e.g. {@link + * EndOfInputChannelStateEvent}) until {@link #pendingCount} reaches zero, then promotes them + * into {@link #readyBuffers}. Excluded from {@link #size()} / {@link #isEmpty()}: while + * something is deferred, {@code pendingCount > 0} already keeps the store non-empty. + */ + @GuardedBy("this") + private final ArrayDeque deferredBuffers = new ArrayDeque<>(); + + private volatile boolean released = false; + + @GuardedBy("this") + private DataAvailableListener dataAvailableListener; + + @GuardedBy("this") + private RecoveredBufferStoreCoordinator coordinator; + + public RecoveredBufferStoreImpl(InputChannelInfo channelInfo, Object gateLock) { + this.channelInfo = checkNotNull(channelInfo); + this.gateLock = checkNotNull(gateLock); + } + + /** + * Test-only: ties {@code gateLock} to the store itself so {@code synchronized(store)} alone + * satisfies both preconditions. + */ + @VisibleForTesting + public RecoveredBufferStoreImpl(InputChannelInfo channelInfo) { + this.channelInfo = checkNotNull(channelInfo); + this.gateLock = this; + } + + public InputChannelInfo getChannelInfo() { + return channelInfo; + } + + public Object getGateLock() { + return gateLock; + } + + @Nullable + @Override + @GuardedBy("this") + public Buffer tryTake() { + assert Thread.holdsLock(this); + assert Thread.holdsLock(gateLock); + return readyBuffers.poll(); + } + + @Override + @GuardedBy("this") + public Buffer.DataType peekNextDataType() { + assert Thread.holdsLock(this); + assert Thread.holdsLock(gateLock); + Buffer peeked = readyBuffers.peek(); + return peeked != null ? peeked.getDataType() : Buffer.DataType.NONE; + } + + @Override + @GuardedBy("this") + public boolean isEmpty() { + assert Thread.holdsLock(this); + return readyBuffers.isEmpty() && pendingCount == 0; + } + + @Override + public int size() { + return readyBuffers.size() + pendingCount; + } + + @Override + public void checkpoint(ChannelStateWriter writer, long checkpointId) throws IOException { + RecoveredBufferStoreCoordinator c; + EntryPosition startPos; + synchronized (this) { + c = coordinator; + startPos = c != null ? c.getCurrentDrainHead() : EntryPosition.END; + if (!readyBuffers.isEmpty()) { + // Skip non-data entries (notably EndOfInputChannelStateEvent): they would be + // rejected by ChannelStateWriteRequest#checkBufferIsBuffer and kill the writer. + List retained = new ArrayList<>(readyBuffers.size()); + for (Buffer buffer : readyBuffers) { + if (buffer.isBuffer()) { + retained.add(buffer.retainBuffer()); + } + } + if (!retained.isEmpty()) { + writer.addInputData( + checkpointId, + channelInfo, + ChannelStateWriter.SEQUENCE_NUMBER_RESTORED, + CloseableIterator.fromList(retained, Buffer::recycleBuffer)); + } + } + } + + if (c != null) { + c.onChannelCheckpointStarted(checkpointId, channelInfo, startPos); + } + } + + @Override + public void releaseAll() { + RecoveredBufferStoreCoordinator c; + synchronized (this) { + released = true; + for (Buffer buffer : readyBuffers) { + buffer.recycleBuffer(); + } + readyBuffers.clear(); + for (Buffer buffer : deferredBuffers) { + buffer.recycleBuffer(); + } + deferredBuffers.clear(); + pendingCount = 0; + c = coordinator; + } + + if (c != null) { + c.onChannelReleased(channelInfo); + } + } + + @Override + public void notifyCheckpointStopped(long checkpointId) { + RecoveredBufferStoreCoordinator c; + synchronized (this) { + c = coordinator; + } + if (c != null) { + c.onChannelCheckpointStopped(checkpointId, channelInfo); + } + } + + @Override + @GuardedBy("this") + public void setCoordinator(RecoveredBufferStoreCoordinator coordinator) { + assert Thread.holdsLock(this); + this.coordinator = coordinator; + } + + @Override + @GuardedBy("this") + public void setDataAvailableListener(DataAvailableListener listener) { + assert Thread.holdsLock(this); + assert Thread.holdsLock(gateLock); + this.dataAvailableListener = listener; + } + + public void addBuffer(Buffer buffer) { + assert Thread.holdsLock(gateLock); + DataAvailableListener listenerToFire; + synchronized (this) { + if (released) { + buffer.recycleBuffer(); + return; + } + boolean wasEmpty = readyBuffers.isEmpty(); + readyBuffers.add(buffer); + if (pendingCount > 0) { + pendingCount--; + if (pendingCount == 0 && !deferredBuffers.isEmpty()) { + while (!deferredBuffers.isEmpty()) { + readyBuffers.add(deferredBuffers.pollFirst()); + } + } + } + listenerToFire = wasEmpty ? dataAvailableListener : null; + } + if (listenerToFire != null) { + listenerToFire.onDataAvailable(); + } + } + + /** Counterpart of {@link #addBuffer}: increments the pending-on-disk counter for a spill. */ + @GuardedBy("this") + public void incrementPending() { + assert Thread.holdsLock(this); + pendingCount++; + } + + /** + * Adds a buffer that becomes consumer-visible only after all on-disk entries have been drained. + * Used to publish {@code EndOfInputChannelStateEvent} so it always lands after the last + * recovered data buffer — without routing the event through the spill path. + */ + public void addBufferAfterDisk(Buffer buffer) { + assert Thread.holdsLock(gateLock); + DataAvailableListener listenerToFire; + synchronized (this) { + if (released) { + buffer.recycleBuffer(); + return; + } + if (pendingCount == 0) { + boolean wasEmpty = readyBuffers.isEmpty(); + readyBuffers.add(buffer); + listenerToFire = wasEmpty ? dataAvailableListener : null; + } else { + deferredBuffers.add(buffer); + listenerToFire = null; + } + } + if (listenerToFire != null) { + listenerToFire.onDataAvailable(); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java index d9b7885815bd1..71dbd5a160ade 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java @@ -19,8 +19,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.metrics.Counter; import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; @@ -29,25 +27,21 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; -import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; -import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger; import org.apache.flink.runtime.io.network.partition.ChannelStateHolder; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; -import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import java.io.IOException; import java.util.ArrayDeque; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -58,7 +52,7 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan private static final Logger LOG = LoggerFactory.getLogger(RecoveredInputChannel.class); - private final ArrayDeque receivedBuffers = new ArrayDeque<>(); + private final RecoveredBufferStoreImpl store; private final CompletableFuture stateConsumedFuture = new CompletableFuture<>(); protected final BufferManager bufferManager; @@ -69,8 +63,7 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan */ private final CompletableFuture bufferFilteringCompleteFuture = new CompletableFuture<>(); - @GuardedBy("receivedBuffers") - private boolean isReleased; + private final AtomicBoolean isReleased = new AtomicBoolean(false); protected ChannelStateWriter channelStateWriter; @@ -85,6 +78,9 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan private long lastStoppedCheckpointId = -1; + private volatile boolean drainDone = false; + private volatile boolean storeTransferred = false; + RecoveredInputChannel( SingleInputGate inputGate, int channelIndex, @@ -107,6 +103,16 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan bufferManager = new BufferManager(inputGate.getMemorySegmentProvider(), this, 0); this.networkBuffersPerChannel = networkBuffersPerChannel; + this.store = new RecoveredBufferStoreImpl(getChannelInfo(), inputGate.getGateLock()); + synchronized (inputGate.getGateLock()) { + synchronized (store) { + store.setDataAvailableListener(this::notifyChannelNonEmpty); + } + } + } + + public RecoveredBufferStoreImpl getStore() { + return store; } @Override @@ -115,6 +121,11 @@ public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { this.channelStateWriter = checkNotNull(channelStateWriter); } + /** Must be called after {@link #setChannelStateWriter}. */ + public ChannelStateWriter getChannelStateWriter() { + return checkNotNull(channelStateWriter, "ChannelStateWriter has not been set yet"); + } + public final InputChannel toInputChannel() throws IOException { Preconditions.checkState( bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete"); @@ -123,15 +134,7 @@ public final InputChannel toInputChannel() throws IOException { stateConsumedFuture.isDone(), "recovered state is not fully consumed"); } - // Extract remaining buffers before conversion. - // These buffers have been filtered but not yet consumed by the Task. - final ArrayDeque remainingBuffers; - synchronized (receivedBuffers) { - remainingBuffers = new ArrayDeque<>(receivedBuffers); - receivedBuffers.clear(); - } - - final InputChannel inputChannel = toInputChannelInternal(remainingBuffers); + final InputChannel inputChannel = toInputChannelInternal(store); inputChannel.checkpointStopped(lastStoppedCheckpointId); return inputChannel; } @@ -142,13 +145,10 @@ public void checkpointStopped(long checkpointId) { } /** - * Creates the physical InputChannel from this recovered channel. - * - * @param remainingBuffers buffers that have been filtered but not yet consumed by the Task. - * These buffers will be migrated to the new physical channel. - * @return the physical InputChannel (LocalInputChannel or RemoteInputChannel) + * Creates the physical InputChannel; the store reference is transferred for continued + * consumption. */ - protected abstract InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) + protected abstract InputChannel toInputChannelInternal(RecoveredBufferStoreImpl recoveredStore) throws IOException; /** @@ -163,67 +163,33 @@ CompletableFuture getStateConsumedFuture() { return stateConsumedFuture; } - public void onRecoveredStateBuffer(Buffer buffer) { - boolean recycleBuffer = true; - NetworkActionsLogger.traceRecover( - "InputChannelRecoveredStateHandler#recover", - buffer, - inputGate.getOwningTaskName(), - channelInfo); - try { - final boolean wasEmpty; - synchronized (receivedBuffers) { - // Similar to notifyBufferAvailable(), make sure that we never add a buffer - // after releaseAllResources() released all buffers from receivedBuffers. - if (isReleased) { - wasEmpty = false; - } else { - wasEmpty = receivedBuffers.isEmpty(); - receivedBuffers.add(buffer); - recycleBuffer = false; - } - } - - if (wasEmpty) { - notifyChannelNonEmpty(); - } - } finally { - if (recycleBuffer) { - buffer.recycleBuffer(); - } - } + /** + * Publishes the {@link EndOfInputChannelStateEvent} (deferred behind any pending spill) and + * completes {@link #bufferFilteringCompleteFuture}. Caller must hold the gate lock; floating + * buffers are released separately via {@link #releaseRecoveryFloatingBuffers()}. + */ + public void finishReadRecoveredState() throws IOException { + assert Thread.holdsLock(inputGate.getGateLock()); + store.addBufferAfterDisk( + EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); + bufferFilteringCompleteFuture.complete(null); + LOG.debug("{}/{} finished recovering input.", inputGate.getOwningTaskName(), channelInfo); } - public void finishReadRecoveredState() throws IOException { - // Adding the event and completing the future must be atomic under receivedBuffers lock. - // Without this, either ordering has a race: - // - event first: task thread consumes EndOfInputChannelStateEvent, which completes - // stateConsumedFuture. When checkpointing during recovery is disabled, - // stateConsumedFuture triggers requestPartitions -> toInputChannel(), which - // fails because bufferFilteringCompleteFuture is not yet done. - // - future first: toInputChannel() extracts buffers before the event is added, - // losing the EndOfInputChannelStateEvent. - // Both toInputChannel() and getNextRecoveredStateBuffer() synchronize on - // receivedBuffers, so holding the same lock here guarantees - // bufferFilteringCompleteFuture is always done before stateConsumedFuture. - synchronized (receivedBuffers) { - onRecoveredStateBuffer( - EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); - bufferFilteringCompleteFuture.complete(null); - } + /** Buffer pool release; intentionally invoked outside the gate lock. */ + public void releaseRecoveryFloatingBuffers() throws IOException { bufferManager.releaseFloatingBuffers(); - LOG.debug("{}/{} finished recovering input.", inputGate.getOwningTaskName(), channelInfo); } @Nullable private BufferAndAvailability getNextRecoveredStateBuffer() throws IOException { + checkState(!isReleased.get(), "Trying to read from released RecoveredInputChannel"); + // tryTake + peekNextDataType under one lock so the consumer never observes a torn view. final Buffer next; final Buffer.DataType nextDataType; - - synchronized (receivedBuffers) { - checkState(!isReleased, "Trying to read from released RecoveredInputChannel"); - next = receivedBuffers.poll(); - nextDataType = peekDataTypeUnsafe(); + synchronized (store) { + next = store.tryTake(); + nextDataType = next != null ? store.peekNextDataType() : Buffer.DataType.NONE; } if (next == null) { @@ -259,18 +225,9 @@ public Optional getNextBuffer() throws IOException { return Optional.ofNullable(getNextRecoveredStateBuffer()); } - private Buffer.DataType peekDataTypeUnsafe() { - assert Thread.holdsLock(receivedBuffers); - - final Buffer first = receivedBuffers.peek(); - return first != null ? first.getDataType() : Buffer.DataType.NONE; - } - @Override int getBuffersInUseCount() { - synchronized (receivedBuffers) { - return receivedBuffers.size(); - } + return store.size(); } @Override @@ -302,34 +259,50 @@ void sendTaskEvent(TaskEvent event) { @Override boolean isReleased() { - synchronized (receivedBuffers) { - return isReleased; - } + return isReleased.get(); } void releaseAllResources() throws IOException { - ArrayDeque releasedBuffers = new ArrayDeque<>(); - boolean shouldRelease = false; - - synchronized (receivedBuffers) { - if (!isReleased) { - isReleased = true; - shouldRelease = true; - releasedBuffers.addAll(receivedBuffers); - receivedBuffers.clear(); + if (isReleased.compareAndSet(false, true)) { + // Abort path: gate.close races requestLock with convertRecoveredInputChannels, so + // storeTransferred=false means conversion never ran and the store is still ours. + // After conversion the physical channel owns the store and releases it itself. + if (!storeTransferred) { + store.releaseAll(); } + bufferManager.releaseAllBuffers(new ArrayDeque<>()); + } + } + + /** Signalled when {@code FilteredBufferDispatcher#close} finishes drain. */ + public void markDrainDone() throws IOException { + drainDone = true; + if (storeTransferred) { + releaseAllResources(); } + } - if (shouldRelease) { - bufferManager.releaseAllBuffers(releasedBuffers); + /** Signalled after the gate slot has been replaced by the physical channel. */ + public void markStoreTransferred() throws IOException { + storeTransferred = true; + if (drainDone) { + releaseAllResources(); } } @VisibleForTesting protected int getNumberOfQueuedBuffers() { - synchronized (receivedBuffers) { - return receivedBuffers.size(); + return store.size(); + } + + /** Non-blocking; returns {@code null} if the pool is exhausted. */ + @Nullable + public Buffer requestBuffer() throws IOException { + if (!exclusiveBuffersAssigned) { + bufferManager.requestExclusiveBuffers(networkBuffersPerChannel); + exclusiveBuffersAssigned = true; } + return bufferManager.requestBuffer(); } public Buffer requestBufferBlocking() throws InterruptedException, IOException { @@ -338,26 +311,7 @@ public Buffer requestBufferBlocking() throws InterruptedException, IOException { bufferManager.requestExclusiveBuffers(networkBuffersPerChannel); exclusiveBuffersAssigned = true; } - if (!inputGate.isCheckpointingDuringRecoveryEnabled()) { - // When checkpoint-during-recovery is not enabled, the original blocking allocation - // is used as-is — no heap buffer fallback, no behavior change from the legacy path. - return bufferManager.requestBufferBlocking(); - } - // Use heap buffer fallback to avoid deadlock during filtering recovery: the filtering - // thread first requests buffers to read state (pre-filter), then requests more buffers - // to write filtered output (post-filter). If pre-filter buffers exhaust the pool, - // post-filter allocation blocks, stalling the thread so pre-filter buffers can never - // be consumed and released — the thread deadlocks itself. Heap buffers bypass the pool - // so post-filter writes always proceed. Both call sites (getBuffer and filterAndRewrite) - // go through this method, so the fallback applies uniformly. - // TODO: replace heap fallback with disk spilling to bound memory usage in FLINK-38544. - Buffer buffer = bufferManager.requestBuffer(); - if (buffer != null) { - return buffer; - } - MemorySegment memorySegment = - MemorySegmentFactory.allocateUnpooledSegment(MemoryManager.DEFAULT_PAGE_SIZE); - return new NetworkBuffer(memorySegment, FreeingBufferRecycler.INSTANCE); + return bufferManager.requestBufferBlocking(); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 66a7d50014067..9740c415eb908 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -87,9 +87,10 @@ public class RemoteInputChannel extends InputChannel { private final ConnectionManager connectionManager; /** - * The received buffers. Received buffers are enqueued by the network I/O thread and the queue - * is consumed by the receiving task thread. + * Buffers enqueued by the network I/O thread and consumed by the task thread. Guarded by the + * {@link #recoveredStore} monitor (the unified channel-private lock). */ + @GuardedBy("recoveredStore") private final PrioritizedDeque receivedBuffers = new PrioritizedDeque<>(); /** @@ -115,14 +116,25 @@ public class RemoteInputChannel extends InputChannel { private final BufferManager bufferManager; - @GuardedBy("receivedBuffers") + @GuardedBy("recoveredStore") private int lastBarrierSequenceNumber = NONE; - @GuardedBy("receivedBuffers") + @GuardedBy("recoveredStore") private long lastBarrierId = NONE; private final ChannelStatePersister channelStatePersister; + /** Always non-null: callers with no recovered data pass {@link RecoveredBufferStore#EMPTY}. */ + private final RecoveredBufferStore recoveredStore; + + /** + * Invariant under lock: {@code hasPendingPriorityEvent <=> + * receivedBuffers.getNumPriorityElements() > 0}. A priority element bypasses the FIFO + * recovery-first rule. + */ + @GuardedBy("recoveredStore") + private boolean hasPendingPriorityEvent = false; + private long totalQueueSizeInBytes; public RemoteInputChannel( @@ -139,7 +151,7 @@ public RemoteInputChannel( Counter numBytesIn, Counter numBuffersIn, ChannelStateWriter stateWriter, - ArrayDeque initialRecoveredBuffers) { + RecoveredBufferStore recoveredStore) { super( inputGate, @@ -157,29 +169,12 @@ public RemoteInputChannel( this.connectionId = checkNotNull(connectionId); this.connectionManager = checkNotNull(connectionManager); this.bufferManager = new BufferManager(inputGate.getMemorySegmentProvider(), this, 0); - this.channelStatePersister = new ChannelStatePersister(stateWriter, getChannelInfo()); - - // Migrate recovered buffers from RecoveredInputChannel if provided. - // These buffers have been filtered but not yet consumed by the Task. - if (!initialRecoveredBuffers.isEmpty()) { - final int expectedCount = initialRecoveredBuffers.size(); - // Sequence number starts at Integer.MIN_VALUE, consistent with RecoveredInputChannel. - int seqNum = Integer.MIN_VALUE; - for (Buffer buffer : initialRecoveredBuffers) { - // subpartitionId is set to 0 for recovered buffers. This is correct because: - // 1) For single-subpartition channels, the only valid subpartition is 0. - // 2) For multi-subpartition channels (consumedSubpartitionIndexSet.size() > 1), - // RecoveryMetadata events embedded in the recovered buffer sequence track - // the actual subpartition context for proper routing. - SequenceBuffer sequenceBuffer = new SequenceBuffer(buffer, seqNum++, 0); - receivedBuffers.add(sequenceBuffer); - totalQueueSizeInBytes += buffer.getSize(); - } - checkState( - receivedBuffers.size() == expectedCount, - "Buffer migration failed: expected %s buffers but got %s", - expectedCount, - receivedBuffers.size()); + + this.recoveredStore = checkNotNull(recoveredStore); + this.channelStatePersister = + new ChannelStatePersister(stateWriter, getChannelInfo(), this.recoveredStore); + synchronized (this.recoveredStore) { + this.recoveredStore.setDataAvailableListener(this::notifyChannelNonEmpty); } } @@ -263,8 +258,8 @@ protected boolean increaseBackoff() { @Override protected int peekNextBufferSubpartitionIdInternal() throws IOException { - synchronized (receivedBuffers) { - checkReadability(); + synchronized (recoveredStore) { + checkPartitionRequestQueueInitialized(); final SequenceBuffer next = receivedBuffers.peek(); @@ -278,24 +273,56 @@ protected int peekNextBufferSubpartitionIdInternal() throws IOException { @Override public Optional getNextBuffer() throws IOException { - final SequenceBuffer next; + // Single critical section so "is recovery done", "is there a priority event", "what is + // the next data type" cannot be torn against the underlying queues. Splitting would let + // producers slip data into receivedBuffers (or drain the store) between segments and + // surface a stale moreAvailable that hides queued buffers from the gate. + final Buffer recoveredBuffer; + final SequenceBuffer fromReceivedBuffers; final DataType nextDataType; - synchronized (receivedBuffers) { - checkReadability(); - - next = receivedBuffers.poll(); - - if (next != null) { - totalQueueSizeInBytes -= next.buffer.getSize(); + synchronized (recoveredStore) { + if (!recoveredStore.isEmpty()) { + if (hasPendingPriorityEvent) { + fromReceivedBuffers = pollPendingPriorityEvent(); + if (fromReceivedBuffers == null) { + // Invariant should keep the flag aligned with priority count; defensive + // yield mirrors pre-refactor behavior. + return Optional.empty(); + } + nextDataType = peekNextDataType(); + recoveredBuffer = null; + } else { + recoveredBuffer = recoveredStore.tryTake(); + if (recoveredBuffer == null) { + // readyBuffers empty but pendingCount > 0: drain listener wakes us. + return Optional.empty(); + } + nextDataType = peekNextDataType(); + fromReceivedBuffers = null; + } + } else { + checkPartitionRequestQueueInitialized(); + fromReceivedBuffers = receivedBuffers.poll(); + if (fromReceivedBuffers != null) { + totalQueueSizeInBytes -= fromReceivedBuffers.buffer.getSize(); + if (receivedBuffers.getNumPriorityElements() == 0) { + hasPendingPriorityEvent = false; + } + } + nextDataType = peekNextDataType(); + recoveredBuffer = null; } - nextDataType = - receivedBuffers.peek() != null - ? receivedBuffers.peek().buffer.getDataType() - : DataType.NONE; } - if (next == null) { + if (recoveredBuffer != null) { + numBytesIn.inc(recoveredBuffer.getSize()); + numBuffersIn.inc(); + return Optional.of( + new BufferAndAvailability(recoveredBuffer, nextDataType, 0, Integer.MIN_VALUE)); + } + + if (fromReceivedBuffers == null) { if (isReleased.get()) { throw new CancelTaskException( "Queried for a buffer after channel has been released."); @@ -305,15 +332,66 @@ public Optional getNextBuffer() throws IOException { NetworkActionsLogger.traceInput( "RemoteInputChannel#getNextBuffer", - next.buffer, + fromReceivedBuffers.buffer, inputGate.getOwningTaskName(), channelInfo, channelStatePersister, - next.sequenceNumber); - numBytesIn.inc(next.buffer.getSize()); + fromReceivedBuffers.sequenceNumber); + numBytesIn.inc(fromReceivedBuffers.buffer.getSize()); numBuffersIn.inc(); return Optional.of( - new BufferAndAvailability(next.buffer, nextDataType, 0, next.sequenceNumber)); + new BufferAndAvailability( + fromReceivedBuffers.buffer, + nextDataType, + 0, + fromReceivedBuffers.sequenceNumber)); + } + + /** + * Data type of the next buffer the consumer will see across {@link #recoveredStore} and {@link + * #receivedBuffers}, respecting the priority bypass. Caller MUST hold {@code recoveredStore}. + */ + @GuardedBy("recoveredStore") + private DataType peekNextDataType() { + assert Thread.holdsLock(recoveredStore); + if (hasPendingPriorityEvent) { + SequenceBuffer peeked = receivedBuffers.peek(); + return peeked != null ? peeked.buffer.getDataType() : DataType.NONE; + } + if (!recoveredStore.isEmpty()) { + // NONE while only pendingCount > 0; drain listener wakes the consumer when + // readyBuffers becomes non-empty. + return recoveredStore.peekNextDataType(); + } + SequenceBuffer peeked = receivedBuffers.peek(); + return peeked != null ? peeked.buffer.getDataType() : DataType.NONE; + } + + /** + * Polls the priority head of {@link #receivedBuffers}, skipping {@link #recoveredStore} and + * clearing {@link #hasPendingPriorityEvent} when the last priority drains. Caller MUST hold + * {@code recoveredStore}. + */ + @GuardedBy("recoveredStore") + @Nullable + private SequenceBuffer pollPendingPriorityEvent() throws IOException { + assert Thread.holdsLock(recoveredStore); + if (!hasPendingPriorityEvent) { + return null; + } + checkPartitionRequestQueueInitialized(); + + SequenceBuffer next = receivedBuffers.poll(); + checkState( + next != null && next.buffer.getDataType().hasPriority(), + "Expected priority event, but got %s", + next == null ? "null" : next.buffer.getDataType()); + totalQueueSizeInBytes -= next.buffer.getSize(); + + if (receivedBuffers.getNumPriorityElements() == 0) { + hasPendingPriorityEvent = false; + } + return next; } // ------------------------------------------------------------------------ @@ -344,8 +422,12 @@ public boolean isReleased() { void releaseAllResources() throws IOException { if (isReleased.compareAndSet(false, true)) { + // EMPTY.releaseAll() is a no-op, so no null check needed. + recoveredStore.releaseAll(); + final ArrayDeque releasedBuffers; - synchronized (receivedBuffers) { + synchronized (recoveredStore) { + hasPendingPriorityEvent = false; releasedBuffers = receivedBuffers.stream() .map(sb -> sb.buffer) @@ -366,7 +448,12 @@ void releaseAllResources() throws IOException { @Override int getBuffersInUseCount() { - return getNumberOfQueuedBuffers() + // Single snapshot under the shared monitor. + int channelBacklog; + synchronized (recoveredStore) { + channelBacklog = recoveredStore.size() + receivedBuffers.size(); + } + return channelBacklog + Math.max(0, bufferManager.getNumberOfRequiredBuffers() - initialCredit); } @@ -430,7 +517,10 @@ boolean isWaitingForFloatingBuffers() { @VisibleForTesting public Buffer getNextReceivedBuffer() { - final SequenceBuffer sequenceBuffer = receivedBuffers.poll(); + final SequenceBuffer sequenceBuffer; + synchronized (recoveredStore) { + sequenceBuffer = receivedBuffers.poll(); + } return sequenceBuffer != null ? sequenceBuffer.buffer : null; } @@ -521,14 +611,14 @@ public int getAndResetUnannouncedCredit() { * @return Buffers queued for processing. */ public int getNumberOfQueuedBuffers() { - synchronized (receivedBuffers) { + synchronized (recoveredStore) { return receivedBuffers.size(); } } @Override public int unsynchronizedGetNumberOfQueuedBuffers() { - return Math.max(0, receivedBuffers.size()); + return recoveredStore.size() + Math.max(0, receivedBuffers.size()); } @Override @@ -577,9 +667,19 @@ public Buffer requestBuffer() { * is less than backlog + initialCredit, it will request floating buffers from the buffer * manager, and then notify unannounced credits to the producer. * + *

No-op while the recovered store is non-empty: credit is gated during recovery so upstream + * cannot send new data. + * * @param backlog The number of unsent buffers in the producer's sub partition. */ public void onSenderBacklog(int backlog) throws IOException { + final boolean stillRecovering; + synchronized (recoveredStore) { + stillRecovering = !recoveredStore.isEmpty(); + } + if (stillRecovering) { + return; + } notifyBufferAvailable(bufferManager.requestFloatingBuffers(backlog + initialCredit)); } @@ -604,7 +704,7 @@ public void onBuffer(Buffer buffer, int sequenceNumber, int backlog, int subpart final boolean wasEmpty; boolean firstPriorityEvent = false; - synchronized (receivedBuffers) { + synchronized (recoveredStore) { NetworkActionsLogger.traceInput( "RemoteInputChannel#onBuffer", buffer, @@ -634,6 +734,9 @@ public void onBuffer(Buffer buffer, int sequenceNumber, int backlog, int subpart firstPriorityEvent = addPriorityBuffer(announce(sequenceBuffer)); } } + if (firstPriorityEvent) { + hasPendingPriorityEvent = true; + } totalQueueSizeInBytes += buffer.getSize(); final OptionalLong barrierId = channelStatePersister.checkForBarrier(sequenceBuffer.buffer); @@ -708,34 +811,31 @@ private void checkAnnouncedOnlyOnce(SequenceBuffer sequenceBuffer) { } /** - * Spills all queued buffers on checkpoint start. If barrier has already been received (and - * reordered), spill only the overtaken buffers. + * Spills all queued buffers on checkpoint start. If the barrier has already been received (and + * reordered), spill only the overtaken buffers. The entire body runs under the store lock so + * {@code checkpointStatus} transitions inside {@code startPersisting} cannot tear against the + * network thread's {@code maybePersist}. */ public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { - synchronized (receivedBuffers) { + synchronized (recoveredStore) { if (barrier.getId() < lastBarrierId) { throw new CheckpointException( String.format( "Sequence number for checkpoint %d is not known (it was likely been overwritten by a newer checkpoint %d)", barrier.getId(), lastBarrierId), - CheckpointFailureReason - .CHECKPOINT_SUBSUMED); // currently, at most one active unaligned - // checkpoint is possible + CheckpointFailureReason.CHECKPOINT_SUBSUMED); } else if (barrier.getId() > lastBarrierId) { - // This channel has received some obsolete barrier, older compared to the - // checkpointId - // which we are processing right now, and we should ignore that obsoleted checkpoint - // barrier sequence number. + // Older barrier already in flight — drop its sequence number; the newer + // checkpoint we just received takes over. resetLastBarrier(); } - channelStatePersister.startPersisting( barrier.getId(), getInflightBuffersUnsafe(barrier.getId())); } } public void checkpointStopped(long checkpointId) { - synchronized (receivedBuffers) { + synchronized (recoveredStore) { channelStatePersister.stopPersisting(checkpointId); if (lastBarrierId == checkpointId) { resetLastBarrier(); @@ -745,7 +845,7 @@ public void checkpointStopped(long checkpointId) { @VisibleForTesting List getInflightBuffers(long checkpointId) { - synchronized (receivedBuffers) { + synchronized (recoveredStore) { return getInflightBuffersUnsafe(checkpointId); } } @@ -753,7 +853,7 @@ List getInflightBuffers(long checkpointId) { @Override public void convertToPriorityEvent(int sequenceNumber) throws IOException { boolean firstPriorityEvent; - synchronized (receivedBuffers) { + synchronized (recoveredStore) { checkState(channelStatePersister.hasBarrierReceived()); int numPriorityElementsBeforeRemoval = receivedBuffers.getNumPriorityElements(); SequenceBuffer toPrioritize = @@ -782,6 +882,9 @@ public void convertToPriorityEvent(int sequenceNumber) throws IOException { addPriorityBuffer( toPrioritize); // note that only position of the element is changed // converting the event itself would require switching the controller sooner + if (firstPriorityEvent) { + hasPendingPriorityEvent = true; + } } if (firstPriorityEvent) { notifyPriorityEventForce(); // forcibly notify about the priority event @@ -799,7 +902,7 @@ private void notifyPriorityEventForce() { * events. */ private List getInflightBuffersUnsafe(long checkpointId) { - assert Thread.holdsLock(receivedBuffers); + assert Thread.holdsLock(recoveredStore); checkState(checkpointId == lastBarrierId || lastBarrierId == NONE); @@ -879,7 +982,7 @@ private boolean shouldBeSpilled(int sequenceNumber) { public void onEmptyBuffer(int sequenceNumber, int backlog) throws IOException { boolean success = false; - synchronized (receivedBuffers) { + synchronized (recoveredStore) { if (!isReleased.get()) { if (expectedSequenceNumber == sequenceNumber) { expectedSequenceNumber++; @@ -903,20 +1006,6 @@ public void onError(Throwable cause) { setError(cause); } - /** - * When receivedBuffers contains migrated buffers from RecoveredInputChannel, they can be read - * before requestSubpartitions(). In that case only check for errors. Once migrated buffers are - * drained, require full client initialization check. - */ - private void checkReadability() throws IOException { - assert Thread.holdsLock(receivedBuffers); - if (receivedBuffers.isEmpty()) { - checkPartitionRequestQueueInitialized(); - } else { - checkError(); - } - } - private void checkPartitionRequestQueueInitialized() throws IOException { checkError(); checkState( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java index 2cfff6f5e7972..99e63c0320d17 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java @@ -20,13 +20,11 @@ import org.apache.flink.runtime.io.network.ConnectionID; import org.apache.flink.runtime.io.network.ConnectionManager; -import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import java.io.IOException; -import java.util.ArrayDeque; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -68,7 +66,7 @@ public class RemoteRecoveredInputChannel extends RecoveredInputChannel { } @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) + protected InputChannel toInputChannelInternal(RecoveredBufferStoreImpl recoveredStore) throws IOException { RemoteInputChannel remoteInputChannel = new RemoteInputChannel( @@ -85,7 +83,7 @@ protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffer numBytesIn, numBuffersIn, channelStateWriter, - remainingBuffers); + recoveredStore); remoteInputChannel.setup(); return remoteInputChannel; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 438efa2f58bd5..b93611e2f9b9f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.io.network.partition.consumer; +import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.core.memory.MemorySegment; @@ -160,6 +161,15 @@ public class SingleInputGate extends IndexedInputGate { /** Channels, which notified this input gate about available data. */ private final PrioritizedDeque inputChannelsWithData = new PrioritizedDeque<>(); + /** + * Returns this gate's intrinsic monitor — see {@link RecoveredBufferStoreImpl} for the {@code + * gate → store} contract that uses it. + */ + @Internal + public Object getGateLock() { + return inputChannelsWithData; + } + /** * Field guaranteeing uniqueness for inputChannelsWithData queue. Both of those fields should be * unified onto one. @@ -407,18 +417,21 @@ public void convertRecoveredInputChannels() { continue; } try { - // Phase 1: Convert channel and release resources outside the lock. - // These calls may acquire the receivedBuffers lock internally, so they - // run outside inputChannelsWithData lock to maintain a consistent lock - // order with onRecoveredStateBuffer() which acquires receivedBuffers - // first and then inputChannelsWithData. - InputChannel realInputChannel = - ((RecoveredInputChannel) inputChannel).toInputChannel(); - inputChannel.releaseAllResources(); - int buffersInUseCount = realInputChannel.getBuffersInUseCount(); - - // Phase 2: Atomically update data structures under the lock. + // toInputChannel + channels[i] swap must be atomic against producer add+fire + // — see RecoveredBufferStoreImpl class javadoc. markStoreTransferred is fired + // outside the gate lock so the (drainDone + storeTransferred) rendezvous's + // BufferManager teardown stays off the gate lock's dependency graph. + // + // The recovered store and its BufferManager are released later (the physical + // channel takes ownership of the store; BufferManager segments are returned + // by BufferRequester#releaseExclusiveBuffers from + // FilteredBufferDispatcher#close). synchronized (inputChannelsWithData) { + InputChannel realInputChannel = + ((RecoveredInputChannel) inputChannel).toInputChannel(); + + int buffersInUseCount = realInputChannel.getBuffersInUseCount(); + if (inputChannelsWithData.contains(inputChannel)) { inputChannelsWithData.getAndRemove(ch -> ch == inputChannel); } @@ -434,12 +447,22 @@ public void convertRecoveredInputChannels() { enqueuedInputChannelsWithData.set(realInputChannel.getChannelIndex()); } } + ((RecoveredInputChannel) inputChannel).markStoreTransferred(); } catch (Throwable t) { inputChannel.setError(t); return; } } } + + try (GateNotificationHelper notification = + new GateNotificationHelper(this, inputChannelsWithData)) { + synchronized (inputChannelsWithData) { + if (!inputChannelsWithData.isEmpty()) { + notification.notifyDataAvailable(); + } + } + } } private void internalRequestPartitions() { @@ -455,11 +478,21 @@ private void internalRequestPartitions() { @Override public void finishReadRecoveredState() throws IOException { - for (final InputChannel channel : channels) { - if (channel instanceof RecoveredInputChannel) { - ((RecoveredInputChannel) channel).finishReadRecoveredState(); + // Pass 1 under the gate lock: EOICS publish + future completion. Pass 2 lock-free: + // floating-buffer release stays off the gate lock's dependency graph. + List recovered = new ArrayList<>(); + synchronized (inputChannelsWithData) { + for (final InputChannel channel : channels) { + if (channel instanceof RecoveredInputChannel) { + RecoveredInputChannel rc = (RecoveredInputChannel) channel; + rc.finishReadRecoveredState(); + recovered.add(rc); + } } } + for (RecoveredInputChannel rc : recovered) { + rc.releaseRecoveryFloatingBuffers(); + } } // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java index 15182cedadb9f..fd85a4325e9c2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java @@ -35,7 +35,6 @@ import javax.annotation.Nullable; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Optional; import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY; @@ -185,7 +184,7 @@ public RemoteInputChannel toRemoteInputChannel( metrics.getNumBytesInRemoteCounter(), metrics.getNumBuffersInRemoteCounter(), channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter, - new ArrayDeque<>()); + RecoveredBufferStore.EMPTY); } public LocalInputChannel toLocalInputChannel(ResultPartitionID resultPartitionID) { @@ -201,7 +200,7 @@ public LocalInputChannel toLocalInputChannel(ResultPartitionID resultPartitionID metrics.getNumBytesInLocalCounter(), metrics.getNumBuffersInLocalCounter(), channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter, - new ArrayDeque<>()); + RecoveredBufferStore.EMPTY); } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java index c2c3efd4b3450..b4091b5f5a6a8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java @@ -30,6 +30,7 @@ import org.apache.flink.runtime.state.filesystem.FsCheckpointStreamFactory; import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory.MemoryCheckpointOutputStream; import org.apache.flink.testutils.junit.utils.TempDirUtils; +import org.apache.flink.util.CloseableIterator; import org.apache.flink.util.function.RunnableWithException; import org.junit.jupiter.api.Test; @@ -410,6 +411,101 @@ void testRecordingOffsets() throws Exception { assertThat(offsetCounts).isEmpty(); } + @Test + void testSpillWriteFormatCompatibility() throws Exception { + // Verify that data written via writeInputFromSpill produces the same byte format + // as data written via the buffer-based path: [4-byte length][data bytes]. + byte[] testData = getData(100); + + // Build expected output via buffer-based path + ByteArrayOutputStream bufferRawStream = new ByteArrayOutputStream(); + DataOutputStream bufferDataStream = new DataOutputStream(bufferRawStream); + ChannelStateSerializerImpl serializer = new ChannelStateSerializerImpl(); + serializer.writeHeader(bufferDataStream); + MemorySegment segment = wrap(testData); + NetworkBuffer buffer = + new NetworkBuffer( + segment, + FreeingBufferRecycler.INSTANCE, + Buffer.DataType.DATA_BUFFER, + segment.size()); + serializer.writeData(bufferDataStream, buffer); + bufferDataStream.flush(); + byte[] bufferPathOutput = bufferRawStream.toByteArray(); + + // Build expected output via spill path: [4-byte length][data bytes] + ByteArrayOutputStream spillRawStream = new ByteArrayOutputStream(); + DataOutputStream spillDataStream = new DataOutputStream(spillRawStream); + serializer.writeHeader(spillDataStream); + spillDataStream.writeInt(testData.length); + spillDataStream.write(testData); + spillDataStream.flush(); + byte[] spillPathOutput = spillRawStream.toByteArray(); + + assertThat(spillPathOutput).isEqualTo(bufferPathOutput); + } + + @Test + void testSpillWriteViaWriter() throws Exception { + // Verify end-to-end that writeInputFromSpill produces the correct state handle + // with proper offsets and state size. + byte[] testData = getData(100); + InputChannelInfo channelInfo = new InputChannelInfo(0, 0); + + // Write via buffer-based path + ChannelStateWriteResult bufferResult = new ChannelStateWriteResult(); + ChannelStateCheckpointWriter bufferWriter = + createWriter(bufferResult, new MemoryCheckpointOutputStream(4096)); + write(bufferWriter, channelInfo, testData); + bufferWriter.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX); + bufferWriter.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX); + + // Write via spill path using a single chunk + ChannelStateWriteResult spillResult = new ChannelStateWriteResult(); + ChannelStateCheckpointWriter spillWriter = + createWriter(spillResult, new MemoryCheckpointOutputStream(4096)); + FilteredSpillFile.Chunk chunk = + new FilteredSpillFile.Chunk(channelInfo, testData, testData.length); + spillWriter.writeInputFromSpill( + JOB_VERTEX_ID, SUBTASK_INDEX, CloseableIterator.ofElements(ignored -> {}, chunk)); + spillWriter.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX); + spillWriter.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX); + + // Compare state handles: offsets and state sizes must match + for (InputChannelStateHandle bufferHandle : bufferResult.inputChannelStateHandles.get()) { + for (InputChannelStateHandle spillHandle : spillResult.inputChannelStateHandles.get()) { + assertThat(spillHandle.getOffsets()).isEqualTo(bufferHandle.getOffsets()); + assertThat(spillHandle.getStateSize()).isEqualTo(bufferHandle.getStateSize()); + } + } + } + + @Test + void testSpillWriteRecordsOffsets() throws Exception { + // Verify that writeInputFromSpill correctly records offsets and state size, + // consistent with the buffer-based write. + int numBytesPerWrite = 50; + InputChannelInfo channelInfo = new InputChannelInfo(0, 0); + + ChannelStateWriteResult result = new ChannelStateWriteResult(); + ChannelStateCheckpointWriter writer = + createWriter(result, new MemoryCheckpointOutputStream(4096)); + + byte[] data = getData(numBytesPerWrite); + FilteredSpillFile.Chunk chunk = new FilteredSpillFile.Chunk(channelInfo, data, data.length); + writer.writeInputFromSpill( + JOB_VERTEX_ID, SUBTASK_INDEX, CloseableIterator.ofElements(ignored -> {}, chunk)); + writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX); + writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX); + + for (InputChannelStateHandle handle : result.inputChannelStateHandles.get()) { + int headerSize = Integer.BYTES; + assertThat(handle.getOffsets()).isEqualTo(singletonList((long) headerSize)); + int lengthSize = Integer.BYTES; + assertThat(handle.getStateSize()).isEqualTo(headerSize + lengthSize + numBytesPerWrite); + } + } + private byte[] getData(int len) { byte[] bytes = new byte[len]; random.nextBytes(bytes); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateChunkReaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateChunkReaderTest.java index 5d58011a3d97f..09edee0004536 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateChunkReaderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateChunkReaderTest.java @@ -136,6 +136,9 @@ public void recover( bufferWithContext.close(); } + @Override + public void finishRecovery() {} + @Override public void close() throws Exception {} } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java index fb931946bbb24..986d1602bd668 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java @@ -27,6 +27,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.CheckpointStorage; import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage; +import org.apache.flink.util.CloseableIterator; import org.apache.flink.util.function.BiConsumerWithException; import org.junit.jupiter.api.Test; @@ -280,6 +281,37 @@ void testNoAddDataAfterClose() throws IOException { .hasCauseInstanceOf(IllegalStateException.class); } + @Test + void testAddInputDataFromSpill() throws Exception { + executeCallbackWithSyncWorker( + (writer, worker) -> { + callStart(writer); + FilteredSpillFile.Chunk chunk = + new FilteredSpillFile.Chunk( + new InputChannelInfo(1, 1), new byte[] {1, 2, 3, 4, 5}, 5); + writer.addInputDataFromSpill( + CHECKPOINT_ID, CloseableIterator.ofElements(ignored -> {}, chunk)); + worker.processAllRequests(); + ChannelStateWriteResult result = writer.getAndRemoveWriteResult(CHECKPOINT_ID); + assertThat(result.isDone()).isFalse(); + }); + } + + @Test + void testAddInputDataFromSpillAfterClose() throws IOException { + ChannelStateWriterImpl writer = openWriter(); + callStart(writer); + writer.close(); + FilteredSpillFile.Chunk chunk = + new FilteredSpillFile.Chunk(new InputChannelInfo(1, 1), new byte[] {1, 2, 3}, 3); + assertThatThrownBy( + () -> + writer.addInputDataFromSpill( + CHECKPOINT_ID, + CloseableIterator.ofElements(ignored -> {}, chunk))) + .hasCauseInstanceOf(IllegalStateException.class); + } + private NetworkBuffer getBuffer() { return new NetworkBuffer( MemorySegmentFactory.allocateUnpooledSegment(123, null), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/EntryPositionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/EntryPositionTest.java new file mode 100644 index 0000000000000..0f740b913cf24 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/EntryPositionTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link EntryPosition}. */ +class EntryPositionTest { + + @Test + void testCompareToOrdersByFileIndexFirst() { + // Two positions in the same file: offset breaks the tie. + EntryPosition a = new EntryPosition(0, 0L); + EntryPosition b = new EntryPosition(0, 100L); + assertThat(a.compareTo(b)).isNegative(); + assertThat(b.compareTo(a)).isPositive(); + + // Different files: file index dominates regardless of offset. + EntryPosition c = new EntryPosition(1, 0L); + assertThat(b.compareTo(c)).isNegative(); + EntryPosition d = new EntryPosition(0, Long.MAX_VALUE - 1); + EntryPosition e = new EntryPosition(1, 0L); + assertThat(d.compareTo(e)).isNegative(); + } + + @Test + void testEndIsGreaterThanEveryRealPosition() { + EntryPosition firstEntry = new EntryPosition(0, 0L); + EntryPosition lateEntry = new EntryPosition(Integer.MAX_VALUE - 1, Long.MAX_VALUE - 1); + assertThat(firstEntry.compareTo(EntryPosition.END)).isNegative(); + assertThat(lateEntry.compareTo(EntryPosition.END)).isNegative(); + assertThat(EntryPosition.END.compareTo(EntryPosition.END)).isZero(); + } + + @Test + void testEqualsAndHashCode() { + EntryPosition a = new EntryPosition(2, 100L); + EntryPosition b = new EntryPosition(2, 100L); + EntryPosition c = new EntryPosition(2, 101L); + EntryPosition d = new EntryPosition(3, 100L); + + assertThat(a).isEqualTo(b).hasSameHashCodeAs(b); + assertThat(a).isNotEqualTo(c); + assertThat(a).isNotEqualTo(d); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FilteredBufferDispatcherTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FilteredBufferDispatcherTest.java new file mode 100644 index 0000000000000..206f7a2d62949 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FilteredBufferDispatcherTest.java @@ -0,0 +1,1794 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredBufferStoreImpl; +import org.apache.flink.util.CloseableIterator; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; + +/** Tests for {@link FilteredBufferDispatcherImpl}. */ +class FilteredBufferDispatcherTest { + + private static final int SEGMENT_SIZE = 64; + + @TempDir Path tempDir; + + private InputChannelInfo ch0; + private InputChannelInfo ch1; + private RecoveredBufferStoreImpl store0; + private RecoveredBufferStoreImpl store1; + private Map stores; + private String[] spillDirs; + + @BeforeEach + void setUp() { + ch0 = new InputChannelInfo(0, 0); + ch1 = new InputChannelInfo(0, 1); + store0 = new RecoveredBufferStoreImpl(ch0); + store1 = new RecoveredBufferStoreImpl(ch1); + stores = new HashMap<>(); + stores.put(ch0, store0); + stores.put(ch1, store1); + spillDirs = new String[] {tempDir.toString()}; + } + + @AfterEach + void tearDown() {} + + /** Records setCoordinator calls without Mockito. */ + private static class TrackingBufferStore extends RecoveredBufferStoreImpl { + private RecoveredBufferStoreCoordinator registeredCoordinator; + private int setCoordinatorCount = 0; + + TrackingBufferStore(InputChannelInfo channelInfo) { + super(channelInfo); + } + + @Override + public synchronized void setCoordinator(RecoveredBufferStoreCoordinator coordinator) { + super.setCoordinator(coordinator); + this.registeredCoordinator = coordinator; + this.setCoordinatorCount++; + } + } + + private Buffer createBuffer() { + return new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(SEGMENT_SIZE), + FreeingBufferRecycler.INSTANCE); + } + + private Queue createBufferPool(int count) { + Queue pool = new LinkedList<>(); + for (int i = 0; i < count; i++) { + pool.add(createBuffer()); + } + return pool; + } + + private byte[] createTestData(int length, byte fillValue) { + byte[] data = new byte[length]; + Arrays.fill(data, fillValue); + return data; + } + + private List drainStore(RecoveredBufferStoreImpl store) { + List result = new ArrayList<>(); + while (true) { + Buffer buf; + synchronized (store) { + buf = store.tryTake(); + } + if (buf == null) { + break; + } + byte[] data = new byte[buf.getSize()]; + buf.getMemorySegment().get(0, data, 0, buf.getSize()); + buf.recycleBuffer(); + result.add(data); + } + return result; + } + + private byte[] concat(List chunks) { + int totalLen = chunks.stream().mapToInt(a -> a.length).sum(); + byte[] result = new byte[totalLen]; + int pos = 0; + for (byte[] chunk : chunks) { + System.arraycopy(chunk, 0, result, pos, chunk.length); + pos += chunk.length; + } + return result; + } + + /** Buffer always available, no disk. All data flows to stores via buffers. */ + @Test + void testP1MemoryPath() throws Exception { + Queue pool = createBufferPool(10); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool)); + + byte[] data = createTestData(SEGMENT_SIZE, (byte) 0xAA); + writer.write(data, data.length, ch0); + writer.flush(); + writer.close(); + + List buffers = drainStore(store0); + byte[] actual = concat(buffers); + assertThat(actual).isEqualTo(data); + assertEmpty(store0); + } + + /** Buffer supplier always returns null. Data goes to disk, replayed on close. */ + @Test + void testP2SpillPath() throws Exception { + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] data = createTestData(SEGMENT_SIZE, (byte) 0xBB); + writer.write(data, data.length, ch0); + writer.flush(); + + assertThat(tryTake(store0)).isNull(); + + writer.drainPendingSpill(); + writer.close(); + + List buffers = drainStore(store0); + byte[] actual = concat(buffers); + assertThat(actual).isEqualTo(data); + assertEmpty(store0); + } + + /** First write spills, then buffer becomes available. P3 replays from disk. */ + @Test + void testP3ReplayPath() throws Exception { + Queue pool = new LinkedList<>(); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool)); + + byte[] data1 = createTestData(SEGMENT_SIZE, (byte) 0x11); + writer.write(data1, data1.length, ch0); + + // Buffers added — next write triggers P3 eager drain. + pool.addAll(createBufferPool(5)); + + byte[] data2 = createTestData(SEGMENT_SIZE, (byte) 0x22); + writer.write(data2, data2.length, ch1); + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + + List buf0 = drainStore(store0); + assertThat(concat(buf0)).isEqualTo(data1); + + List buf1 = drainStore(store1); + assertThat(concat(buf1)).isEqualTo(data2); + } + + /** + * Multiple channels' data goes to disk. Replay order matches FIFO write order across channels. + */ + @Test + void testP3FIFOOrdering() throws Exception { + Queue drainPool = createBufferPool(10); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] d0 = createTestData(SEGMENT_SIZE, (byte) 0x10); + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0x20); + writer.write(d0, d0.length, ch0); + writer.write(d1, d1.length, ch1); + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + + assertThat(concat(drainStore(store0))).isEqualTo(d0); + assertThat(concat(drainStore(store1))).isEqualTo(d1); + } + + /** + * Multiple entries on disk, multiple buffers available. P3 loops until no buffer or disk empty. + */ + @Test + void testP3EagerDrain() throws Exception { + Queue pool = new LinkedList<>(); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool)); + + for (int i = 0; i < 3; i++) { + byte[] d = createTestData(SEGMENT_SIZE, (byte) (0x30 + i)); + writer.write(d, d.length, ch0); + } + + pool.addAll(createBufferPool(10)); + + byte[] d3 = createTestData(SEGMENT_SIZE, (byte) 0x40); + writer.write(d3, d3.length, ch1); + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + + List results = drainStore(store0); + assertThat(results).hasSize(3); + for (int i = 0; i < 3; i++) { + assertThat(results.get(i)).isEqualTo(createTestData(SEGMENT_SIZE, (byte) (0x30 + i))); + } + + assertThat(concat(drainStore(store1))).isEqualTo(d3); + } + + /** + * Start with buffer, buffer fills, no new buffer available. Remaining data goes to file. Cannot + * upgrade back to buffer within one writeToBackend call. + */ + @Test + void testBackendDowngradeOnly() throws Exception { + Queue pool = createBufferPool(1); + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool, drainPool)); + + // First SEGMENT_SIZE goes to buffer, rest to disk. + byte[] data = createTestData(SEGMENT_SIZE * 2, (byte) 0x44); + writer.write(data, data.length, ch0); + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + + List results = drainStore(store0); + byte[] actual = concat(results); + assertThat(actual).isEqualTo(data); + } + + /** Data starts in buffer, spans to file when buffer full. */ + @Test + void testCrossBackendRecordSpanning() throws Exception { + Queue pool = createBufferPool(1); + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool, drainPool)); + + byte[] part1 = createTestData(SEGMENT_SIZE / 2, (byte) 0x55); + writer.write(part1, part1.length, ch0); + + byte[] part2 = createTestData(SEGMENT_SIZE, (byte) 0x66); + writer.write(part2, part2.length, ch0); + + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + + List results = drainStore(store0); + byte[] actual = concat(results); + byte[] expected = new byte[part1.length + part2.length]; + System.arraycopy(part1, 0, expected, 0, part1.length); + System.arraycopy(part2, 0, expected, part1.length, part2.length); + assertThat(actual).isEqualTo(expected); + } + + /** Write to channel A, then B. Verify flush between transitions. */ + @Test + void testChannelChangeDetection() throws Exception { + Queue pool = createBufferPool(10); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool)); + + byte[] d0 = createTestData(SEGMENT_SIZE / 2, (byte) 0x77); + writer.write(d0, d0.length, ch0); + + // Channel switch should flush ch0's partial buffer. + byte[] d1 = createTestData(SEGMENT_SIZE / 2, (byte) 0x88); + writer.write(d1, d1.length, ch1); + + writer.flush(); + writer.close(); + + List results0 = drainStore(store0); + assertThat(concat(results0)).isEqualTo(d0); + + List results1 = drainStore(store1); + assertThat(concat(results1)).isEqualTo(d1); + } + + /** Multiple channels share one spill file. */ + @Test + void testSingleFilePerTask() throws Exception { + Queue drainPool = createBufferPool(10); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] d0 = createTestData(SEGMENT_SIZE, (byte) 0x99); + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0xAA); + writer.write(d0, d0.length, ch0); + writer.write(d1, d1.length, ch1); + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + + assertThat(concat(drainStore(store0))).isEqualTo(d0); + assertThat(concat(drainStore(store1))).isEqualTo(d1); + } + + /** Spill, partial replay, verify tracking state. */ + @Test + void testCursorBasedTracking() throws Exception { + Queue pool = new LinkedList<>(); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool)); + + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0x01); + byte[] d2 = createTestData(SEGMENT_SIZE, (byte) 0x02); + writer.write(d1, d1.length, ch0); + writer.write(d2, d2.length, ch0); + + // Only 1 buffer — partial replay: only 1 of 2 entries drains. + pool.add(createBuffer()); + + byte[] d3 = createTestData(SEGMENT_SIZE, (byte) 0x03); + writer.write(d3, d3.length, ch1); + + pool.addAll(createBufferPool(5)); + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + + List results0 = drainStore(store0); + assertThat(results0).hasSize(2); + assertThat(results0.get(0)).isEqualTo(d1); + assertThat(results0.get(1)).isEqualTo(d2); + } + + /** drainPendingSpill() drains all disk data; close() is then a no-op resource release. */ + @Test + void testDrainPendingSpillUntilEmpty() throws Exception { + Queue drainPool = createBufferPool(10); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + for (int i = 0; i < 5; i++) { + byte[] data = createTestData(SEGMENT_SIZE, (byte) i); + writer.write(data, data.length, ch0); + } + writer.flush(); + + assertThat(tryTake(store0)).isNull(); + + writer.drainPendingSpill(); + + List results = drainStore(store0); + assertThat(results).hasSize(5); + for (int i = 0; i < 5; i++) { + assertThat(results.get(i)).isEqualTo(createTestData(SEGMENT_SIZE, (byte) i)); + } + + writer.close(); + assertThat(tryTake(store0)).isNull(); + } + + /** close() twice doesn't throw; drainPendingSpill() after close() is a no-op. */ + @Test + void testCloseIdempotent() throws Exception { + Queue pool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool)); + + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + writer.close(); + } + + /** After close(), spill files deleted. */ + @Test + void testCloseCleanup() throws Exception { + Queue drainPool = createBufferPool(10); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] data = createTestData(SEGMENT_SIZE * 3, (byte) 0xCC); + writer.write(data, data.length, ch0); + writer.flush(); + + try (Stream files = + Files.list(tempDir).filter(p -> p.getFileName().toString().startsWith("spill-"))) { + assertThat(files.count()).isGreaterThan(0); + } + + writer.drainPendingSpill(); + writer.close(); + + try (Stream files = + Files.list(tempDir).filter(p -> p.getFileName().toString().startsWith("spill-"))) { + assertThat(files.count()).isEqualTo(0); + } + } + + /** write() after close() throws IllegalStateException. */ + @Test + void testWriteAfterClose() throws Exception { + Queue pool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool)); + + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + + byte[] data = createTestData(10, (byte) 0xDD); + assertThatThrownBy(() -> writer.write(data, data.length, ch0)) + .isInstanceOf(IllegalStateException.class); + } + + /** write() after flush() throws IllegalStateException. */ + @Test + void testWriteAfterFlush() throws Exception { + Queue pool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool)); + + writer.flush(); + + byte[] data = createTestData(10, (byte) 0xEE); + assertThatThrownBy(() -> writer.write(data, data.length, ch0)) + .isInstanceOf(IllegalStateException.class); + } + + /** Empty spill dirs throws IOException. */ + @Test + void testSpillDirectorySource() { + assertThatThrownBy( + () -> + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + new String[0], + SEGMENT_SIZE, + TestBufferPool.empty())) + .isInstanceOf(IOException.class); + } + + /** Enough data for many spill entries. All replayed correctly. */ + @Test + void testLargeDataMultiRotation() throws Exception { + // FilteredSpillFile rotates at 64MB; truly testing rotation needs >192MB which is too + // much for a unit test. Verify the mechanism with many entries instead. + int entryCount = 100; + int segmentSize = 256; + Queue drainPool = new LinkedList<>(); + for (int i = 0; i < entryCount + 10; i++) { + drainPool.add( + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(segmentSize), + FreeingBufferRecycler.INSTANCE)); + } + + String[] dirs = new String[] {tempDir.toString()}; + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + dirs, + segmentSize, + TestBufferPool.drainOnly(drainPool)); + + byte[][] expectedData = new byte[entryCount][]; + for (int i = 0; i < entryCount; i++) { + expectedData[i] = new byte[segmentSize]; + Arrays.fill(expectedData[i], (byte) (i & 0xFF)); + writer.write(expectedData[i], expectedData[i].length, ch0); + } + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + + List results = drainStore(store0); + assertThat(results).hasSize(entryCount); + for (int i = 0; i < entryCount; i++) { + assertThat(results.get(i)).isEqualTo(expectedData[i]); + } + } + + /** FilteredBufferDispatcher.write(data, length, channelInfo) is the unified write interface. */ + @Test + void testUnifiedWriteInterface() throws Exception { + Queue pool = createBufferPool(10); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool)); + + byte[] d0 = createTestData(32, (byte) 0xF0); + byte[] d1 = createTestData(32, (byte) 0xF1); + + writer.write(d0, d0.length, ch0); + writer.write(d1, d1.length, ch1); + writer.flush(); + writer.close(); + + assertThat(concat(drainStore(store0))).isEqualTo(d0); + assertThat(concat(drainStore(store1))).isEqualTo(d1); + } + + /** SpillEntry aligns with memorySegmentSize, 1:1 with buffer. */ + @Test + void testBufferAlignedEntryReplay() throws Exception { + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + // 3 * SEGMENT_SIZE bytes → exactly 3 spill entries, 1:1 with buffers. + for (int i = 0; i < 3; i++) { + byte[] data = createTestData(SEGMENT_SIZE, (byte) (0xA0 + i)); + writer.write(data, data.length, ch0); + } + writer.flush(); + writer.drainPendingSpill(); + writer.close(); + + List results = drainStore(store0); + assertThat(results).hasSize(3); + for (int i = 0; i < 3; i++) { + assertThat(results.get(i)).hasSize(SEGMENT_SIZE); + assertThat(results.get(i)).isEqualTo(createTestData(SEGMENT_SIZE, (byte) (0xA0 + i))); + } + } + + /** + * After construction, each store has its coordinator registered to the + * FilteredBufferDispatcherImpl instance. + */ + @Test + void testCoordinatorRegisteredOnConstruction() throws Exception { + TrackingBufferStore trackStore0 = new TrackingBufferStore(ch0); + TrackingBufferStore trackStore1 = new TrackingBufferStore(ch1); + Map trackStores = new HashMap<>(); + trackStores.put(ch0, trackStore0); + trackStores.put(ch1, trackStore1); + + Queue pool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + trackStores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + new TestBufferPool(pool)); + + assertThat(trackStore0.setCoordinatorCount).isEqualTo(1); + assertThat(trackStore0.registeredCoordinator).isSameAs(writer); + assertThat(trackStore1.setCoordinatorCount).isEqualTo(1); + assertThat(trackStore1.registeredCoordinator).isSameAs(writer); + } + + /** + * First onChannelCheckpointStarted call for a checkpointId scans spillEntryQueue and builds the + * correct wait-set; subsequent calls for the same checkpointId remove channels. + */ + @Test + void testWaitSetBuiltOnFirstCallback() throws Exception { + // Spill one entry per channel so both appear in the wait-set. + Queue drainPool = createBufferPool(10); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] d0 = createTestData(SEGMENT_SIZE, (byte) 0x10); + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0x20); + writer.write(d0, d0.length, ch0); + writer.write(d1, d1.length, ch1); + writer.flush(); + + // After flush: 2 spill entries in queue (ch0, ch1). + // First callback for checkpoint 1: wait-set = {ch0, ch1}; then ch0 is removed. + writer.onChannelCheckpointStarted(1L, ch0, writer.getCurrentDrainHead()); + // Second callback for same checkpoint: ch1 is removed → wait-set is now empty. + writer.onChannelCheckpointStarted(1L, ch1, writer.getCurrentDrainHead()); + // No exception; wait-set reached empty — state machine operated correctly. + + writer.drainPendingSpill(); + writer.close(); + } + + /** New checkpointId causes wait-set to be rebuilt from current spillEntryQueue. */ + @Test + void testWaitSetRebuiltOnNewCheckpointId() throws Exception { + Queue drainPool = createBufferPool(10); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + // Spill entries for both channels. + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x30), SEGMENT_SIZE, ch0); + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x40), SEGMENT_SIZE, ch1); + writer.flush(); + + // Checkpoint 1: consume both callbacks (wait-set empties). + writer.onChannelCheckpointStarted(1L, ch0, writer.getCurrentDrainHead()); + writer.onChannelCheckpointStarted(1L, ch1, writer.getCurrentDrainHead()); + + // Checkpoint 2: new id → wait-set should be rebuilt from the *remaining* spillEntryQueue. + // At this point the queue still holds the 2 entries (they are drained only on close). + // Both channels should again appear in the wait-set. + writer.onChannelCheckpointStarted(2L, ch0, writer.getCurrentDrainHead()); + writer.onChannelCheckpointStarted(2L, ch1, writer.getCurrentDrainHead()); + // No exception; both rebuilt and removed successfully. + + writer.drainPendingSpill(); + writer.close(); + } + + /** + * Channel not present in spillEntryQueue is not in wait-set; removing it is a no-op. Uses a + * fresh store map so only ch0 has a spill entry; ch1 callback is a no-op. + */ + @Test + void testCallbackForChannelWithNoPendingEntryIsNoOp() throws Exception { + // Use a fresh store map to avoid interference from previous tests + RecoveredBufferStoreImpl freshStore0 = new RecoveredBufferStoreImpl(ch0); + RecoveredBufferStoreImpl freshStore1 = new RecoveredBufferStoreImpl(ch1); + Map freshStores = new HashMap<>(); + freshStores.put(ch0, freshStore0); + freshStores.put(ch1, freshStore1); + + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + freshStores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + // Only ch0 spills; ch1 has no entries in the queue. + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x50), SEGMENT_SIZE, ch0); + writer.flush(); + + // ch1 callback: not in wait-set → no-op remove, wait-set stays non-empty. + writer.onChannelCheckpointStarted(42L, ch1, writer.getCurrentDrainHead()); + // ch0 callback: removed from wait-set → empty. + writer.onChannelCheckpointStarted(42L, ch0, writer.getCurrentDrainHead()); + + writer.drainPendingSpill(); + writer.close(); + } + + /** + * Duplicate callback for the same channel in the same checkpoint is idempotent (Set.remove on + * an already-absent element is a no-op). + */ + @Test + void testDuplicateCallbackForSameChannelIsIdempotent() throws Exception { + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x70), SEGMENT_SIZE, ch0); + writer.flush(); + + // ch0 removed on first call; second call is a no-op (already absent from set). + writer.onChannelCheckpointStarted(10L, ch0, writer.getCurrentDrainHead()); + writer.onChannelCheckpointStarted( + 10L, ch0, writer.getCurrentDrainHead()); // idempotent — no exception + + writer.drainPendingSpill(); + writer.close(); + } + + /** + * A stale callback (checkpointId < currentCheckpointId) must be ignored: it must not modify + * the current checkpoint's wait-set and must not trigger phase 2. We verify this by observing + * that after the stale callback arrives, the current checkpoint still converges normally when + * its own callbacks arrive. + */ + @Test + void testStaleCheckpointCallbackIsIgnored() throws Exception { + Queue drainPool = createBufferPool(5); + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + // Two entries so both channels appear in the wait-set. + writer.write(createTestData(SEGMENT_SIZE, (byte) 0xA1), SEGMENT_SIZE, ch0); + writer.write(createTestData(SEGMENT_SIZE, (byte) 0xA2), SEGMENT_SIZE, ch1); + writer.flush(); + + // Move to a newer checkpoint (id=20) and deliver one of its two channel callbacks. + writer.onChannelCheckpointStarted(20L, ch0, writer.getCurrentDrainHead()); + // wait-set still contains ch1; phase 2 not yet triggered. + assertThat(recordingWriter.inputDataCalls).isEmpty(); + + // A stale callback for an older checkpoint (id=10) arrives. It must be ignored — not + // alter the wait-set for checkpoint 20, and must not trigger phase 2. + writer.onChannelCheckpointStarted(10L, ch0, writer.getCurrentDrainHead()); + writer.onChannelCheckpointStarted(10L, ch1, writer.getCurrentDrainHead()); + assertThat(recordingWriter.inputDataCalls).isEmpty(); + + // Now deliver the remaining callback for checkpoint 20 — wait-set empties and phase 2 + // snapshots the entries into the ChannelStateWriter. + writer.onChannelCheckpointStarted(20L, ch1, writer.getCurrentDrainHead()); + assertThat(recordingWriter.inputDataCalls).hasSize(2); + + writer.drainPendingSpill(); + writer.close(); + } + + /** + * wait-set reaching empty triggers phase2 {@code drainSpillEntriesToCheckpoint}: the frozen + * readers are snapshotted and streamed to the ChannelStateWriter, but the original readers and + * store state are left intact. drainPendingSpill() then still delivers every entry to the store + * via network buffers. + */ + @Test + void testWaitSetEmptyTriggersPhase2SnapshotThenDrainDeliversBuffers() throws Exception { + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] payload = createTestData(SEGMENT_SIZE, (byte) 0x80); + writer.write(payload, SEGMENT_SIZE, ch0); + writer.flush(); + + // phase2: snapshots frozen readers into the ChannelStateWriter (NO_OP here); the original + // reader and store state are untouched. + writer.onChannelCheckpointStarted(99L, ch0, writer.getCurrentDrainHead()); + + // drainPendingSpill() consumes the original reader and delivers buffers to the store. + writer.drainPendingSpill(); + writer.close(); + + Buffer delivered = tryTake(store0); + assertThat(delivered).isNotNull(); + byte[] actual = new byte[delivered.getSize()]; + delivered.getMemorySegment().get(0, actual, 0, delivered.getSize()); + delivered.recycleBuffer(); + assertThat(actual).isEqualTo(payload); + } + + /** + * Captures addInputDataFromSpill calls for assertions; drains the chunk iterator synchronously + * so tests can assert on actual bytes and channel info. + */ + private static class RecordingChannelStateWriter + extends ChannelStateWriter.NoOpChannelStateWriter { + + static class Call { + final long checkpointId; + final InputChannelInfo info; + final int dataLength; + final byte[] capturedBytes; + + Call(long checkpointId, InputChannelInfo info, int dataLength, byte[] capturedBytes) { + this.checkpointId = checkpointId; + this.info = info; + this.dataLength = dataLength; + this.capturedBytes = capturedBytes; + } + } + + final List inputDataCalls = new ArrayList<>(); + + @Override + public void addInputDataFromSpill( + long checkpointId, CloseableIterator chunks) { + try { + while (chunks.hasNext()) { + FilteredSpillFile.Chunk chunk = chunks.next(); + byte[] bytes = new byte[chunk.getLength()]; + System.arraycopy(chunk.getData(), 0, bytes, 0, chunk.getLength()); + inputDataCalls.add( + new Call( + checkpointId, + chunk.getChannelInfo(), + chunk.getLength(), + bytes)); + } + chunks.close(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + /** + * phase2 writes all spill entries to ChannelStateWriter via streaming addInputData. Verifies + * checkpointId, channelInfo, seqNum=SEQUENCE_NUMBER_RESTORED, and byte content. + */ + @Test + void testPhase2WritesDiskDataThroughStreamingApi() throws Exception { + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + + // drainOnly: no buffers for the write path (forces everything to spill), but the blocking + // drain path gets buffers so close()'s drain loop can still deliver the snapshotted entries + // to the stores. phase 2 is a backup — close() drain is the task-facing delivery path. + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] d0 = createTestData(SEGMENT_SIZE, (byte) 0xA1); + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0xA2); + writer.write(d0, d0.length, ch0); + writer.write(d1, d1.length, ch1); + writer.flush(); + + // Trigger phase2: all channels report in, wait-set empties on second callback + long checkpointId = 42L; + writer.onChannelCheckpointStarted(checkpointId, ch0, writer.getCurrentDrainHead()); + writer.onChannelCheckpointStarted(checkpointId, ch1, writer.getCurrentDrainHead()); + + // Two entries must have been streamed to ChannelStateWriter + assertThat(recordingWriter.inputDataCalls).hasSize(2); + + RecordingChannelStateWriter.Call call0 = recordingWriter.inputDataCalls.get(0); + assertThat(call0.checkpointId).isEqualTo(checkpointId); + assertThat(call0.info).isEqualTo(ch0); + assertThat(call0.dataLength).isEqualTo(SEGMENT_SIZE); + assertThat(call0.capturedBytes).isEqualTo(d0); + + RecordingChannelStateWriter.Call call1 = recordingWriter.inputDataCalls.get(1); + assertThat(call1.checkpointId).isEqualTo(checkpointId); + assertThat(call1.info).isEqualTo(ch1); + assertThat(call1.dataLength).isEqualTo(SEGMENT_SIZE); + assertThat(call1.capturedBytes).isEqualTo(d1); + + writer.drainPendingSpill(); + writer.close(); + } + + /** + * phase 2 is a snapshot-only backup: it must NOT decrement {@code store.pendingCount}. The + * original reader and the store state are left untouched so close()'s drain loop still has + * these entries to deliver to the task via network buffers. + */ + @Test + void testPhase2DoesNotTouchStorePendingCount() throws Exception { + Queue drainPool = createBufferPool(3); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + // Spill 2 entries for ch0, 1 for ch1 — each exactly SEGMENT_SIZE so they auto-seal + writer.write(createTestData(SEGMENT_SIZE, (byte) 0xB1), SEGMENT_SIZE, ch0); + writer.write(createTestData(SEGMENT_SIZE, (byte) 0xB2), SEGMENT_SIZE, ch1); + writer.write(createTestData(SEGMENT_SIZE, (byte) 0xB3), SEGMENT_SIZE, ch0); + writer.flush(); + + // Before phase 2: both stores non-empty + assertNotEmpty(store0); + assertNotEmpty(store1); + + // Phase 2: all callbacks arrive — entries are copied to checkpoint, but pendingCount stays + long checkpointId = 7L; + writer.onChannelCheckpointStarted(checkpointId, ch0, writer.getCurrentDrainHead()); + writer.onChannelCheckpointStarted(checkpointId, ch1, writer.getCurrentDrainHead()); + + // pendingCount untouched — stores still report non-empty + assertNotEmpty(store0); + assertNotEmpty(store1); + + // drainPendingSpill() delivers all entries to stores; only then do the counts go to zero + writer.drainPendingSpill(); + writer.close(); + // Drain the ready buffers so isEmpty() reflects pendingCount only + drainReady(store0); + drainReady(store1); + assertEmpty(store0); + assertEmpty(store1); + } + + /** + * After phase 2 snapshots entries into the checkpoint, drainPendingSpill() must still deliver + * every entry to the stores (phase 2 is a backup, not an ownership transfer). + */ + @Test + void testDrainStillDeliversEntriesAfterPhase2() throws Exception { + Queue drainPool = createBufferPool(1); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] payload = createTestData(SEGMENT_SIZE, (byte) 0xC1); + writer.write(payload, SEGMENT_SIZE, ch0); + writer.flush(); + + writer.onChannelCheckpointStarted(55L, ch0, writer.getCurrentDrainHead()); + + // drainPendingSpill() consumes the still-pending entry and delivers it to the store + writer.drainPendingSpill(); + writer.close(); + + Buffer delivered = tryTake(store0); + assertThat(delivered).isNotNull(); + byte[] actual = new byte[delivered.getSize()]; + delivered.getMemorySegment().get(0, actual, 0, delivered.getSize()); + delivered.recycleBuffer(); + assertThat(actual).isEqualTo(payload); + } + + /** + * Two independent consumers: phase 2 writes every entry into the checkpoint via + * ChannelStateWriter, and drainPendingSpill() additionally delivers every entry to the stores. + * Both streams see the full data — the on-disk bytes are read twice via independent + * FileChannels. + */ + @Test + void testPhase2AndDrainBothReceiveAllEntries() throws Exception { + Queue drainPool = createBufferPool(2); + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] payload0 = createTestData(SEGMENT_SIZE, (byte) 0xD1); + byte[] payload1 = createTestData(SEGMENT_SIZE, (byte) 0xD2); + writer.write(payload0, SEGMENT_SIZE, ch0); + writer.write(payload1, SEGMENT_SIZE, ch1); + writer.flush(); + + // Phase 2: both entries captured into ChannelStateWriter (checkpoint backup) + long checkpointId = 100L; + writer.onChannelCheckpointStarted(checkpointId, ch0, writer.getCurrentDrainHead()); + writer.onChannelCheckpointStarted(checkpointId, ch1, writer.getCurrentDrainHead()); + assertThat(recordingWriter.inputDataCalls).hasSize(2); + + // drainPendingSpill(): both entries additionally delivered to the stores (task-facing + // pipeline) + writer.drainPendingSpill(); + writer.close(); + + Buffer buf0 = tryTake(store0); + Buffer buf1 = tryTake(store1); + assertThat(buf0).isNotNull(); + assertThat(buf1).isNotNull(); + byte[] got0 = new byte[buf0.getSize()]; + byte[] got1 = new byte[buf1.getSize()]; + buf0.getMemorySegment().get(0, got0, 0, buf0.getSize()); + buf1.getMemorySegment().get(0, got1, 0, buf1.getSize()); + buf0.recycleBuffer(); + buf1.recycleBuffer(); + assertThat(got0).isEqualTo(payload0); + assertThat(got1).isEqualTo(payload1); + } + + /** + * After a store is released, its pending disk entries must be dropped from every Reader so the + * dispatcher's subsequent close() drain does not try to deliver bytes for the gone channel. + */ + @Test + void testReleaseAllRemovesChannelDiskEntriesEagerly() throws Exception { + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] d0 = createTestData(SEGMENT_SIZE, (byte) 0x51); + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0x52); + writer.write(d0, d0.length, ch0); + writer.write(d1, d1.length, ch1); + writer.flush(); + + // Release store0 before draining — this should propagate to the dispatcher, which drops + // all ch0 entries from the Readers. + store0.releaseAll(); + + writer.drainPendingSpill(); + writer.close(); + + // store1 still receives its data; store0 must stay empty since ch0 entries were dropped. + assertThat(tryTake(store0)).isNull(); + assertThat(concat(drainStore(store1))).isEqualTo(d1); + } + + /** + * When a channel is released while an in-flight checkpoint wait-set still contains it, the + * dispatcher must remove it from the wait-set so the wait-set can still converge to empty and + * phase-2 drain is not blocked. + */ + @Test + void testReleaseAllConvergesInFlightCheckpointWaitSet() throws Exception { + Queue drainPool = createBufferPool(5); + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x61), SEGMENT_SIZE, ch0); + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x62), SEGMENT_SIZE, ch1); + writer.flush(); + + // ch0 reports in. Wait-set still contains ch1 so phase 2 must not have fired yet. + writer.onChannelCheckpointStarted(30L, ch0, writer.getCurrentDrainHead()); + assertThat(recordingWriter.inputDataCalls).isEmpty(); + + // ch1 is released before its checkpoint callback ever arrives. The dispatcher removes it + // from the wait-set, which now empties and triggers phase 2. + store1.releaseAll(); + + // Only ch0's entry made it into the checkpoint backup; ch1's entries were dropped on + // release. + assertThat(recordingWriter.inputDataCalls).hasSize(1); + assertThat(recordingWriter.inputDataCalls.get(0).info).isEqualTo(ch0); + + writer.drainPendingSpill(); + writer.close(); + } + + /** + * After a checkpoint is aborted (i.e. all stores called notifyCheckpointStopped), a subsequent + * channel release must NOT trigger a phase-2 drain into the stopped checkpoint — the writer for + * that id is gone and any drain would either be wasted work or rely on the writer's isDone() + * guard to silently swallow the data. + */ + @Test + void testReleaseAfterStoppedCheckpointDoesNotDrainStoppedCheckpoint() throws Exception { + Queue drainPool = createBufferPool(5); + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x71), SEGMENT_SIZE, ch0); + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x72), SEGMENT_SIZE, ch1); + writer.flush(); + + // Checkpoint 50 starts on ch0; wait-set still contains ch1. + writer.onChannelCheckpointStarted(50L, ch0, writer.getCurrentDrainHead()); + assertThat(recordingWriter.inputDataCalls).isEmpty(); + + // The task aborts checkpoint 50 — every channel's persister fires notifyCheckpointStopped. + store0.notifyCheckpointStopped(50L); + store1.notifyCheckpointStopped(50L); + + // Now ch1 is released. Without the stopped-checkpoint short-circuit, the wait-set would + // empty and the dispatcher would drain to checkpoint 50; with the fix, no drain fires. + store1.releaseAll(); + + assertThat(recordingWriter.inputDataCalls).isEmpty(); + + writer.drainPendingSpill(); + writer.close(); + } + + /** + * A late {@code onChannelCheckpointStarted} for a checkpoint that has already been stopped must + * be ignored as stale, even if a new checkpoint has not yet started. + */ + @Test + void testLateCheckpointStartedAfterStoppedIsIgnored() throws Exception { + Queue drainPool = createBufferPool(5); + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x81), SEGMENT_SIZE, ch0); + writer.flush(); + + // Stop checkpoint 60 before anyone reports in. + store0.notifyCheckpointStopped(60L); + store1.notifyCheckpointStopped(60L); + + // A late onChannelCheckpointStarted(60, ...) shows up. It must be short-circuited. + writer.onChannelCheckpointStarted(60L, ch0, writer.getCurrentDrainHead()); + writer.onChannelCheckpointStarted(60L, ch1, writer.getCurrentDrainHead()); + + assertThat(recordingWriter.inputDataCalls).isEmpty(); + + writer.drainPendingSpill(); + writer.close(); + } + + /** + * A new checkpoint started AFTER a stop notification must still progress normally — the + * stopped-id short-circuit only skips the exact stopped id, not all subsequent checkpoints. + */ + @Test + void testCheckpointAfterStoppedStillProgresses() throws Exception { + Queue drainPool = createBufferPool(5); + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x91), SEGMENT_SIZE, ch0); + writer.write(createTestData(SEGMENT_SIZE, (byte) 0x92), SEGMENT_SIZE, ch1); + writer.flush(); + + // Abort checkpoint 70. + store0.notifyCheckpointStopped(70L); + store1.notifyCheckpointStopped(70L); + + // Checkpoint 71 begins; both channels report in and phase-2 fires for 71. + writer.onChannelCheckpointStarted(71L, ch0, writer.getCurrentDrainHead()); + writer.onChannelCheckpointStarted(71L, ch1, writer.getCurrentDrainHead()); + + assertThat(recordingWriter.inputDataCalls).hasSize(2); + assertThat(recordingWriter.inputDataCalls.get(0).checkpointId).isEqualTo(71L); + assertThat(recordingWriter.inputDataCalls.get(1).checkpointId).isEqualTo(71L); + + writer.drainPendingSpill(); + writer.close(); + } + + /** + * A BufferRequester whose blocking path parks indefinitely until interrupted. Used to verify + * that callers that should NOT invoke requestBufferBlocking (e.g. close()) are not blocked. + */ + private static final class BlockingForeverBufferRequester implements BufferRequester { + + /** Signals that a blocking request is in flight. */ + final CountDownLatch blockingStarted = new CountDownLatch(1); + + /** Unblocks the blocking requester when a real buffer should be delivered. */ + final SynchronousQueue releaseQueue = new SynchronousQueue<>(); + + @Override + public Buffer requestBuffer(InputChannelInfo channelInfo) { + return null; // write path always spills + } + + @Override + public Buffer requestBufferBlocking(InputChannelInfo channelInfo) + throws InterruptedException { + blockingStarted.countDown(); + // Park until a buffer arrives through releaseQueue or the thread is interrupted. + return releaseQueue.take(); + } + + @Override + public void releaseExclusiveBuffers() { + // No-op: this test fixture has no per-channel buffer manager to tear down. + } + } + + /** + * AT-CABT: abort path — skip drainPendingSpill, call close() directly. close() must return + * promptly (it must not invoke requestBufferBlocking) and delete all spill files. + */ + @Test + void testCloseWithoutDrainReleasesResources() throws Exception { + // requestBufferBlocking parks indefinitely — if close() mistakenly calls it the test hangs. + BlockingForeverBufferRequester requester = new BlockingForeverBufferRequester(); + + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, ChannelStateWriter.NO_OP, spillDirs, SEGMENT_SIZE, requester); + + byte[] data = createTestData(SEGMENT_SIZE, (byte) 0xAB); + writer.write(data, data.length, ch0); + writer.flush(); + + // Verify spill file exists before close. + try (Stream files = + Files.list(tempDir).filter(p -> p.getFileName().toString().startsWith("spill-"))) { + assertThat(files.count()).isGreaterThan(0); + } + + // close() skips drainPendingSpill entirely: must complete within 5 s (not block on buffer). + assertTimeoutPreemptively( + Duration.ofSeconds(5), + () -> assertDoesNotThrow(writer::close), + "close() blocked — it incorrectly called requestBufferBlocking"); + + // Spill files must be deleted on close. + try (Stream files = + Files.list(tempDir).filter(p -> p.getFileName().toString().startsWith("spill-"))) { + assertThat(files.count()).isEqualTo(0); + } + + // The written bytes were intentionally dropped (abort semantics). Store must remain empty. + assertThat(tryTake(store0)).isNull(); + } + + /** + * AT-INTR: drainPendingSpill() is interruptible. When the drain thread blocks on + * requestBufferBlocking and the thread is interrupted, drainPendingSpill() must propagate + * InterruptedException. A subsequent close() must still release resources. + */ + @Test + void testDrainPendingSpillInterruptible() throws Exception { + BlockingForeverBufferRequester requester = new BlockingForeverBufferRequester(); + + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, ChannelStateWriter.NO_OP, spillDirs, SEGMENT_SIZE, requester); + + byte[] data = createTestData(SEGMENT_SIZE, (byte) 0xBC); + writer.write(data, data.length, ch0); + writer.flush(); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future drainFuture = + executor.submit( + () -> { + try { + writer.drainPendingSpill(); + return null; // unexpected: should have thrown + } catch (InterruptedException e) { + // restore flag for good measure + Thread.currentThread().interrupt(); + return e; + } catch (Exception e) { + return e; + } + }); + + // Wait until drainPendingSpill() is actually blocked. + assertThat(requester.blockingStarted.await(5, TimeUnit.SECONDS)) + .as("drain thread did not enter blocking state in time") + .isTrue(); + + // Interrupt the drain thread. + executor.shutdownNow(); + + Exception thrown = drainFuture.get(5, TimeUnit.SECONDS); + assertThat(thrown) + .as("drainPendingSpill() must throw InterruptedException on interrupt") + .isInstanceOf(InterruptedException.class); + + // Even after an interrupted drain, close() must release resources without throwing. + assertTimeoutPreemptively( + Duration.ofSeconds(5), + () -> assertDoesNotThrow(writer::close), + "close() blocked after interrupted drain"); + + // Spill files must be cleaned up. + try (Stream files = + Files.list(tempDir).filter(p -> p.getFileName().toString().startsWith("spill-"))) { + assertThat(files.count()).isEqualTo(0); + } + } + + /** + * Phase 2 must respect each channel's recorded {@code startPos}: entries with position strictly + * less than the channel's startPos are skipped (their channel's Step 1 already captured them + * via readyBuffers), while entries at or beyond startPos are emitted to the channel state. Two + * channels with different startPos values let us verify the filter is applied per-channel and + * not globally. + */ + @Test + void testPhase2FilterByPerChannelStartPos() throws Exception { + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(createBufferPool(0))); + + // Layout in the spill file: ch0 entries at offsets 0 and 2*SEGMENT_SIZE, ch1 entry at + // offset SEGMENT_SIZE. Total 3 entries, FIFO order. + byte[] d0a = createTestData(SEGMENT_SIZE, (byte) 0x10); + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0x20); + byte[] d0b = createTestData(SEGMENT_SIZE, (byte) 0x11); + writer.write(d0a, SEGMENT_SIZE, ch0); + writer.write(d1, SEGMENT_SIZE, ch1); + writer.write(d0b, SEGMENT_SIZE, ch0); + writer.flush(); + + // ch0's barrier passed before any drain — its startPos is the head of the file (include + // every ch0 entry). ch1's barrier passed after the first ch0 entry was logically drained — + // its startPos points at the third entry (file 0, offset 2*SEGMENT_SIZE), so its single + // entry at offset SEGMENT_SIZE must be skipped (covered by Step 1 ready snapshot). + EntryPosition ch0StartPos = new EntryPosition(0, 0); + EntryPosition ch1StartPos = new EntryPosition(0, 2L * SEGMENT_SIZE); + writer.onChannelCheckpointStarted(7L, ch0, ch0StartPos); + writer.onChannelCheckpointStarted(7L, ch1, ch1StartPos); + + writer.close(); + + // ch0 receives both of its entries; ch1 receives nothing in phase 2. + List ch0Bytes = + recordingWriter.inputDataCalls.stream() + .filter(c -> c.info.equals(ch0)) + .map(c -> c.capturedBytes) + .collect(java.util.stream.Collectors.toList()); + List ch1Bytes = + recordingWriter.inputDataCalls.stream() + .filter(c -> c.info.equals(ch1)) + .map(c -> c.capturedBytes) + .collect(java.util.stream.Collectors.toList()); + assertThat(ch0Bytes).hasSize(2); + assertThat(ch0Bytes.get(0)).isEqualTo(d0a); + assertThat(ch0Bytes.get(1)).isEqualTo(d0b); + assertThat(ch1Bytes).isEmpty(); + } + + /** + * The phase-2 snapshot is captured at the first {@code onChannelCheckpointStarted} + * call, not at wait-set convergence. This guards the original race: between the first and last + * trigger, drainPendingSpill could pop entries off the original Reader; the pinned snapshot + * preserves the pre-drain view so phase-2 sees every entry that was on disk when the first + * channel passed its barrier. + * + *

Scenario: write three entries for two channels, fire ch0's trigger to capture the + * snapshot, drain everything (which empties the original Reader's deque), then fire ch1's + * trigger to converge the wait-set. Phase 2 must still report all three entries because the + * snapshot was taken before drain ran. + */ + @Test + void testPhase2SnapshotPinnedAtFirstTrigger() throws Exception { + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] d0 = createTestData(SEGMENT_SIZE, (byte) 0xC0); + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0xC1); + byte[] d2 = createTestData(SEGMENT_SIZE, (byte) 0xC2); + writer.write(d0, SEGMENT_SIZE, ch0); + writer.write(d1, SEGMENT_SIZE, ch1); + writer.write(d2, SEGMENT_SIZE, ch0); + writer.flush(); + + // First trigger arrives BEFORE drain. The snapshot pins the full disk view here. + EntryPosition initialDrainHead = writer.getCurrentDrainHead(); + writer.onChannelCheckpointStarted(11L, ch0, initialDrainHead); + + // drainPendingSpill empties the ORIGINAL Reader. Phase 2 must read from the pinned + // snapshot, not the post-drain Reader, otherwise we lose every entry. + writer.drainPendingSpill(); + + // Now ch1's trigger converges the wait-set, firing phase-2. + writer.onChannelCheckpointStarted(11L, ch1, initialDrainHead); + writer.close(); + + // All three entries must show up in phase-2 with their original channel info. + List calls = recordingWriter.inputDataCalls; + assertThat(calls).hasSize(3); + assertThat(calls.get(0).info).isEqualTo(ch0); + assertThat(calls.get(0).capturedBytes).isEqualTo(d0); + assertThat(calls.get(1).info).isEqualTo(ch1); + assertThat(calls.get(1).capturedBytes).isEqualTo(d1); + assertThat(calls.get(2).info).isEqualTo(ch0); + assertThat(calls.get(2).capturedBytes).isEqualTo(d2); + } + + /** + * {@code onChannelCheckpointStopped} must close pinned snapshot Readers and clear the in- + * progress checkpoint state, otherwise every aborted checkpoint leaks one {@code FileChannel} + * per spill file and a stale per-channel startPos map can poison the next checkpoint's filter. + * Following an abort, a fresh checkpoint must produce a complete phase-2 from a brand-new + * snapshot. + */ + @Test + void testCheckpointStoppedReleasesSnapshotsAndStateForNextCheckpoint() throws Exception { + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(createBufferPool(0))); + + byte[] d0 = createTestData(SEGMENT_SIZE, (byte) 0xD0); + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0xD1); + writer.write(d0, SEGMENT_SIZE, ch0); + writer.write(d1, SEGMENT_SIZE, ch1); + writer.flush(); + + // Start checkpoint 5; only ch0 reports in, then both channels stop the checkpoint + // (e.g. abort path) before ch1 reports. + writer.onChannelCheckpointStarted(5L, ch0, writer.getCurrentDrainHead()); + writer.onChannelCheckpointStopped(5L, ch0); + writer.onChannelCheckpointStopped(5L, ch1); + // Phase 2 for ckpt 5 was never submitted because the wait-set never converged before stop. + assertThat(recordingWriter.inputDataCalls).isEmpty(); + + // A subsequent checkpoint 6 starts cleanly and converges normally — must include both + // entries despite the dangling state from the aborted checkpoint 5. + EntryPosition startPos = writer.getCurrentDrainHead(); + writer.onChannelCheckpointStarted(6L, ch0, startPos); + writer.onChannelCheckpointStarted(6L, ch1, startPos); + writer.close(); + + assertThat(recordingWriter.inputDataCalls).hasSize(2); + assertThat(recordingWriter.inputDataCalls.get(0).checkpointId).isEqualTo(6L); + assertThat(recordingWriter.inputDataCalls.get(1).checkpointId).isEqualTo(6L); + } + + /** + * The {@code drainHead} field must advance only after each drain bundle's {@code addBuffer}, so + * an external observer reading {@code getCurrentDrainHead()} can rely on the invariant + * "drainHead crossed e ⇒ e is in store.readyBuffers". This test exercises the public + * observable: at flush() drainHead points at the first entry; after each drain bundle commits, + * drainHead advances to the next entry; after the last entry drainHead reaches END. + */ + @Test + void testDrainHeadAdvancesAfterEachAddBuffer() throws Exception { + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + ChannelStateWriter.NO_OP, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] d0 = createTestData(SEGMENT_SIZE, (byte) 0xE0); + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0xE1); + writer.write(d0, SEGMENT_SIZE, ch0); + writer.write(d1, SEGMENT_SIZE, ch1); + writer.flush(); + + // After flush, drainHead points at the first entry of file 0. + EntryPosition headAfterFlush = writer.getCurrentDrainHead(); + assertThat(headAfterFlush.getFileIndex()).isEqualTo(0); + assertThat(headAfterFlush.getOffset()).isEqualTo(0L); + + writer.drainPendingSpill(); + + // After drain consumes everything, drainHead reaches the END sentinel. + assertThat(writer.getCurrentDrainHead()).isEqualTo(EntryPosition.END); + + writer.close(); + } + + /** + * Before the {@link #drainPendingSpill()} bundle was made atomic, a phase-2 snapshot taken in + * the gap between {@code reader.skipNextEntry()} (entry gone from disk-side bookkeeping) and + * {@code store.addBuffer()} (entry not yet in the channel's readyBuffers) would lose the entry: + * Step 1 captures no buffer, phase-2 sees no entry. This test exercises the invariant by + * injecting an {@code onChannelCheckpointStarted} call between two drain bundles and asserting + * that every spilled entry appears either in the channel's readyBuffers (Step 1) or in phase-2 + * capture — never neither. + */ + @Test + void testNoEntryLostBetweenDrainAndCheckpointTrigger() throws Exception { + RecordingChannelStateWriter recordingWriter = new RecordingChannelStateWriter(); + Queue drainPool = createBufferPool(5); + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, + recordingWriter, + spillDirs, + SEGMENT_SIZE, + TestBufferPool.drainOnly(drainPool)); + + byte[] d0a = createTestData(SEGMENT_SIZE, (byte) 0x60); + byte[] d0b = createTestData(SEGMENT_SIZE, (byte) 0x61); + byte[] d0c = createTestData(SEGMENT_SIZE, (byte) 0x62); + byte[] d1 = createTestData(SEGMENT_SIZE, (byte) 0x70); + writer.write(d0a, SEGMENT_SIZE, ch0); + writer.write(d1, SEGMENT_SIZE, ch1); + writer.write(d0b, SEGMENT_SIZE, ch0); + writer.write(d0c, SEGMENT_SIZE, ch0); + writer.flush(); + + // ch0's barrier arrives BEFORE drain runs — every ch0 entry is "in flight" for ckpt 9. + EntryPosition ch0StartPos = writer.getCurrentDrainHead(); + writer.onChannelCheckpointStarted(9L, ch0, ch0StartPos); + + // Drain runs to completion: entries are popped from the original reader and addBuffered + // to their stores. The phase-2 snapshot was pinned before drain so it still owns the full + // disk view; per-channel filtering decides whether each snapshot entry is emitted. + writer.drainPendingSpill(); + + // ch1 reports in AFTER drain — its startPos is END (everything already drained for ch1's + // perspective). Phase-2 must skip every ch1 snapshot entry (Step 1 captured them via + // store readyBuffers) and emit the full ch0 set. + EntryPosition ch1StartPos = writer.getCurrentDrainHead(); + writer.onChannelCheckpointStarted(9L, ch1, ch1StartPos); + writer.close(); + + // Aggregate every ch0 byte the dispatcher told the world about: phase-2 emits + buffers + // the store actually received via drain. The set must equal the original three ch0 writes + // — no entry can fall through both paths. + List phase2Ch0 = + recordingWriter.inputDataCalls.stream() + .filter(c -> c.info.equals(ch0)) + .map(c -> c.capturedBytes) + .collect(java.util.stream.Collectors.toList()); + List readyCh0 = drainStore(store0); + // ch0 trigger fired before drain → every ch0 spill entry has p_e >= startPos_ch0, + // so phase 2 emits all three. Drain afterwards still adds them to readyBuffers, but the + // race fix only guarantees coverage (no loss); a buffer being delivered post-trigger is + // expected and Task will consume it as ckpt N+1's input. + assertThat(phase2Ch0).hasSize(3); + assertThat(phase2Ch0).containsExactly(d0a, d0b, d0c); + // For ch1: ch1's barrier was after drain, so its single entry is captured by Step 1 + // (readyBuffers) and skipped in phase 2. + List phase2Ch1 = + recordingWriter.inputDataCalls.stream() + .filter(c -> c.info.equals(ch1)) + .map(c -> c.capturedBytes) + .collect(java.util.stream.Collectors.toList()); + List readyCh1 = drainStore(store1); + assertThat(phase2Ch1).isEmpty(); + assertThat(readyCh1).hasSize(1); + assertThat(readyCh1.get(0)).isEqualTo(d1); + // Sanity: the ready set for ch0 contains every original byte too (drain delivered them + // post-trigger), but the test's correctness hinges on the phase-2 set being complete. + assertThat(readyCh0).hasSize(3); + } + + /** + * AT-LOCK: FLINK-39519 deadlock regression. drainPendingSpill() must NOT hold the dispatcher + * monitor while blocking on requestBufferBlocking. Concurrent onChannelCheckpointStopped (which + * acquires the dispatcher monitor) must complete promptly while drain is blocked. + * + *

Before the close/drain split, drainSpillThroughBuffers() ran inside a synchronized block, + * so onChannelCheckpointStopped would deadlock waiting for the monitor held by drain. This test + * reproduces that scenario and asserts the callback completes in time. + */ + @Test + void testDrainPendingSpillReleasesMonitorForCheckpointStopped() throws Exception { + BlockingForeverBufferRequester requester = new BlockingForeverBufferRequester(); + + FilteredBufferDispatcherImpl writer = + new FilteredBufferDispatcherImpl( + stores, ChannelStateWriter.NO_OP, spillDirs, SEGMENT_SIZE, requester); + + // Set up checkpoint wait-set state so onChannelCheckpointStopped follows the full path. + // notifyCheckpointStopped on the stores will call writer.onChannelCheckpointStopped. + long checkpointId = 42L; + + byte[] data = createTestData(SEGMENT_SIZE, (byte) 0xCD); + writer.write(data, data.length, ch0); + writer.flush(); + + // Build the wait-set for checkpoint 42 by having ch0 report in. + writer.onChannelCheckpointStarted(checkpointId, ch0, writer.getCurrentDrainHead()); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future drainFuture = + executor.submit( + () -> { + try { + writer.drainPendingSpill(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (Exception ignored) { + } + }); + + // Wait until drainPendingSpill() is blocked inside requestBufferBlocking. + assertThat(requester.blockingStarted.await(5, TimeUnit.SECONDS)) + .as("drain thread did not enter blocking state in time") + .isTrue(); + + // onChannelCheckpointStopped acquires the dispatcher monitor. If drain held the monitor + // this call would deadlock; it must return within 1 s. + assertTimeoutPreemptively( + Duration.ofSeconds(1), + () -> { + store0.notifyCheckpointStopped(checkpointId); + store1.notifyCheckpointStopped(checkpointId); + }, + "onChannelCheckpointStopped deadlocked — drainPendingSpill held the dispatcher monitor"); + + // Interrupt the drain thread so the test exits cleanly. + executor.shutdownNow(); + drainFuture.get(5, TimeUnit.SECONDS); + + writer.close(); + } + + // -- helpers -- + + private static void assertEmpty(RecoveredBufferStoreImpl store) { + synchronized (store) { + assertThat(store.isEmpty()).isTrue(); + } + } + + private static void assertNotEmpty(RecoveredBufferStoreImpl store) { + synchronized (store) { + assertThat(store.isEmpty()).isFalse(); + } + } + + private static Buffer tryTake(RecoveredBufferStoreImpl store) { + synchronized (store) { + return store.tryTake(); + } + } + + private static void drainReady(RecoveredBufferStoreImpl store) { + synchronized (store) { + Buffer b; + while ((b = store.tryTake()) != null) { + b.recycleBuffer(); + } + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FilteredSpillFileTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FilteredSpillFileTest.java new file mode 100644 index 0000000000000..2ab470f6c8c79 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FilteredSpillFileTest.java @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.runtime.memory.MemoryManager; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link FilteredSpillFile}, {@link FilteredSpillFile.Reader}, and {@link + * FilteredSpillFile.Chunk}. + */ +class FilteredSpillFileTest { + + @TempDir private Path temporaryFolder; + + private static final int MEMORY_SEGMENT_SIZE = MemoryManager.DEFAULT_PAGE_SIZE; + + private static final InputChannelInfo CHANNEL_0 = new InputChannelInfo(0, 0); + private static final InputChannelInfo CHANNEL_1 = new InputChannelInfo(0, 1); + + /** Write a single entry and read it back via readNext; verify bytes match. */ + @Test + void testSingleEntryRoundTrip() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + Random random = new Random(42); + byte[] data = new byte[1024]; + random.nextBytes(data); + + try (FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE)) { + writer.writeEntry(data, data.length, CHANNEL_0); + writer.finish(); + + FilteredSpillFile.Reader reader = writer.getReaders().get(0); + assertThat(reader.hasEntries()).isTrue(); + FilteredSpillFile.Chunk chunk = reader.readNext(); + assertThat(chunk).isNotNull(); + assertThat(chunk.getChannelInfo()).isEqualTo(CHANNEL_0); + assertThat(chunk.getLength()).isEqualTo(data.length); + assertThat(chunk.getData()).startsWith(data); + assertThat(reader.hasEntries()).isFalse(); + } + } + + /** Write multiple entries across channels; verify readNext returns them in order. */ + @Test + void testMultipleEntriesInOrder() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + byte[] d0 = new byte[] {1, 2, 3, 4}; + byte[] d1 = new byte[] {5, 6, 7, 8}; + + try (FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE)) { + writer.writeEntry(d0, d0.length, CHANNEL_0); + writer.writeEntry(d1, d1.length, CHANNEL_1); + writer.finish(); + + FilteredSpillFile.Reader reader = writer.getReaders().get(0); + FilteredSpillFile.Chunk c0 = reader.readNext(); + assertThat(c0.getChannelInfo()).isEqualTo(CHANNEL_0); + assertThat(c0.getLength()).isEqualTo(d0.length); + byte[] actual0 = new byte[d0.length]; + System.arraycopy(c0.getData(), 0, actual0, 0, d0.length); + assertThat(actual0).isEqualTo(d0); + + FilteredSpillFile.Chunk c1 = reader.readNext(); + assertThat(c1.getChannelInfo()).isEqualTo(CHANNEL_1); + assertThat(c1.getLength()).isEqualTo(d1.length); + byte[] actual1 = new byte[d1.length]; + System.arraycopy(c1.getData(), 0, actual1, 0, d1.length); + assertThat(actual1).isEqualTo(d1); + + assertThat(reader.readNext()).isNull(); + } + } + + /** + * Write more than 64MB to trigger file rotation; verify multiple Readers are created and data + * is correct across files. + */ + @Test + void testFileRotation() throws Exception { + Path dir1 = Files.createDirectory(temporaryFolder.resolve("dir1")); + Path dir2 = Files.createDirectory(temporaryFolder.resolve("dir2")); + String[] spillDirs = {dir1.toString(), dir2.toString()}; + + // 64MB / DEFAULT_PAGE_SIZE + extra to force at least one rotation + int numEntries = (int) (64L * 1024 * 1024 / MEMORY_SEGMENT_SIZE) + 10; + byte[][] chunks = new byte[numEntries][MEMORY_SEGMENT_SIZE]; + Random random = new Random(42); + for (byte[] chunk : chunks) { + random.nextBytes(chunk); + } + + try (FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE)) { + for (int i = 0; i < numEntries; i++) { + writer.writeEntry(chunks[i], MEMORY_SEGMENT_SIZE, CHANNEL_0); + } + writer.finish(); + + assertThat(writer.getReaders().size()).isGreaterThan(1); + + int idx = 0; + for (FilteredSpillFile.Reader reader : writer.getReaders()) { + while (reader.hasEntries()) { + FilteredSpillFile.Chunk chunk = reader.readNext(); + byte[] actual = new byte[chunk.getLength()]; + System.arraycopy(chunk.getData(), 0, actual, 0, chunk.getLength()); + assertThat(actual).isEqualTo(chunks[idx++]); + } + } + assertThat(idx).isEqualTo(numEntries); + } + } + + /** Writer.close() finishes and releases resources; writeEntry after close throws. */ + @Test + void testCloseReleasesResources() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE); + writer.writeEntry(new byte[] {1, 2, 3}, 3, CHANNEL_0); + writer.close(); + + assertThatThrownBy(() -> writer.writeEntry(new byte[] {4}, 1, CHANNEL_0)) + .isInstanceOf(IllegalStateException.class); + } + + /** Truncated file causes readNext to throw IOException. */ + @Test + void testTruncatedFileThrows() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + byte[] data = new byte[1024]; + new Random(42).nextBytes(data); + + try (FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE)) { + writer.writeEntry(data, data.length, CHANNEL_0); + writer.finish(); + + // Truncate the spill file to half + Path filePath = writer.getReaders().get(0).filePath; + try (RandomAccessFile raf = new RandomAccessFile(filePath.toFile(), "rw")) { + raf.setLength(data.length / 2); + } + + assertThatThrownBy(() -> writer.getReaders().get(0).readNext()) + .isInstanceOf(IOException.class); + } + } + + /** snapshot() creates an independent Reader with the same entries; pre-frozen. */ + @Test + void testSnapshot() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + byte[] data = new byte[256]; + for (int i = 0; i < data.length; i++) { + data[i] = (byte) i; + } + + try (FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE)) { + writer.writeEntry(data, data.length, CHANNEL_0); + writer.finish(); + + FilteredSpillFile.Reader original = writer.getReaders().get(0); + assertThat(original.isFrozen()).isTrue(); + + FilteredSpillFile.Reader snap = original.snapshot(); + try { + assertThat(snap.isFrozen()).isTrue(); + assertThat(snap.hasEntries()).isTrue(); + + FilteredSpillFile.Chunk chunk = snap.readNext(); + assertThat(chunk.getLength()).isEqualTo(data.length); + byte[] actual = new byte[data.length]; + System.arraycopy(chunk.getData(), 0, actual, 0, data.length); + assertThat(actual).isEqualTo(data); + + // Original still has entries (snapshot is independent) + assertThat(original.hasEntries()).isTrue(); + } finally { + snap.close(); + } + } + } + + /** addEntry after freeze throws IllegalStateException. */ + @Test + void testAddEntryAfterFreezeThrows() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + try (FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE)) { + writer.writeEntry(new byte[] {1}, 1, CHANNEL_0); + writer.finish(); + // Reader is frozen by finish(); addEntry via a new writeEntry after finish should throw + assertThatThrownBy(() -> writer.writeEntry(new byte[] {2}, 1, CHANNEL_0)) + .isInstanceOf(IllegalStateException.class); + } + } + + /** peekNextChannel returns the channel of the next entry without consuming it. */ + @Test + void testPeekNextChannel() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + try (FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE)) { + writer.writeEntry(new byte[] {1, 2}, 2, CHANNEL_0); + writer.writeEntry(new byte[] {3, 4}, 2, CHANNEL_1); + writer.finish(); + + FilteredSpillFile.Reader reader = writer.getReaders().get(0); + assertThat(reader.peekNextChannel()).isEqualTo(CHANNEL_0); + reader.readNext(); + assertThat(reader.peekNextChannel()).isEqualTo(CHANNEL_1); + reader.readNext(); + assertThat(reader.peekNextChannel()).isNull(); + } + } + + /** getPendingChannels returns all channels with pending entries. */ + @Test + void testGetPendingChannels() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + try (FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE)) { + writer.writeEntry(new byte[] {1}, 1, CHANNEL_0); + writer.writeEntry(new byte[] {2}, 1, CHANNEL_1); + writer.finish(); + + FilteredSpillFile.Reader reader = writer.getReaders().get(0); + assertThat(reader.getPendingChannels()).containsExactlyInAnyOrder(CHANNEL_0, CHANNEL_1); + + reader.readNext(); // consume CHANNEL_0 + assertThat(reader.getPendingChannels()).containsExactly(CHANNEL_1); + + reader.readNext(); // consume CHANNEL_1 + assertThat(reader.getPendingChannels()).isEmpty(); + } + } + + /** isIdle() returns true before any writeEntry call, false after. */ + @Test + void testIsIdle() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + try (FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE)) { + assertThat(writer.isIdle()).isTrue(); + writer.writeEntry(new byte[] {1}, 1, CHANNEL_0); + assertThat(writer.isIdle()).isFalse(); + } + } + + /** + * isIdle() flips dynamically with the reader entry count: empty at start, non-idle after any + * write, stays non-idle while entries remain even if some are partially drained, idle again + * only after all entries have been consumed. Must behave consistently across multiple write / + * drain rounds. + */ + @Test + void testIsIdleFlipsAcrossWriteDrainRounds() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + try (FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE)) { + assertThat(writer.isIdle()).isTrue(); + + writer.writeEntry(new byte[] {1}, 1, CHANNEL_0); + writer.writeEntry(new byte[] {2}, 1, CHANNEL_1); + writer.writeEntry(new byte[] {3}, 1, CHANNEL_0); + assertThat(writer.isIdle()).isFalse(); + + FilteredSpillFile.Reader reader = writer.getReaders().get(0); + + reader.readNext(); + assertThat(writer.isIdle()).isFalse(); + + reader.readNext(); + assertThat(writer.isIdle()).isFalse(); + + reader.readNext(); + assertThat(reader.hasEntries()).isFalse(); + assertThat(writer.isIdle()).isTrue(); + + writer.writeEntry(new byte[] {4}, 1, CHANNEL_1); + writer.writeEntry(new byte[] {5}, 1, CHANNEL_0); + assertThat(writer.isIdle()).isFalse(); + + reader.readNext(); + assertThat(writer.isIdle()).isFalse(); + + reader.readNext(); + assertThat(reader.hasEntries()).isFalse(); + assertThat(writer.isIdle()).isTrue(); + } + } + + /** close() deletes all spill files on disk. */ + @Test + void testCloseDeletesAllFiles() throws Exception { + String[] spillDirs = {temporaryFolder.toString()}; + FilteredSpillFile writer = new FilteredSpillFile(spillDirs, MEMORY_SEGMENT_SIZE); + writer.writeEntry(new byte[64], 64, CHANNEL_0); + List filePaths = new ArrayList<>(); + for (FilteredSpillFile.Reader r : writer.getReaders()) { + filePaths.add(r.filePath); + } + for (Path p : filePaths) { + assertThat(Files.exists(p)).isTrue(); + } + + writer.close(); + + for (Path p : filePaths) { + assertThat(Files.exists(p)).isFalse(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java index 85b4fd1d48ef1..dae910d4052d9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java @@ -27,6 +27,7 @@ import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredBufferStoreImpl; import org.apache.flink.runtime.plugable.DeserializationDelegate; import org.apache.flink.streaming.runtime.io.recovery.RecordFilter; import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel; @@ -35,12 +36,12 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import java.io.IOException; +import java.nio.file.Path; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -53,6 +54,9 @@ class GateFilterHandlerBufferOwnershipTest { private static final int BUFFER_SIZE = 1024; private static final SubtaskConnectionDescriptor KEY = new SubtaskConnectionDescriptor(0, 0); + private static final InputChannelInfo TARGET_CHANNEL = new InputChannelInfo(0, 0); + + @TempDir private Path temporaryFolder; @Test void testSourceBufferRecycledOnSuccess() throws Exception { @@ -60,13 +64,14 @@ void testSourceBufferRecycledOnSuccess() throws Exception { createHandler(RecordFilter.acceptAll()); Buffer sourceBuffer = createBufferWithRecords(1L, 2L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + FilteredBufferDispatcher outputWriter = createTestOutputWriter(); + handler.filterAndRewrite(0, 0, sourceBuffer, outputWriter, TARGET_CHANNEL); // sourceBuffer should be recycled by the deserializer after consumption assertThat(sourceBuffer.isRecycled()).isTrue(); - // Clean up result buffers - result.forEach(Buffer::recycleBuffer); + outputWriter.flush(); + outputWriter.close(); } @Test @@ -75,79 +80,44 @@ void testSourceBufferRecycledWhenAllRecordsFilteredOut() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(rejectAll); Buffer sourceBuffer = createBufferWithRecords(1L, 2L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + FilteredBufferDispatcher outputWriter = createTestOutputWriter(); + handler.filterAndRewrite(0, 0, sourceBuffer, outputWriter, TARGET_CHANNEL); - assertThat(result).isEmpty(); // sourceBuffer should still be recycled even though no output was produced assertThat(sourceBuffer.isRecycled()).isTrue(); + + outputWriter.flush(); + outputWriter.close(); } @Test - void testSourceBufferRecycledOnInvalidVirtualChannel() { + void testSourceBufferRecycledOnInvalidVirtualChannel() throws Exception { // Create handler with KEY=(0,0) but call with (1,1) to trigger IllegalStateException ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(RecordFilter.acceptAll()); Buffer sourceBuffer = createBufferWithRecords(1L); + FilteredBufferDispatcher outputWriter = createTestOutputWriter(); assertThatThrownBy( - () -> handler.filterAndRewrite(1, 1, sourceBuffer, this::createEmptyBuffer)) + () -> + handler.filterAndRewrite( + 1, 1, sourceBuffer, outputWriter, TARGET_CHANNEL)) .isInstanceOf(IllegalStateException.class); // sourceBuffer must be recycled even when lookup fails before setNextBuffer assertThat(sourceBuffer.isRecycled()).isTrue(); - } - - @Test - void testResultBuffersAndCurrentBufferRecycledOnSerializationError() throws Exception { - // Use a small buffer so that records span multiple buffers. The supplier fails on the - // second request, after the first output buffer has been filled and added to resultBuffers. - AtomicInteger bufferRequestCount = new AtomicInteger(0); - ChannelStateFilteringHandler.BufferSupplier failingSupplier = - () -> { - if (bufferRequestCount.incrementAndGet() > 1) { - throw new IOException("Simulated buffer allocation failure"); - } - return createEmptyBuffer(13); - }; - - ChannelStateFilteringHandler.GateFilterHandler handler = - createHandler(RecordFilter.acceptAll()); - - Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); - - // The exception should propagate; no buffer leak (no IllegalReferenceCountException - // from double-recycle). - assertThatThrownBy(() -> handler.filterAndRewrite(0, 0, sourceBuffer, failingSupplier)) - .isInstanceOf(IOException.class) - .hasMessage("Simulated buffer allocation failure"); - // sourceBuffer ownership was transferred to the deserializer via setNextBuffer(). - // The deserializer may still hold it if it hasn't fully consumed the buffer before the - // error. Calling clear() triggers the cleanup chain: - // GateFilterHandler#clear() -> VirtualChannel#clear() -> deserializer.clear() - handler.clear(); - assertThat(sourceBuffer.isRecycled()).isTrue(); + outputWriter.close(); } /** - * Tests the production cleanup path: when filterAndRewrite throws mid-processing, the - * deserializer may still hold sourceBuffer. In production, ChannelStateFilteringHandler is used - * in a try-with-resources block (see {@code SequentialChannelStateReaderImpl#readInputData}), - * so its close() is guaranteed to be called, which triggers clear() on all GateFilterHandlers - * and their deserializers. This test simulates that exact pattern. + * Tests that the production cleanup path works: when ChannelStateFilteringHandler is used in a + * try-with-resources block, its close() clears all GateFilterHandlers and their deserializers, + * ensuring buffers held by the deserializer are recycled. */ @Test void testCloseRecyclesDeserializerHeldBufferAfterError() throws Exception { - AtomicInteger bufferRequestCount = new AtomicInteger(0); - ChannelStateFilteringHandler.BufferSupplier failingSupplier = - () -> { - if (bufferRequestCount.incrementAndGet() > 1) { - throw new IOException("Simulated buffer allocation failure"); - } - return createEmptyBuffer(13); - }; - ChannelStateFilteringHandler.GateFilterHandler gateHandler = createHandler(RecordFilter.acceptAll()); // Wrap in ChannelStateFilteringHandler, the production-level owner @@ -155,18 +125,22 @@ void testCloseRecyclesDeserializerHeldBufferAfterError() throws Exception { new ChannelStateFilteringHandler( new ChannelStateFilteringHandler.GateFilterHandler[] {gateHandler}); + // Create a large buffer that will require multiple writes Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); + // Create a dispatcher that throws on the second write call + FilteredBufferDispatcher failingWriter = new FailingOutputWriter(); + // Simulate the production try-with-resources pattern assertThatThrownBy( () -> { try (ChannelStateFilteringHandler ignored = filteringHandler) { filteringHandler.filterAndRewrite( - 0, 0, 0, sourceBuffer, failingSupplier); + 0, 0, 0, sourceBuffer, failingWriter, TARGET_CHANNEL); } }) .isInstanceOf(IOException.class) - .hasMessage("Simulated buffer allocation failure"); + .hasMessage("Simulated write failure"); // After close(), the entire cleanup chain has fired: // ChannelStateFilteringHandler.close() -> GateFilterHandler.clear() @@ -174,10 +148,6 @@ void testCloseRecyclesDeserializerHeldBufferAfterError() throws Exception { assertThat(sourceBuffer.isRecycled()).isTrue(); } - // ------------------------------------------------------------------------------------------- - // Helper methods - // ------------------------------------------------------------------------------------------- - private ChannelStateFilteringHandler.GateFilterHandler createHandler( RecordFilter filter) { RecordDeserializer> deserializer = @@ -193,6 +163,35 @@ private ChannelStateFilteringHandler.GateFilterHandler createHandler( return new ChannelStateFilteringHandler.GateFilterHandler<>(channels, serializer); } + private FilteredBufferDispatcher createTestOutputWriter() throws IOException { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(TARGET_CHANNEL); + Map storesByChannel = new HashMap<>(); + storesByChannel.put(TARGET_CHANNEL, store); + + String[] spillDirs = new String[] {temporaryFolder.toString()}; + BufferRequester newBufferPerRequest = + new BufferRequester() { + @Override + public Buffer requestBuffer(InputChannelInfo channelInfo) { + return createEmptyBuffer(); + } + + @Override + public Buffer requestBufferBlocking(InputChannelInfo channelInfo) { + return createEmptyBuffer(); + } + + @Override + public void releaseExclusiveBuffers() {} + }; + return new FilteredBufferDispatcherImpl( + storesByChannel, + ChannelStateWriter.NO_OP, + spillDirs, + BUFFER_SIZE, + newBufferPerRequest); + } + private Buffer createBufferWithRecords(Long... values) { try { StreamElementSerializer serializer = @@ -220,11 +219,33 @@ private Buffer createBufferWithRecords(Long... values) { } private Buffer createEmptyBuffer() { - return createEmptyBuffer(BUFFER_SIZE); + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(BUFFER_SIZE); + return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); } - private Buffer createEmptyBuffer(int size) { - MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(size); - return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); + /** + * A dispatcher that fails on the second write call. Used to test cleanup paths when + * filterAndRewrite encounters an error during serialization output. + */ + private static class FailingOutputWriter implements FilteredBufferDispatcher { + private int writeCount = 0; + + @Override + public void write(byte[] data, int length, InputChannelInfo channelInfo) + throws IOException { + writeCount++; + if (writeCount > 1) { + throw new IOException("Simulated write failure"); + } + } + + @Override + public void flush() {} + + @Override + public void drainPendingSpill() {} + + @Override + public void close() {} } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java index f02ce35fd867d..388215e9d410b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java @@ -27,6 +27,7 @@ import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredBufferStoreImpl; import org.apache.flink.runtime.plugable.DeserializationDelegate; import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate; import org.apache.flink.streaming.runtime.io.recovery.RecordFilter; @@ -36,8 +37,10 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import java.io.IOException; +import java.nio.file.Path; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -50,17 +53,24 @@ class GateFilterHandlerTest { private static final int BUFFER_SIZE = 1024; private static final SubtaskConnectionDescriptor KEY = new SubtaskConnectionDescriptor(0, 0); + private static final InputChannelInfo TARGET_CHANNEL = new InputChannelInfo(0, 0); + + @TempDir private Path temporaryFolder; @Test void testAllRecordsPassFilter() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(RecordFilter.acceptAll()); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(TARGET_CHANNEL); + FilteredBufferDispatcher outputWriter = createTestOutputWriter(store); + Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + handler.filterAndRewrite(0, 0, sourceBuffer, outputWriter, TARGET_CHANNEL); + outputWriter.flush(); + outputWriter.close(); - // deserializeBuffers consumes (recycles) each buffer via the deserializer - List values = deserializeBuffers(result); + List values = drainAndDeserialize(store); assertThat(values).containsExactly(1L, 2L, 3L); } @@ -69,10 +79,16 @@ void testAllRecordsFilteredOut() throws Exception { RecordFilter rejectAll = record -> false; ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(rejectAll); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(TARGET_CHANNEL); + FilteredBufferDispatcher outputWriter = createTestOutputWriter(store); + Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + handler.filterAndRewrite(0, 0, sourceBuffer, outputWriter, TARGET_CHANNEL); + outputWriter.flush(); + outputWriter.close(); - assertThat(result).isEmpty(); + List values = drainAndDeserialize(store); + assertThat(values).isEmpty(); } @Test @@ -80,49 +96,37 @@ void testPartialFiltering() throws Exception { RecordFilter keepEven = record -> record.getValue() % 2 == 0; ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(keepEven); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(TARGET_CHANNEL); + FilteredBufferDispatcher outputWriter = createTestOutputWriter(store); + Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + handler.filterAndRewrite(0, 0, sourceBuffer, outputWriter, TARGET_CHANNEL); + outputWriter.flush(); + outputWriter.close(); - List values = deserializeBuffers(result); + List values = drainAndDeserialize(store); assertThat(values).containsExactly(2L, 4L); } - @Test - void testSmallOutputBufferProducesMultipleBuffers() throws Exception { - // Use a very small output buffer size so records must span multiple buffers - int smallBufferSize = 8; - ChannelStateFilteringHandler.GateFilterHandler handler = - createHandler(RecordFilter.acceptAll()); - - Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); - List result = - handler.filterAndRewrite( - 0, 0, sourceBuffer, () -> createEmptyBuffer(smallBufferSize)); - - // Each Long record needs 4 bytes length + ~9 bytes data > 8-byte buffer - assertThat(result.size()).isGreaterThan(1); - - List values = deserializeBuffers(result); - assertThat(values).containsExactly(1L, 2L, 3L); - } - @Test void testEmptyBuffer() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(RecordFilter.acceptAll()); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(TARGET_CHANNEL); + FilteredBufferDispatcher outputWriter = createTestOutputWriter(store); + Buffer emptyBuffer = createEmptyBuffer(); emptyBuffer.setSize(0); - List result = handler.filterAndRewrite(0, 0, emptyBuffer, this::createEmptyBuffer); + handler.filterAndRewrite(0, 0, emptyBuffer, outputWriter, TARGET_CHANNEL); + outputWriter.flush(); + outputWriter.close(); - assertThat(result).isEmpty(); + List values = drainAndDeserialize(store); + assertThat(values).isEmpty(); } - // ------------------------------------------------------------------------------------------- - // Helper methods - // ------------------------------------------------------------------------------------------- - private ChannelStateFilteringHandler.GateFilterHandler createHandler( RecordFilter filter) { RecordDeserializer> deserializer = @@ -138,6 +142,35 @@ private ChannelStateFilteringHandler.GateFilterHandler createHandler( return new ChannelStateFilteringHandler.GateFilterHandler<>(channels, serializer); } + private FilteredBufferDispatcher createTestOutputWriter(RecoveredBufferStoreImpl store) + throws IOException { + Map storesByChannel = new HashMap<>(); + storesByChannel.put(TARGET_CHANNEL, store); + + String[] spillDirs = new String[] {temporaryFolder.toString()}; + BufferRequester newBufferPerRequest = + new BufferRequester() { + @Override + public Buffer requestBuffer(InputChannelInfo channelInfo) { + return createEmptyBuffer(); + } + + @Override + public Buffer requestBufferBlocking(InputChannelInfo channelInfo) { + return createEmptyBuffer(); + } + + @Override + public void releaseExclusiveBuffers() {} + }; + return new FilteredBufferDispatcherImpl( + storesByChannel, + ChannelStateWriter.NO_OP, + spillDirs, + BUFFER_SIZE, + newBufferPerRequest); + } + private Buffer createBufferWithRecords(Long... values) throws IOException { StreamElementSerializer serializer = new StreamElementSerializer<>(LongSerializer.INSTANCE); @@ -150,14 +183,11 @@ private Buffer serializeRecordsToBuffer( DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); for (Long value : values) { - // Serialize using the same length-prefixed format as Flink DataOutputSerializer recordOutput = new DataOutputSerializer(64); serializer.serialize(new StreamRecord<>(value), recordOutput); int recordLength = recordOutput.length(); - // Write 4-byte big-endian length prefix output.writeInt(recordLength); - // Write record bytes output.write(recordOutput.getSharedBuffer(), 0, recordLength); } @@ -179,7 +209,8 @@ private Buffer createEmptyBuffer(int size) { return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); } - private List deserializeBuffers(List buffers) throws IOException { + /** Drains all buffers from the store and deserializes Long values. */ + private List drainAndDeserialize(RecoveredBufferStoreImpl store) throws IOException { StreamElementSerializer serializer = new StreamElementSerializer<>(LongSerializer.INSTANCE); SpillingAdaptiveSpanningRecordDeserializer> @@ -190,7 +221,8 @@ private List deserializeBuffers(List buffers) throws IOException { new NonReusingDeserializationDelegate<>(serializer); List values = new ArrayList<>(); - for (Buffer buffer : buffers) { + Buffer buffer; + while ((buffer = takeUnderLock(store)) != null) { deserializer.setNextBuffer(buffer); while (true) { RecordDeserializer.DeserializationResult result = @@ -210,4 +242,10 @@ private List deserializeBuffers(List buffers) throws IOException { } return values; } + + private static Buffer takeUnderLock(RecoveredBufferStoreImpl store) { + synchronized (store) { + return store.tryTake(); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java index 9c4aab0bc7a5d..c790ab8ec9d47 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java @@ -83,7 +83,8 @@ private InputChannelRecoveredStateHandler buildInputChannelStateHandler( .MappingType.IDENTITY) }), null, - MemoryManager.DEFAULT_PAGE_SIZE); + MemoryManager.DEFAULT_PAGE_SIZE, + null); } private InputChannelRecoveredStateHandler buildMultiChannelHandler() { @@ -111,7 +112,8 @@ private InputChannelRecoveredStateHandler buildMultiChannelHandler() { .MappingType.RESCALING) }), null, - MemoryManager.DEFAULT_PAGE_SIZE); + MemoryManager.DEFAULT_PAGE_SIZE, + null); } /** Builds a handler in filtering mode (non-null filtering handler, no-op stub). */ @@ -136,7 +138,8 @@ private InputChannelRecoveredStateHandler buildFilteringInputChannelStateHandler .MappingType.IDENTITY) }), stubFilteringHandler, - MemoryManager.DEFAULT_PAGE_SIZE); + MemoryManager.DEFAULT_PAGE_SIZE, + null); } @Test @@ -298,4 +301,57 @@ void testPreFilterSegmentFreedOnClose() throws Exception { assertThat(segment.isFreed()).isTrue(); assertThat(filteringHandler.getPreFilterSegmentForTesting()).isNull(); } + + /** + * AT-FRCV (input half): finishRecovery() triggers finishReadRecoveredState() on every gate + * exactly once; close() must NOT invoke it again (close is pure resource release). + * + *

The internal {@code recoveryFinished} idempotency guard is verified via reflection. We + * also verify that the pre-filter segment — allocated during recovery — is freed by close() and + * not by finishRecovery(), confirming the clean separation of lifecycle concerns. + */ + @Test + void testFinishRecoveryTriggersConversion() throws Exception { + InputChannelRecoveredStateHandler filteringHandler = + buildFilteringInputChannelStateHandler(); + + // Allocate and recycle a pre-filter buffer so preFilterSegment is live before close(). + RecoveredChannelStateHandler.BufferWithContext buf = + filteringHandler.getBuffer(channelInfo); + buf.context.recycleBuffer(); + MemorySegment segmentBeforeFinish = filteringHandler.getPreFilterSegmentForTesting(); + assertThat(segmentBeforeFinish).isNotNull(); + + // Before finishRecovery(): recoveryFinished == false. + assertThat(getRecoveryFinishedFlag(filteringHandler)).isFalse(); + + // finishRecovery() must complete without error and flip the guard. + filteringHandler.finishRecovery(); + assertThat(getRecoveryFinishedFlag(filteringHandler)).isTrue(); + + // Idempotency: second call keeps the guard at true and does not re-invoke the gate loop. + filteringHandler.finishRecovery(); + assertThat(getRecoveryFinishedFlag(filteringHandler)).isTrue(); + + // close() is pure resource release: segment freed, guard unchanged (still true). + filteringHandler.close(); + assertThat(getRecoveryFinishedFlag(filteringHandler)) + .as("close() must not alter the recoveryFinished flag") + .isTrue(); + assertThat(filteringHandler.getPreFilterSegmentForTesting()) + .as("close() must null-out the preFilterSegment reference") + .isNull(); + assertThat(segmentBeforeFinish.isFreed()) + .as("close() must free the preFilterSegment") + .isTrue(); + } + + /** Reads the private {@code recoveryFinished} field via reflection. */ + private static boolean getRecoveryFinishedFlag(InputChannelRecoveredStateHandler handler) + throws Exception { + java.lang.reflect.Field field = + InputChannelRecoveredStateHandler.class.getDeclaredField("recoveryFinished"); + field.setAccessible(true); + return (boolean) field.get(handler); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java index c77208f3ff749..711433462dac7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java @@ -74,6 +74,16 @@ public void addInputData( } } + @Override + public void addInputDataFromSpill( + long checkpointId, CloseableIterator chunks) { + try { + chunks.close(); + } catch (Exception e) { + rethrow(e); + } + } + @Override public void addOutputData( long checkpointId, ResultSubpartitionInfo info, int startSeqNum, Buffer... data) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java index 91d4800e6736a..d00323bf96b2a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java @@ -106,4 +106,47 @@ void testRecycleBufferAfterRecoverWasCalled() throws Exception { assertThat(networkBufferPool.getNumberOfAvailableMemorySegments()) .isEqualTo(preAllocatedSegments); } + + /** + * AT-FRCV (output half): finishRecovery() invokes finishReadRecoveredState( + * notifyAndBlockOnCompletion) on each CheckpointedResultPartition exactly once; close() must + * NOT invoke it again (close is a no-op resource release for the output handler). + * + *

Verification strategy: the {@code recoveryFinished} idempotency guard on the handler is + * inspected via reflection. We verify it flips from false to true on the first finishRecovery() + * call, stays true on the second (idempotent), and remains true after close() — confirming that + * close() does not re-enter the finishReadRecoveredState logic. + */ + @Test + void testFinishRecoveryTriggersFinishReadRecoveredState() throws Exception { + // Before finishRecovery(): recoveryFinished == false. + assertThat(getRecoveryFinishedFlag(rstHandler)).isFalse(); + + // finishRecovery() must flip the guard. + rstHandler.finishRecovery(); + assertThat(getRecoveryFinishedFlag(rstHandler)) + .as("finishRecovery() must set recoveryFinished to true") + .isTrue(); + + // Idempotency: second call keeps the guard at true, does not re-invoke partition loop. + rstHandler.finishRecovery(); + assertThat(getRecoveryFinishedFlag(rstHandler)) + .as("second finishRecovery() must leave recoveryFinished as true") + .isTrue(); + + // close() is a no-op for the output handler: guard must not change. + rstHandler.close(); + assertThat(getRecoveryFinishedFlag(rstHandler)) + .as("close() must NOT alter the recoveryFinished flag") + .isTrue(); + } + + /** Reads the private {@code recoveryFinished} field via reflection. */ + private static boolean getRecoveryFinishedFlag(ResultSubpartitionRecoveredStateHandler handler) + throws Exception { + java.lang.reflect.Field field = + ResultSubpartitionRecoveredStateHandler.class.getDeclaredField("recoveryFinished"); + field.setAccessible(true); + return (boolean) field.get(handler); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/TestBufferPool.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/TestBufferPool.java new file mode 100644 index 0000000000000..cb53d1a7bb121 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/TestBufferPool.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import javax.annotation.Nullable; + +import java.util.LinkedList; +import java.util.Queue; + +/** + * Test-only {@link BufferRequester} backed by two queues: one drawn by the non-blocking fast path + * (P1 / P3) and one drawn by the blocking close() drain. Tests provide concrete queues and control + * which buffers are available on each path. + */ +final class TestBufferPool implements BufferRequester { + + private final Queue writePool; + private final Queue drainPool; + + /** Uses the same queue for both paths. */ + TestBufferPool(Queue pool) { + this(pool, pool); + } + + /** Uses separate queues for the non-blocking and blocking paths. */ + TestBufferPool(Queue writePool, Queue drainPool) { + this.writePool = writePool; + this.drainPool = drainPool; + } + + /** Shorthand for a requester that only supplies buffers to the blocking drain path. */ + static TestBufferPool drainOnly(Queue drainPool) { + return new TestBufferPool(new LinkedList<>(), drainPool); + } + + /** Shorthand for a requester whose queues are both empty (no buffer available at all). */ + static TestBufferPool empty() { + return new TestBufferPool(new LinkedList<>(), new LinkedList<>()); + } + + @Override + @Nullable + public Buffer requestBuffer(InputChannelInfo channelInfo) { + return writePool.poll(); + } + + @Override + @Nullable + public Buffer requestBufferBlocking(InputChannelInfo channelInfo) { + return drainPool.poll(); + } + + @Override + public void releaseExclusiveBuffers() { + // No-op: tests pre-supply queues; nothing to return to a global pool. + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java index d96ed78b6a031..653cad55d7494 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java @@ -46,6 +46,7 @@ import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredBufferStore; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; @@ -68,7 +69,6 @@ import java.io.IOException; import java.net.InetSocketAddress; -import java.util.ArrayDeque; import java.util.stream.Stream; import static org.apache.flink.runtime.io.network.netty.PartitionRequestQueueTest.blockChannel; @@ -953,7 +953,7 @@ private static class TestRemoteInputChannelForError extends RemoteInputChannel { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + RecoveredBufferStore.EMPTY); this.expectedMessage = expectedMessage; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java index e3cfb55e3400f..0f0cb3a6c2ac4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java @@ -35,6 +35,7 @@ import org.apache.flink.runtime.io.network.partition.TestingResultPartition; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredBufferStore; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider; @@ -44,7 +45,6 @@ import org.junit.jupiter.api.Test; -import java.util.ArrayDeque; import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -250,7 +250,7 @@ private static class TestRemoteInputChannelForPartitionNotFound extends RemoteIn new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + RecoveredBufferStore.EMPTY); this.latch = latch; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/ChannelStatePersisterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/ChannelStatePersisterTest.java index 6ea9bf8712cfb..0f23abd8a1653 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/ChannelStatePersisterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/ChannelStatePersisterTest.java @@ -41,46 +41,73 @@ /** {@link ChannelStatePersister} test. */ class ChannelStatePersisterTest { + /** + * Build a persister bound to a fresh {@link RecoveredBufferStore#EMPTY} so the test holds the + * same store monitor that the persister asserts on. Tests that need a non-empty store pass one + * explicitly to {@link #newPersister(ChannelStateWriter, InputChannelInfo, + * RecoveredBufferStore)}. + */ + private static ChannelStatePersister newPersister( + ChannelStateWriter writer, InputChannelInfo channelInfo) { + return newPersister(writer, channelInfo, RecoveredBufferStore.EMPTY); + } + + private static ChannelStatePersister newPersister( + ChannelStateWriter writer, InputChannelInfo channelInfo, RecoveredBufferStore store) { + return new ChannelStatePersister(writer, channelInfo, store); + } + @Test void testNewBarrierNotOverwrittenByStopPersisting() throws Exception { RecordingChannelStateWriter channelStateWriter = new RecordingChannelStateWriter(); InputChannelInfo channelInfo = new InputChannelInfo(0, 0); - ChannelStatePersister persister = - new ChannelStatePersister(channelStateWriter, channelInfo); + RecoveredBufferStore store = RecoveredBufferStore.EMPTY; + ChannelStatePersister persister = newPersister(channelStateWriter, channelInfo, store); long checkpointId = 1L; channelStateWriter.start( checkpointId, CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault())); - persister.checkForBarrier(barrier(checkpointId)); - persister.startPersisting(checkpointId, Arrays.asList(buildSomeBuffer())); + synchronized (store) { + persister.checkForBarrier(barrier(checkpointId)); + persister.startPersisting(checkpointId, Arrays.asList(buildSomeBuffer())); + } assertThat(channelStateWriter.getAddedInput().get(channelInfo)).hasSize(1); - persister.maybePersist(buildSomeBuffer()); + synchronized (store) { + persister.maybePersist(buildSomeBuffer()); + } assertThat(channelStateWriter.getAddedInput().get(channelInfo)).hasSize(1); // meanwhile, checkpoint coordinator timed out the 1st checkpoint and started the 2nd // now task thread is picking up the barrier and aborts the 1st: - persister.checkForBarrier(barrier(checkpointId + 1)); - persister.maybePersist(buildSomeBuffer()); - persister.stopPersisting(checkpointId); - persister.maybePersist(buildSomeBuffer()); + synchronized (store) { + persister.checkForBarrier(barrier(checkpointId + 1)); + persister.maybePersist(buildSomeBuffer()); + persister.stopPersisting(checkpointId); + persister.maybePersist(buildSomeBuffer()); + } assertThat(channelStateWriter.getAddedInput().get(channelInfo)).hasSize(1); - assertThat(persister.hasBarrierReceived()).isTrue(); + synchronized (store) { + assertThat(persister.hasBarrierReceived()).isTrue(); + } } @Test void testNewBarrierNotOverwrittenByCheckForBarrier() throws Exception { + RecoveredBufferStore store = RecoveredBufferStore.EMPTY; ChannelStatePersister persister = - new ChannelStatePersister(ChannelStateWriter.NO_OP, new InputChannelInfo(0, 0)); + newPersister(ChannelStateWriter.NO_OP, new InputChannelInfo(0, 0), store); - persister.startPersisting(1L, Collections.emptyList()); - persister.startPersisting(2L, Collections.emptyList()); + synchronized (store) { + persister.startPersisting(1L, Collections.emptyList()); + persister.startPersisting(2L, Collections.emptyList()); - assertThat(persister.checkForBarrier(barrier(1L))).isNotPresent(); + assertThat(persister.checkForBarrier(barrier(1L))).isNotPresent(); - assertThat(persister.hasBarrierReceived()).isFalse(); + assertThat(persister.hasBarrierReceived()).isFalse(); + } } @Test @@ -103,41 +130,83 @@ private void testLateBarrier( throws Exception { RecordingChannelStateWriter channelStateWriter = new RecordingChannelStateWriter(); InputChannelInfo channelInfo = new InputChannelInfo(0, 0); + RecoveredBufferStore store = RecoveredBufferStore.EMPTY; - ChannelStatePersister persister = - new ChannelStatePersister(channelStateWriter, channelInfo); + ChannelStatePersister persister = newPersister(channelStateWriter, channelInfo, store); long lateCheckpointId = 1L; long checkpointId = 2L; - if (startCheckpointOnLateBarrier) { - persister.startPersisting(lateCheckpointId, Collections.emptyList()); + synchronized (store) { + if (startCheckpointOnLateBarrier) { + persister.startPersisting(lateCheckpointId, Collections.emptyList()); + } + if (cancelCheckpointBeforeLateBarrier) { + persister.stopPersisting(lateCheckpointId); + } + persister.checkForBarrier(barrier(lateCheckpointId)); } - if (cancelCheckpointBeforeLateBarrier) { - persister.stopPersisting(lateCheckpointId); - } - persister.checkForBarrier(barrier(lateCheckpointId)); channelStateWriter.start( checkpointId, CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault())); - persister.startPersisting(checkpointId, Arrays.asList(buildSomeBuffer())); - persister.maybePersist(buildSomeBuffer()); - persister.checkForBarrier(barrier(checkpointId)); - persister.maybePersist(buildSomeBuffer()); + synchronized (store) { + persister.startPersisting(checkpointId, Arrays.asList(buildSomeBuffer())); + persister.maybePersist(buildSomeBuffer()); + persister.checkForBarrier(barrier(checkpointId)); + persister.maybePersist(buildSomeBuffer()); - assertThat(persister.hasBarrierReceived()).isTrue(); + assertThat(persister.hasBarrierReceived()).isTrue(); + } assertThat(channelStateWriter.getAddedInput().get(channelInfo)).hasSize(2); } + @Test + void testStartPersistingRejectsNonEmptyStoreAndNonEmptyKnownBuffers() throws Exception { + // Invariant: store non-empty and knownBuffers non-empty must not coexist. This is + // guaranteed by UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM=true (upstream does not replay + // output state) combined with RemoteInputChannel#getNextBuffer draining the store before + // polling receivedBuffers. Violating this means one of those assumptions broke. + RecordingChannelStateWriter channelStateWriter = new RecordingChannelStateWriter(); + InputChannelInfo channelInfo = new InputChannelInfo(0, 0); + + long checkpointId = 1L; + channelStateWriter.start( + checkpointId, CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault())); + + RecoveredBufferStoreImpl nonEmptyStore = new RecoveredBufferStoreImpl(channelInfo); + synchronized (nonEmptyStore) { + nonEmptyStore.addBuffer(buildSomeBuffer()); + } + ChannelStatePersister persister = + newPersister(channelStateWriter, channelInfo, nonEmptyStore); + assertThatThrownBy( + () -> { + synchronized (nonEmptyStore) { + persister.startPersisting( + checkpointId, Arrays.asList(buildSomeBuffer())); + } + }) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Invariant violated"); + } + @Test void testLateBarrierTriggeringCheckpoint() throws Exception { + RecoveredBufferStore store = RecoveredBufferStore.EMPTY; ChannelStatePersister persister = - new ChannelStatePersister(ChannelStateWriter.NO_OP, new InputChannelInfo(0, 0)); + newPersister(ChannelStateWriter.NO_OP, new InputChannelInfo(0, 0), store); long lateCheckpointId = 1L; long checkpointId = 2L; - persister.checkForBarrier(barrier(checkpointId)); + synchronized (store) { + persister.checkForBarrier(barrier(checkpointId)); + } assertThatThrownBy( - () -> persister.startPersisting(lateCheckpointId, Collections.emptyList())) + () -> { + synchronized (store) { + persister.startPersisting( + lateCheckpointId, Collections.emptyList()); + } + }) .isInstanceOf(CheckpointException.class); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java index 08f65d9fe7265..1a0d5f589a50d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java @@ -34,7 +34,6 @@ import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import java.net.InetSocketAddress; -import java.util.ArrayDeque; import static org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.TestingResultPartitionManager; @@ -166,7 +165,7 @@ public LocalInputChannel buildLocalChannel(SingleInputGate inputGate) { metrics.getNumBytesInLocalCounter(), metrics.getNumBuffersInLocalCounter(), stateWriter, - new ArrayDeque<>()); + RecoveredBufferStore.EMPTY); } public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) { @@ -184,7 +183,7 @@ public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) { metrics.getNumBytesInRemoteCounter(), metrics.getNumBuffersInRemoteCounter(), stateWriter, - new ArrayDeque<>()); + RecoveredBufferStore.EMPTY); } public LocalRecoveredInputChannel buildLocalRecoveredChannel(SingleInputGate inputGate) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java index 86bda9866d204..c3fdf1eebf972 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java @@ -22,6 +22,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.disk.NoOpFileChannelManager; @@ -61,7 +62,6 @@ import org.mockito.stubbing.Answer; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -678,11 +678,17 @@ void testGetBuffersInUseCountIncludesToBeConsumedBuffers() throws Exception { new TestingResultPartitionManager(subpartitionView); final SingleInputGate inputGate = createSingleInputGate(1); - // Create 3 recovered buffers - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(32)); - recoveredBuffers.add(TestBufferFactory.createBuffer(32)); - recoveredBuffers.add(TestBufferFactory.createBuffer(32)); + // Create 3 recovered buffers in a store + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(new InputChannelInfo(0, 0)); + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(32)); + } + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(32)); + } + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(32)); + } final LocalInputChannel localChannel = new LocalInputChannel( @@ -697,7 +703,7 @@ void testGetBuffersInUseCountIncludesToBeConsumedBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + store); inputGate.setInputChannels(localChannel); @@ -714,13 +720,17 @@ void testGetBuffersInUseCountIncludesToBeConsumedBuffers() throws Exception { } @Test - void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { - // given: LocalInputChannel with recovered buffers migrated from RecoveredInputChannel + void testGetNextBufferWithRecoveredStore() throws Exception { + // given: LocalInputChannel with recovered buffers in a store SingleInputGate inputGate = createSingleInputGate(1); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - recoveredBuffers.add(TestBufferFactory.createBuffer(20)); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(new InputChannelInfo(0, 0)); + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(10)); + } + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(20)); + } LocalInputChannel channel = new LocalInputChannel( @@ -735,7 +745,7 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + store); inputGate.setInputChannels(channel); @@ -752,13 +762,19 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { @Test void testCheckpointStartedPersistsRecoveredBuffers() throws Exception { - // given: Local input channel with recovered buffers + // given: Local input channel with recovered buffers in a store SingleInputGate inputGate = new SingleInputGateBuilder().build(); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - recoveredBuffers.add(TestBufferFactory.createBuffer(20)); - recoveredBuffers.add(TestBufferFactory.createBuffer(30)); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(new InputChannelInfo(0, 0)); + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(10)); + } + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(20)); + } + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(30)); + } RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); @@ -775,7 +791,7 @@ void testCheckpointStartedPersistsRecoveredBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), stateWriter, - recoveredBuffers); + store); inputGate.setInputChannels(channel); @@ -793,6 +809,56 @@ void testCheckpointStartedPersistsRecoveredBuffers() throws Exception { .containsExactly(10, 20, 30); } + // Verify that checkpoint delegates to RecoveredBufferStore when present + @Test + void testCheckpointWithRecoveredStore() throws Exception { + // given: LocalInputChannel with a RecoveredBufferStore containing buffers + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(new InputChannelInfo(0, 0)); + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(10)); + } + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(20)); + } + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(30)); + } + + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + + LocalInputChannel channel = + new LocalInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + new ResultPartitionManager(), + new TaskEventDispatcher(), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + stateWriter, + store); + + inputGate.setInputChannels(channel); + + // when: Checkpoint is started + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + CheckpointBarrier barrier = new CheckpointBarrier(1L, 0L, options); + channel.checkpointStarted(barrier); + + // then: All 3 recovered buffers should be persisted via store.checkpoint() + List persistedBuffers = stateWriter.getAddedInput().get(channel.getChannelInfo()); + assertThat(persistedBuffers).isNotNull().hasSize(3); + assertThat(persistedBuffers.stream().mapToInt(Buffer::getSize).toArray()) + .containsExactly(10, 20, 30); + } + @Test void testPriorityEventConsumedBeforeRecoveredBuffers() throws Exception { RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); @@ -824,8 +890,10 @@ void testPriorityEventFailsFastWhenSubpartitionViewIsNull() throws Exception { // given: Local input channel with recovered buffers but NO subpartition view initialized SingleInputGate inputGate = new SingleInputGateBuilder().build(); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(new InputChannelInfo(0, 0)); + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(10)); + } LocalInputChannel channel = new LocalInputChannel( @@ -840,7 +908,7 @@ void testPriorityEventFailsFastWhenSubpartitionViewIsNull() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + store); inputGate.setInputChannels(channel); // Do NOT call channel.requestSubpartitions() — subpartitionView stays null @@ -943,6 +1011,74 @@ void testNextDataTypeCorrectedToRecoveredBufferType() throws Exception { assertThat(next.get().buffer().getSize()).isEqualTo(10); } + @Test + void testEmptyRecoveredStoreHasNoBuffers() throws Exception { + // Callers with no recovered data pass RecoveredBufferStore.EMPTY explicitly. + // EMPTY.isEmpty() == true so getBuffersInUseCount() should count 0 from the store. + // EMPTY.releaseAll() is a no-op, so releaseAllResources() must not throw. + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = + new LocalInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + new ResultPartitionManager(), + new TaskEventDispatcher(), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + ChannelStateWriter.NO_OP, + RecoveredBufferStore.EMPTY); + + inputGate.setInputChannels(channel); + + // The EMPTY store contributes 0 to the queued buffer count. + assertThat(channel.getBuffersInUseCount()).isEqualTo(0); + // releaseAllResources() must not throw (EMPTY.releaseAll() is a no-op). + channel.releaseAllResources(); + } + + @Test + void testCheckpointStartedPassesEmptyKnownBuffers() throws Exception { + // LocalInputChannel has no network inflight buffers; it always passes emptyList to + // startPersisting so that toBeConsumedBuffers are NOT snapshotted (they are ordinary + // FullyFilledBuffer splits, not channel state). + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + + LocalInputChannel channel = + new LocalInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + new ResultPartitionManager(), + new TaskEventDispatcher(), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + stateWriter, + RecoveredBufferStore.EMPTY); + + inputGate.setInputChannels(channel); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + CheckpointBarrier barrier = new CheckpointBarrier(1L, 0L, options); + + // checkpointStarted must not throw and should produce no persisted inflight data + // (knownBuffers is always emptyList for Local). + channel.checkpointStarted(barrier); + + List persisted = stateWriter.getAddedInput().get(channel.getChannelInfo()); + // No inflight buffers persisted (store is EMPTY, knownBuffers is emptyList). + assertThat(persisted).isNullOrEmpty(); + } + /** * Creates a LocalInputChannel with recovered buffers and a live subpartition, ready for * priority event tests. The channel has already called requestSubpartitions(). @@ -962,9 +1098,11 @@ private static ChannelAndSubpartition createChannelWithRecoveredBuffers( TestingResultPartitionManager partitionManager = new TestingResultPartitionManager(subpartitionView); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(new InputChannelInfo(0, 0)); for (int size : recoveredBufferSizes) { - recoveredBuffers.add(TestBufferFactory.createBuffer(size)); + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(size)); + } } LocalInputChannel channel = @@ -980,7 +1118,7 @@ private static ChannelAndSubpartition createChannelWithRecoveredBuffers( new SimpleCounter(), new SimpleCounter(), stateWriter, - recoveredBuffers); + store); inputGate.setInputChannels(channel); channel.requestSubpartitions(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredBufferStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredBufferStoreTest.java new file mode 100644 index 0000000000000..09a0624f72e27 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredBufferStoreTest.java @@ -0,0 +1,691 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.checkpoint.channel.EntryPosition; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveredBufferStoreCoordinator; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link RecoveredBufferStoreImpl}. */ +class RecoveredBufferStoreTest { + + private static final InputChannelInfo DEFAULT_CHANNEL_INFO = new InputChannelInfo(0, 0); + + /** addBuffer / tryTake lifecycle. */ + @Test + void testStoreLifecycle() { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + + // Query methods require holding the store monitor (locking contract). + synchronized (store) { + assertThat(store.isEmpty()).isTrue(); + assertThat(store.size()).isEqualTo(0); + assertThat(store.peekNextDataType()).isEqualTo(Buffer.DataType.NONE); + } + + NetworkBuffer buffer1 = createBuffer(new byte[] {1, 2, 3, 4}); + synchronized (store) { + store.addBuffer(buffer1); + } + + Buffer taken; + synchronized (store) { + assertThat(store.isEmpty()).isFalse(); + assertThat(store.size()).isEqualTo(1); + assertThat(store.peekNextDataType()).isEqualTo(Buffer.DataType.DATA_BUFFER); + + taken = store.tryTake(); + assertThat(taken).isNotNull(); + assertThat(store.isEmpty()).isTrue(); + assertThat(store.size()).isEqualTo(0); + } + taken.recycleBuffer(); + + synchronized (store) { + assertThat(store.tryTake()).isNull(); + } + } + + /** + * Checkpoint with ready buffers. Ready buffers should be retained and passed to the + * ChannelStateWriter. + */ + @Test + void testCheckpointWithReadyBuffers() throws Exception { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + + byte[] data = new byte[] {10, 20, 30, 40}; + NetworkBuffer buffer = createBuffer(data); + synchronized (store) { + store.addBuffer(buffer); + } + + RecordingChannelStateWriter writer = new RecordingChannelStateWriter(); + long checkpointId = 1L; + writer.start(checkpointId, null); + + store.checkpoint(writer, checkpointId); + + assertThat(writer.getAddedInput().get(DEFAULT_CHANNEL_INFO)).hasSize(1); + + // The original buffer should still be in the store (retained, not consumed) + synchronized (store) { + assertThat(store.size()).isEqualTo(1); + } + + // Clean up: recycle the buffer recorded by writer + writer.getAddedInput().get(DEFAULT_CHANNEL_INFO).forEach(Buffer::recycleBuffer); + store.releaseAll(); + } + + /** Concurrent access from two threads. One thread adds buffers and the other takes them. */ + @Test + void testConcurrentCheckpointAndReplay() throws Exception { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + int numBuffers = 100; + CyclicBarrier barrier = new CyclicBarrier(2); + AtomicReference error = new AtomicReference<>(); + + // Producer thread: adds buffers + Thread producer = + new Thread( + () -> { + try { + barrier.await(); + for (int i = 0; i < numBuffers; i++) { + NetworkBuffer buf = createBuffer(new byte[] {(byte) i}); + synchronized (store) { + store.addBuffer(buf); + } + } + } catch (Throwable t) { + error.set(t); + } + }); + + // Consumer thread: takes buffers + CountDownLatch consumedAll = new CountDownLatch(1); + Thread consumer = + new Thread( + () -> { + try { + barrier.await(); + int consumed = 0; + while (consumed < numBuffers) { + Buffer buf; + synchronized (store) { + buf = store.tryTake(); + } + if (buf != null) { + buf.recycleBuffer(); + consumed++; + } + } + consumedAll.countDown(); + } catch (Throwable t) { + error.set(t); + } + }); + + producer.start(); + consumer.start(); + producer.join(10_000); + consumer.join(10_000); + + assertThat(error.get()).isNull(); + synchronized (store) { + assertThat(store.isEmpty()).isTrue(); + } + } + + /** + * Simulate store transfer by adding buffers, then taking them in another "context" (simulating + * conversion). Continue consuming after conversion. + */ + @Test + void testConsumptionAfterConversion() { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + + // Add buffers in "recovery" phase + NetworkBuffer buf1 = createBuffer(new byte[] {1, 2}); + NetworkBuffer buf2 = createBuffer(new byte[] {3, 4}); + NetworkBuffer buf3 = createBuffer(new byte[] {5, 6}); + synchronized (store) { + store.addBuffer(buf1); + } + synchronized (store) { + store.addBuffer(buf2); + } + synchronized (store) { + store.addBuffer(buf3); + } + + // Simulate partial consumption before conversion + Buffer taken1; + synchronized (store) { + taken1 = store.tryTake(); + } + assertThat(taken1).isNotNull(); + taken1.recycleBuffer(); + + // After conversion, continue consuming remaining buffers + Buffer taken2; + Buffer taken3; + synchronized (store) { + taken2 = store.tryTake(); + assertThat(taken2).isNotNull(); + taken3 = store.tryTake(); + assertThat(taken3).isNotNull(); + } + taken2.recycleBuffer(); + taken3.recycleBuffer(); + + synchronized (store) { + assertThat(store.isEmpty()).isTrue(); + assertThat(store.tryTake()).isNull(); + } + } + + /** Verify releaseAll recycles all buffers and clears state. */ + @Test + void testReleaseAll() { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + + NetworkBuffer buf1 = createBuffer(new byte[] {1}); + NetworkBuffer buf2 = createBuffer(new byte[] {2}); + synchronized (store) { + store.addBuffer(buf1); + } + synchronized (store) { + store.addBuffer(buf2); + } + + store.releaseAll(); + + assertThat(buf1.isRecycled()).isTrue(); + assertThat(buf2.isRecycled()).isTrue(); + synchronized (store) { + assertThat(store.isEmpty()).isTrue(); + assertThat(store.size()).isEqualTo(0); + } + } + + /** Verify releaseAll notifies the registered coordinator with the bound channel info. */ + @Test + void testReleaseAllNotifiesCoordinator() { + InputChannelInfo channelInfo = new InputChannelInfo(3, 7); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(channelInfo); + + RecordingCoordinator coordinator = new RecordingCoordinator(); + synchronized (store) { + store.setCoordinator(coordinator); + } + + // Add some in-memory and on-disk bookkeeping to make the release meaningful. + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {1})); + } + synchronized (store) { + store.incrementPending(); + } + + store.releaseAll(); + + assertThat(coordinator.released).containsExactly(channelInfo); + synchronized (store) { + assertThat(store.isEmpty()).isTrue(); + assertThat(store.size()).isEqualTo(0); + } + } + + /** Verify data-available listener fires when buffer is added to empty store. */ + @Test + void testDataAvailableListener() { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + int[] callbackCount = {0}; + synchronized (store) { + store.setDataAvailableListener(() -> callbackCount[0]++); + } + + // Add first buffer: should trigger listener (store was empty) + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {1})); + } + assertThat(callbackCount[0]).isEqualTo(1); + + // Add second buffer: should NOT trigger listener (store was not empty) + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {2})); + } + assertThat(callbackCount[0]).isEqualTo(1); + + // Drain both buffers + synchronized (store) { + store.tryTake().recycleBuffer(); + store.tryTake().recycleBuffer(); + } + + // Add buffer again to empty store: should trigger listener + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {3})); + } + assertThat(callbackCount[0]).isEqualTo(2); + + store.releaseAll(); + } + + /** Verify pending spill entry count tracking. */ + @Test + void testPendingCount() { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + + synchronized (store) { + store.incrementPending(); + + // Store not empty when pending entries exist + assertThat(store.isEmpty()).isFalse(); + // size() reports ready + pending so the channel-level backlog reflects on-disk data too + assertThat(store.size()).isEqualTo(1); + } + + // Drain the pending entry by handing back a buffer; addBuffer folds in the matching + // pending decrement. + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {1})); + } + synchronized (store) { + // pending consumed, buffer became ready — still size 1 but now in readyBuffers + assertThat(store.size()).isEqualTo(1); + store.tryTake().recycleBuffer(); + assertThat(store.isEmpty()).isTrue(); + assertThat(store.size()).isEqualTo(0); + } + } + + /** Verify size() aggregates ready buffers and pending on-disk entries. */ + @Test + void testSizeAggregatesReadyAndPending() { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {1})); + } + synchronized (store) { + store.incrementPending(); + store.incrementPending(); + + assertThat(store.size()).isEqualTo(3); + + store.tryTake().recycleBuffer(); + assertThat(store.size()).isEqualTo(2); + } + + // Drain both pending entries by handing back buffers; each addBuffer consumes one pending. + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {2})); + } + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {3})); + } + synchronized (store) { + assertThat(store.size()).isEqualTo(2); + } + + store.releaseAll(); + } + + /** + * Verify that the coordinator registered via setCoordinator receives onChannelCheckpointStarted + * during checkpoint() after snapshotting ready buffers, with the correct checkpointId and + * channelInfo. + */ + @Test + void testCheckpointNotifiesCoordinatorAfterSnapshot() throws Exception { + InputChannelInfo channelInfo = new InputChannelInfo(0, 0); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(channelInfo); + + RecordingCoordinator coordinator = new RecordingCoordinator(); + synchronized (store) { + store.setCoordinator(coordinator); + } + + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {1, 2})); + } + + RecordingChannelStateWriter writer = new RecordingChannelStateWriter(); + long checkpointId = 42L; + writer.start(checkpointId, null); + + store.checkpoint(writer, checkpointId); + + // Coordinator must have been notified exactly once with correct args + assertThat(coordinator.checkpointIds).containsExactly(42L); + assertThat(coordinator.checkpointChannels).containsExactly(channelInfo); + + // Writer received the ready buffer before notification fired (snapshot happened first) + assertThat(writer.getAddedInput().get(channelInfo)).hasSize(1); + + writer.getAddedInput().get(channelInfo).forEach(Buffer::recycleBuffer); + store.releaseAll(); + } + + /** Verify checkpoint() without any ready buffers still notifies the coordinator. */ + @Test + void testCheckpointNotifiesCoordinatorEvenWhenNoReadyBuffers() throws Exception { + InputChannelInfo channelInfo = new InputChannelInfo(1, 2); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(channelInfo); + + RecordingCoordinator coordinator = new RecordingCoordinator(); + synchronized (store) { + store.setCoordinator(coordinator); + } + + RecordingChannelStateWriter writer = new RecordingChannelStateWriter(); + writer.start(1L, null); + store.checkpoint(writer, 1L); + + assertThat(coordinator.checkpointIds).containsExactly(1L); + } + + /** Verify no notification is fired if setCoordinator was never called. */ + @Test + void testCheckpointWithNoCoordinatorSetDoesNotThrow() throws Exception { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {1})); + } + + RecordingChannelStateWriter writer = new RecordingChannelStateWriter(); + writer.start(1L, null); + // Should not throw even without a coordinator registered + store.checkpoint(writer, 1L); + + writer.getAddedInput().get(DEFAULT_CHANNEL_INFO).forEach(Buffer::recycleBuffer); + store.releaseAll(); + } + + /** + * Verify notifyCheckpointStopped forwards the call to the registered coordinator with the bound + * channel info. + */ + @Test + void testNotifyCheckpointStoppedNotifiesCoordinator() { + InputChannelInfo channelInfo = new InputChannelInfo(2, 5); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(channelInfo); + + RecordingCoordinator coordinator = new RecordingCoordinator(); + synchronized (store) { + store.setCoordinator(coordinator); + } + + store.notifyCheckpointStopped(11L); + store.notifyCheckpointStopped(12L); + + assertThat(coordinator.stoppedCheckpointIds).containsExactly(11L, 12L); + assertThat(coordinator.stoppedChannels).containsExactly(channelInfo, channelInfo); + } + + /** Verify notifyCheckpointStopped is a safe no-op when no coordinator is registered. */ + @Test + void testNotifyCheckpointStoppedWithoutCoordinatorIsNoOp() { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + // Should not throw without a coordinator registered + store.notifyCheckpointStopped(7L); + } + + /** + * Verify setDataAvailableListener can be called through the RecoveredBufferStore interface + * without instanceof casts. + */ + @Test + void testSetDataAvailableListenerViaInterface() { + RecoveredBufferStore store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + int[] callCount = {0}; + // Must compile and run without instanceof check + synchronized (store) { + store.setDataAvailableListener(() -> callCount[0]++); + } + + synchronized (store) { + ((RecoveredBufferStoreImpl) store).addBuffer(createBuffer(new byte[] {1})); + } + assertThat(callCount[0]).isEqualTo(1); + + store.releaseAll(); + } + + /** Verify all methods of EMPTY return expected no-op / sentinel values. */ + @Test + void testEmptySingletonBehavior() throws Exception { + RecoveredBufferStore empty = RecoveredBufferStore.EMPTY; + + assertThat(empty.tryTake()).isNull(); + assertThat(empty.peekNextDataType()).isEqualTo(Buffer.DataType.NONE); + assertThat(empty.isEmpty()).isTrue(); + assertThat(empty.size()).isEqualTo(0); + } + + /** Verify checkpoint() on EMPTY is a no-op and does not write any channel state. */ + @Test + void testEmptySingletonCheckpointIsNoOp() throws Exception { + RecoveredBufferStore empty = RecoveredBufferStore.EMPTY; + + RecordingChannelStateWriter writer = new RecordingChannelStateWriter(); + writer.start(1L, null); + empty.checkpoint(writer, 1L); + + // No data must have been written + assertThat(writer.getAddedInput().isEmpty()).isTrue(); + } + + /** Verify releaseAll() on EMPTY does not throw. */ + @Test + void testEmptySingletonReleaseAllIsNoOp() { + RecoveredBufferStore.EMPTY.releaseAll(); + } + + /** Verify notifyCheckpointStopped() on EMPTY does not throw. */ + @Test + void testEmptySingletonNotifyCheckpointStoppedIsNoOp() { + RecoveredBufferStore.EMPTY.notifyCheckpointStopped(99L); + } + + /** Verify all setters on EMPTY are no-ops (accept and discard without throwing). */ + @Test + void testEmptySingletonSettersAreNoOp() { + RecoveredBufferStore empty = RecoveredBufferStore.EMPTY; + + empty.setCoordinator(new RecordingCoordinator()); + empty.setDataAvailableListener(() -> {}); + // No exception == pass + } + + /** + * Calling a {@code @GuardedBy("this")} method without holding the store monitor must trip the + * {@code assert Thread.holdsLock(this)} guard under {@code -ea}. This locks the contract in: + * future refactors that accidentally drop the synchronized wrapper at a call site will fail + * loudly in tests instead of silently producing torn reads. + */ + @Test + void testGuardedMethodsAssertHoldsLock() { + // The AssertionError surfaces only with assertions enabled; flink test JVMs run with -ea + // by default. Skip the test cleanly if for some reason this JVM was started without -ea + // so the suite does not turn red on a JVM configuration issue. + if (!RecoveredBufferStoreTest.class.desiredAssertionStatus()) { + return; + } + + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + try { + assertThat(catchAssertion(() -> store.tryTake())).isTrue(); + assertThat(catchAssertion(() -> store.peekNextDataType())).isTrue(); + assertThat(catchAssertion(() -> store.isEmpty())).isTrue(); + assertThat(catchAssertion(() -> store.incrementPending())).isTrue(); + assertThat(catchAssertion(() -> store.setCoordinator(new RecordingCoordinator()))) + .isTrue(); + assertThat(catchAssertion(() -> store.setDataAvailableListener(() -> {}))).isTrue(); + // size() is the deliberate exception — lock-free for metric / gate-bookkeeping paths. + // Calling it without holding the monitor must NOT trip the assertion guard. + assertThat(catchAssertion(() -> store.size())).isFalse(); + } finally { + store.releaseAll(); + } + } + + /** + * Concurrent drain test: when the producer keeps appending and the consumer keeps polling, each + * {@code tryTake + peekNextDataType} pair observed by the consumer must be self- consistent — + * if {@code peekNextDataType()} returns {@code NONE} after a successful tryTake it must mean + * the next tryTake on the same thread also returns null (modulo any further producer activity + * that happened strictly after the peek), and if it returns a non-NONE type the next tryTake + * must return a buffer with that type. The test guards against future regressions where someone + * splits the take/peek pair across two store-lock acquisitions. + */ + @Test + void testTryTakePeekPairAtomicUnderConcurrency() throws Exception { + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(DEFAULT_CHANNEL_INFO); + int numBuffers = 500; + CyclicBarrier startBarrier = new CyclicBarrier(2); + AtomicReference error = new AtomicReference<>(); + + Thread producer = + new Thread( + () -> { + try { + startBarrier.await(); + for (int i = 0; i < numBuffers; i++) { + synchronized (store) { + store.addBuffer(createBuffer(new byte[] {(byte) i})); + } + } + } catch (Throwable t) { + error.set(t); + } + }, + "atomic-pair-producer"); + + Thread consumer = + new Thread( + () -> { + try { + startBarrier.await(); + int consumed = 0; + while (consumed < numBuffers) { + Buffer taken; + Buffer.DataType peekedType; + synchronized (store) { + taken = store.tryTake(); + peekedType = store.peekNextDataType(); + } + if (taken == null) { + // peeked type with no taken buffer must be NONE + assertThat(peekedType).isEqualTo(Buffer.DataType.NONE); + continue; + } + taken.recycleBuffer(); + consumed++; + } + } catch (Throwable t) { + error.set(t); + } + }, + "atomic-pair-consumer"); + + producer.start(); + consumer.start(); + producer.join(10_000); + consumer.join(10_000); + + assertThat(error.get()).isNull(); + synchronized (store) { + assertThat(store.isEmpty()).isTrue(); + } + } + + private static boolean catchAssertion(Runnable r) { + try { + r.run(); + return false; + } catch (AssertionError ae) { + return true; + } + } + + /** + * Test-only coordinator that records all coordinator notifications: started, stopped, and + * released. + */ + private static class RecordingCoordinator implements RecoveredBufferStoreCoordinator { + final List checkpointIds = new ArrayList<>(); + final List checkpointChannels = new ArrayList<>(); + final List checkpointStartPositions = new ArrayList<>(); + final List stoppedCheckpointIds = new ArrayList<>(); + final List stoppedChannels = new ArrayList<>(); + final List released = new ArrayList<>(); + volatile EntryPosition currentDrainHead = EntryPosition.END; + + @Override + public EntryPosition getCurrentDrainHead() { + return currentDrainHead; + } + + @Override + public void onChannelCheckpointStarted( + long checkpointId, InputChannelInfo channelInfo, EntryPosition startPos) { + checkpointIds.add(checkpointId); + checkpointChannels.add(channelInfo); + checkpointStartPositions.add(startPos); + } + + @Override + public void onChannelCheckpointStopped(long checkpointId, InputChannelInfo channelInfo) { + stoppedCheckpointIds.add(checkpointId); + stoppedChannels.add(channelInfo); + } + + @Override + public void onChannelReleased(InputChannelInfo channelInfo) { + released.add(channelInfo); + } + } + + private static NetworkBuffer createBuffer(byte[] data) { + org.apache.flink.core.memory.MemorySegment segment = + MemorySegmentFactory.allocateUnpooledSegment(data.length); + segment.put(0, data, 0, data.length); + NetworkBuffer buffer = new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); + buffer.setSize(data.length); + return buffer; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java index f40fd09702ede..5b1ead7e77fb7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java @@ -21,15 +21,22 @@ import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointType; +import org.apache.flink.runtime.io.network.NettyShuffleEnvironment; +import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; +import org.apache.flink.runtime.memory.MemoryManager; import org.junit.jupiter.api.Test; import java.io.IOException; -import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; import static org.apache.flink.runtime.checkpoint.CheckpointOptions.unaligned; import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault; @@ -81,7 +88,9 @@ void testToInputChannelAllowedWhenBufferFilteringCompleteAndConfigEnabled() thro .hasMessageContaining("buffer filtering is not complete"); // After finishReadRecoveredState(), bufferFilteringCompleteFuture should be done - channel.finishReadRecoveredState(); + synchronized (channel.inputGate.getGateLock()) { + channel.finishReadRecoveredState(); + } assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); assertThat(channel.getStateConsumedFuture()).isNotDone(); @@ -103,7 +112,9 @@ void testToInputChannelAllowedWhenStateConsumedAndConfigDisabled() throws IOExce // After finishReadRecoveredState(), bufferFilteringCompleteFuture is done // but stateConsumedFuture is not - channel.finishReadRecoveredState(); + synchronized (channel.inputGate.getGateLock()) { + channel.finishReadRecoveredState(); + } assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); assertThat(channel.getStateConsumedFuture()).isNotDone(); @@ -113,7 +124,9 @@ void testToInputChannelAllowedWhenStateConsumedAndConfigDisabled() throws IOExce .hasMessageContaining("recovered state is not fully consumed"); // Consume the EndOfInputChannelStateEvent to complete stateConsumedFuture - assertThat(channel.getNextBuffer()).isNotPresent(); + synchronized (channel.inputGate.getGateLock()) { + assertThat(channel.getNextBuffer()).isNotPresent(); + } assertThat(channel.getStateConsumedFuture()).isDone(); // Now conversion should succeed @@ -127,7 +140,9 @@ void testBufferFilteringCompleteFutureAlwaysCompletes() throws IOException { for (boolean configEnabled : new boolean[] {true, false}) { RecoveredInputChannel channel = buildChannel(configEnabled); assertThat(channel.getBufferFilteringCompleteFuture()).isNotDone(); - channel.finishReadRecoveredState(); + synchronized (channel.inputGate.getGateLock()) { + channel.finishReadRecoveredState(); + } assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); } } @@ -141,16 +156,97 @@ void testStateConsumedFutureCompletesAfterConsumingAllBuffers() throws IOExcepti assertThat(channel.getStateConsumedFuture()).isNotDone(); - channel.finishReadRecoveredState(); + synchronized (channel.inputGate.getGateLock()) { + channel.finishReadRecoveredState(); + } assertThat(channel.getStateConsumedFuture()).isNotDone(); // Consuming the EndOfInputChannelStateEvent should complete the future. // getNextBuffer() returns empty when it encounters the event internally. - assertThat(channel.getNextBuffer()).isNotPresent(); + synchronized (channel.inputGate.getGateLock()) { + assertThat(channel.getNextBuffer()).isNotPresent(); + } assertThat(channel.getStateConsumedFuture()).isDone(); } } + @Test + void testRequestBufferNonBlockingAndBlockingHasNoHeapFallback() throws Exception { + int numBuffers = 3; + NettyShuffleEnvironment environment = + new NettyShuffleEnvironmentBuilder() + .setNumNetworkBuffers(numBuffers) + .setBufferSize(MemoryManager.DEFAULT_PAGE_SIZE) + .build(); + try { + SingleInputGate filteringGate = + new SingleInputGateBuilder() + .setChannelFactory(InputChannelBuilder::buildLocalRecoveredChannel) + .setupBufferPoolFactory(environment) + .setCheckpointingDuringRecoveryEnabled(true) + .build(); + filteringGate.setup(); + + RecoveredInputChannel channel = (RecoveredInputChannel) filteringGate.getChannel(0); + + // requestBuffer() is non-blocking: drain exclusive buffers, then null is returned. + List allBuffers = new ArrayList<>(); + while (true) { + Buffer b = channel.requestBuffer(); + if (b == null) { + break; + } + allBuffers.add(b); + } + assertThat(channel.requestBuffer()).isNull(); + + // Also drain the gate's floating buffer pool so requestBufferBlocking() has nothing + // left and is forced to block. + BufferPool bufferPool = filteringGate.getBufferPool(); + while (true) { + Buffer b = bufferPool.requestBuffer(); + if (b == null) { + break; + } + allBuffers.add(b); + } + + // requestBufferBlocking() must block (not fall back to heap) when the pool is empty. + CompletableFuture blockingFuture = new CompletableFuture<>(); + Thread blockingThread = + new Thread( + () -> { + try { + blockingFuture.complete(channel.requestBufferBlocking()); + } catch (Exception e) { + blockingFuture.completeExceptionally(e); + } + }); + blockingThread.start(); + + long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (blockingThread.getState() != Thread.State.WAITING + && System.nanoTime() < deadline) { + Thread.onSpinWait(); + } + assertThat(blockingFuture.isDone()).isFalse(); + + // Recycle one buffer; blocking thread then gets a pool buffer. + allBuffers.remove(0).recycleBuffer(); + Buffer poolBuffer = blockingFuture.get(5, TimeUnit.SECONDS); + assertThat(poolBuffer).isNotNull(); + poolBuffer.recycleBuffer(); + + for (Buffer b : allBuffers) { + b.recycleBuffer(); + } + blockingThread.join(5000); + filteringGate.close(); + } finally { + environment.close(); + } + } + private RecoveredInputChannel buildChannel(boolean checkpointingDuringRecoveryEnabled) { try { SingleInputGate inputGate = @@ -169,7 +265,8 @@ private RecoveredInputChannel buildChannel(boolean checkpointingDuringRecoveryEn new SimpleCounter(), 10) { @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + protected InputChannel toInputChannelInternal( + RecoveredBufferStoreImpl recoveredStore) { throw new AssertionError("channel conversion succeeded"); } }; @@ -210,7 +307,7 @@ private static class TestableRecoveredInputChannel extends RecoveredInputChannel } @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + protected InputChannel toInputChannelInternal(RecoveredBufferStoreImpl recoveredStore) { return new TestInputChannel(inputGate, 0); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java index e47de93c9e8bd..8164b5c79bc89 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java @@ -162,6 +162,95 @@ void testGateNotifiedOnBarrierConversion() throws IOException, InterruptedExcept } } + @Test + void testPriorityFlagSetUnderLockOnPriorityEnqueue() throws Exception { + // Producer-side invariant: when onBuffer enqueues a priority element under the + // receivedBuffers lock, hasPendingPriorityEvent is also set true *inside* that same + // synchronized block. Externally observable consequence: the flag is already true the + // moment onBuffer returns to its caller (network thread), with no window in which the + // priority element is queued but the flag is still false. + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannels(inputChannel); + inputChannel.requestSubpartitions(); + + Buffer priority = + toBuffer( + new CheckpointBarrier( + CHECKPOINT_ID, + System.currentTimeMillis(), + CheckpointOptions.unaligned(CHECKPOINT, getDefault())), + true); + inputChannel.onBuffer(priority, 0, -1, 0); + + assertThat(getHasPendingPriorityEvent(inputChannel)).isTrue(); + + inputChannel.releaseAllResources(); + } + + @Test + void testNormalPathPollClearsPriorityFlagInvariant() throws Exception { + // Consumer-side invariant: if a normal-path getNextBuffer drains the last priority + // element via PrioritizedDeque.poll() (which can happen when a stale `false` read of + // the flag routes the consumer through the normal path), the same synchronized block + // resets the flag so a subsequent producer flag write cannot leave the channel in a + // `flag=true && numPriorityElements==0` state. + // + // We simulate the "stale read sent the consumer down the normal path" outcome by + // manually clearing the flag (post-onBuffer) before calling getNextBuffer, so the + // consumer's `if (hasPendingPriorityEvent)` check sees false and falls through to the + // normal poll. The point under test is what getNextBuffer does under its receivedBuffers + // lock, not the upstream stale-read window itself. + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannels(inputChannel); + inputChannel.requestSubpartitions(); + + Buffer priority = + toBuffer( + new CheckpointBarrier( + CHECKPOINT_ID, + System.currentTimeMillis(), + CheckpointOptions.unaligned(CHECKPOINT, getDefault())), + true); + inputChannel.onBuffer(priority, 0, -1, 0); + // Force the consumer into the normal path even though a priority element is queued. + setHasPendingPriorityEvent(inputChannel, false); + + Optional first = inputChannel.getNextBuffer(); + assertThat(first).isPresent(); + assertThat(first.get().buffer().getDataType().hasPriority()).isTrue(); + assertThat(getHasPendingPriorityEvent(inputChannel)).isFalse(); + + // Subsequent DATA must not trip the priority invariant: with the flag correctly cleared + // by the previous normal-path poll, the consumer takes the normal path again. + Buffer dataBuffer = createBuffer(TestBufferFactory.BUFFER_SIZE); + inputChannel.onBuffer(dataBuffer, 1, -1, 0); + + Optional second = inputChannel.getNextBuffer(); + assertThat(second).isPresent(); + assertThat(second.get().buffer().getDataType()).isEqualTo(DataType.DATA_BUFFER); + assertThat(getHasPendingPriorityEvent(inputChannel)).isFalse(); + + inputChannel.releaseAllResources(); + } + + private static void setHasPendingPriorityEvent(RemoteInputChannel channel, boolean value) + throws ReflectiveOperationException { + java.lang.reflect.Field f = + RemoteInputChannel.class.getDeclaredField("hasPendingPriorityEvent"); + f.setAccessible(true); + f.setBoolean(channel, value); + } + + private static boolean getHasPendingPriorityEvent(RemoteInputChannel channel) + throws ReflectiveOperationException { + java.lang.reflect.Field f = + RemoteInputChannel.class.getDeclaredField("hasPendingPriorityEvent"); + f.setAccessible(true); + return f.getBoolean(channel); + } + @Test void testExceptionOnReordering() throws Exception { // Setup @@ -2075,13 +2164,52 @@ void verifyResult( } @Test - void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { - // given: RemoteInputChannel with recovered buffers migrated from RecoveredInputChannel + void testNullRecoveredStoreDefaultsToEmpty() throws Exception { + // When no recovered data is passed (null), the constructor must substitute EMPTY. + // releaseAllResources() must not throw (EMPTY.releaseAll() is a no-op). + SingleInputGate inputGate = createSingleInputGate(1); + ConnectionID connectionId = + new ConnectionID( + org.apache.flink.runtime.clusterframework.types.ResourceID.generate(), + new java.net.InetSocketAddress("localhost", 0), + 0); + RemoteInputChannel channel = + new RemoteInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + connectionId, + InputChannelTestUtils.mockConnectionManagerWithPartitionRequestClient( + mock(PartitionRequestClient.class)), + 0, + 0, + 0, + 2 /* initialCredit */, + new SimpleCounter(), + new SimpleCounter(), + ChannelStateWriter.NO_OP, + RecoveredBufferStore.EMPTY); + + inputGate.setInputChannels(channel); + + assertThat(channel.getInitialCredit()).isEqualTo(2); + // releaseAllResources() must not throw (EMPTY.releaseAll() is a no-op). + channel.releaseAllResources(); + } + + @Test + void testGetNextBufferWithRecoveredStore() throws Exception { + // given: RemoteInputChannel with recovered buffers in a store SingleInputGate inputGate = createSingleInputGate(1); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - recoveredBuffers.add(TestBufferFactory.createBuffer(20)); + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(new InputChannelInfo(0, 0)); + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(10)); + } + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(20)); + } ConnectionID connectionId = new ConnectionID( @@ -2104,7 +2232,7 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + store); inputGate.setInputChannels(channel); @@ -2119,6 +2247,81 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { assertThat(second.get().buffer().getSize()).isEqualTo(20); } + @Test + void testNextDataTypeReflectsReceivedBuffersWhenRecoveredStoreExhausted() throws Exception { + // When the very last tryTake empties the recovered store but receivedBuffers + // already has a buffer queued by onBuffer (this happens in production when the + // channel was already in inputChannelsWithData with the bit set, so the + // notifyChannelNonEmpty triggered by onBuffer is short-circuited by + // alreadyEnqueued), getNextBuffer must surface the receivedBuffers head as the + // next data type so moreAvailable() == true and the gate keeps the channel + // enqueued. Otherwise the queued buffer becomes invisible to the gate. + SingleInputGate inputGate = createSingleInputGate(1); + + RecoveredBufferStoreImpl store = new RecoveredBufferStoreImpl(new InputChannelInfo(0, 0)); + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(10)); + } + synchronized (store) { + store.addBuffer(TestBufferFactory.createBuffer(20)); + } + + ConnectionID connectionId = + new ConnectionID( + org.apache.flink.runtime.clusterframework.types.ResourceID.generate(), + new java.net.InetSocketAddress("localhost", 0), + 0); + RemoteInputChannel channel = + new RemoteInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + connectionId, + InputChannelTestUtils.mockConnectionManagerWithPartitionRequestClient( + mock(PartitionRequestClient.class)), + 0, + 0, + 0, + 2, + new SimpleCounter(), + new SimpleCounter(), + ChannelStateWriter.NO_OP, + store); + + inputGate.setInputChannels(channel); + channel.requestSubpartitions(); + + // First take: store still has one more, moreAvailable=true (purely from store). + Optional first = channel.getNextBuffer(); + assertThat(first).isPresent(); + assertThat(first.get().moreAvailable()).isTrue(); + first.get().buffer().recycleBuffer(); + + // Producer-side message arrives via the network thread and lands in + // receivedBuffers. The data type is irrelevant for this assertion (production + // hits this with RECOVERY_COMPLETION but any non-priority buffer reproduces it). + Buffer received = TestBufferFactory.createBuffer(30); + channel.onBuffer(received, 0, 0, 0); + + // Last take from the recovered store. Without the fix the next-data-type peek + // only consults the recovered store and returns NONE — the gate then sees + // moreAvailable=false and never re-enqueues the channel. + Optional second = channel.getNextBuffer(); + assertThat(second).isPresent(); + assertThat(second.get().moreAvailable()) + .as( + "moreAvailable must reflect receivedBuffers when the recovered store is exhausted") + .isTrue(); + second.get().buffer().recycleBuffer(); + + // Sanity: the buffer queued via onBuffer is still consumable from the post-recovery path. + Optional third = channel.getNextBuffer(); + assertThat(third).isPresent(); + assertThat(third.get().buffer().getSize()).isEqualTo(30); + third.get().buffer().recycleBuffer(); + } + private static final class TestBufferPool extends NoOpBufferPool { @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java index 1ed1a42a66ea0..1aff6d8c7f0a1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java @@ -300,12 +300,16 @@ void testBufferFilteringCompleteFutureAggregation() throws IOException { assertThat(union.getStateConsumedFuture()).isNotDone(); // Complete buffer filtering on first gate only - channel1.finishReadRecoveredState(); + synchronized (ig1.getGateLock()) { + channel1.finishReadRecoveredState(); + } assertThat(ig1.getBufferFilteringCompleteFuture()).isDone(); assertThat(union.getBufferFilteringCompleteFuture()).isNotDone(); // Complete buffer filtering on second gate - channel2.finishReadRecoveredState(); + synchronized (ig2.getGateLock()) { + channel2.finishReadRecoveredState(); + } assertThat(ig2.getBufferFilteringCompleteFuture()).isDone(); assertThat(union.getBufferFilteringCompleteFuture()).isDone(); @@ -325,9 +329,7 @@ private static RecoveredInputChannel buildRecoveredChannel(SingleInputGate input new SimpleCounter(), 10) { @Override - protected InputChannel toInputChannelInternal( - java.util.ArrayDeque - remainingBuffers) { + protected InputChannel toInputChannelInternal(RecoveredBufferStoreImpl recoveredStore) { throw new UnsupportedOperationException(); } }; diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java index b850a7cc55370..29a2b3df65d95 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java @@ -30,6 +30,7 @@ import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredBufferStore; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateFactory; @@ -37,7 +38,6 @@ import org.apache.flink.runtime.taskmanager.NettyShuffleEnvironmentConfiguration; import java.io.IOException; -import java.util.ArrayDeque; /** * A benchmark-specific input gate factory which overrides the respective methods of creating {@link @@ -130,7 +130,7 @@ public TestLocalInputChannel( metrics.getNumBytesInLocalCounter(), metrics.getNumBuffersInLocalCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + RecoveredBufferStore.EMPTY); } @Override @@ -186,7 +186,7 @@ public TestRemoteInputChannel( metrics.getNumBytesInRemoteCounter(), metrics.getNumBuffersInRemoteCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + RecoveredBufferStore.EMPTY); } @Override