diff --git a/client/src/main/java/org/apache/rocketmq/client/ClientConfig.java b/client/src/main/java/org/apache/rocketmq/client/ClientConfig.java index 79cb04af1d0..9e012254329 100644 --- a/client/src/main/java/org/apache/rocketmq/client/ClientConfig.java +++ b/client/src/main/java/org/apache/rocketmq/client/ClientConfig.java @@ -101,6 +101,10 @@ public class ClientConfig { private boolean enableHeartbeatChannelEventListener = true; + private boolean enableConcurrentHeartbeat = false; + + private int concurrentHeartbeatThreadPoolSize = Runtime.getRuntime().availableProcessors(); + /** * The switch for message trace */ @@ -240,6 +244,8 @@ public void resetClientConfig(final ClientConfig cc) { this.namespaceV2 = cc.namespaceV2; this.enableTrace = cc.enableTrace; this.traceTopic = cc.traceTopic; + this.enableConcurrentHeartbeat = cc.enableConcurrentHeartbeat; + this.concurrentHeartbeatThreadPoolSize = cc.concurrentHeartbeatThreadPoolSize; } public ClientConfig cloneClientConfig() { @@ -272,6 +278,8 @@ public ClientConfig cloneClientConfig() { cc.namespaceV2 = namespaceV2; cc.enableTrace = enableTrace; cc.traceTopic = traceTopic; + cc.enableConcurrentHeartbeat = enableConcurrentHeartbeat; + cc.concurrentHeartbeatThreadPoolSize = concurrentHeartbeatThreadPoolSize; return cc; } @@ -525,6 +533,22 @@ public void setMaxPageSizeInGetMetadata(int maxPageSizeInGetMetadata) { this.maxPageSizeInGetMetadata = maxPageSizeInGetMetadata; } + public boolean isEnableConcurrentHeartbeat() { + return this.enableConcurrentHeartbeat; + } + + public void setEnableConcurrentHeartbeat(boolean enableConcurrentHeartbeat) { + this.enableConcurrentHeartbeat = enableConcurrentHeartbeat; + } + + public int getConcurrentHeartbeatThreadPoolSize() { + return concurrentHeartbeatThreadPoolSize; + } + + public void setConcurrentHeartbeatThreadPoolSize(int concurrentHeartbeatThreadPoolSize) { + this.concurrentHeartbeatThreadPoolSize = concurrentHeartbeatThreadPoolSize; + } + @Override public String toString() { return "ClientConfig{" + @@ -558,6 +582,8 @@ public String toString() { ", enableHeartbeatChannelEventListener=" + enableHeartbeatChannelEventListener + ", enableTrace=" + enableTrace + ", traceTopic='" + traceTopic + '\'' + + ", enableConcurrentHeartbeat=" + enableConcurrentHeartbeat + + ", concurrentHeartbeatThreadPoolSize=" + concurrentHeartbeatThreadPoolSize + '}'; } } diff --git a/client/src/main/java/org/apache/rocketmq/client/impl/factory/MQClientInstance.java b/client/src/main/java/org/apache/rocketmq/client/impl/factory/MQClientInstance.java index bb838a62650..0c39aa97b20 100644 --- a/client/src/main/java/org/apache/rocketmq/client/impl/factory/MQClientInstance.java +++ b/client/src/main/java/org/apache/rocketmq/client/impl/factory/MQClientInstance.java @@ -42,6 +42,7 @@ import org.apache.rocketmq.common.MQVersion; import org.apache.rocketmq.common.MixAll; import org.apache.rocketmq.common.ServiceState; +import org.apache.rocketmq.common.ThreadFactoryImpl; import org.apache.rocketmq.common.constant.PermName; import org.apache.rocketmq.common.filter.ExpressionType; import org.apache.rocketmq.common.message.MessageExt; @@ -68,6 +69,7 @@ import org.apache.rocketmq.remoting.protocol.route.QueueData; import org.apache.rocketmq.remoting.protocol.route.TopicRouteData; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -79,7 +81,10 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; @@ -125,7 +130,7 @@ public class MQClientInstance { */ private final ConcurrentMap> brokerAddrTable = new ConcurrentHashMap<>(); - private final ConcurrentMap> brokerVersionTable = new ConcurrentHashMap<>(); + private final ConcurrentMap> brokerVersionTable = new ConcurrentHashMap<>(); private final Set brokerSupportV2HeartbeatSet = new HashSet<>(); private final ConcurrentMap brokerAddrHeartbeatFingerprintTable = new ConcurrentHashMap<>(); private final ScheduledExecutorService scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(r -> new Thread(r, "MQClientFactoryScheduledThread")); @@ -142,6 +147,7 @@ public Thread newThread(Runnable r) { private final AtomicLong sendHeartbeatTimesTotal = new AtomicLong(0); private ServiceState serviceState = ServiceState.CREATE_JUST; private final Random random = new Random(); + private ExecutorService concurrentHeartbeatExecutor; public MQClientInstance(ClientConfig clientConfig, int instanceIndex, String clientId) { this(clientConfig, instanceIndex, clientId, null); @@ -217,6 +223,12 @@ public void onChannelActive(String remoteAddr, Channel channel) { this.consumerStatsManager = new ConsumerStatsManager(this.scheduledExecutorService); + if (this.clientConfig.isEnableConcurrentHeartbeat()) { + this.concurrentHeartbeatExecutor = Executors.newFixedThreadPool( + clientConfig.getConcurrentHeartbeatThreadPoolSize(), + new ThreadFactoryImpl("MQClientConcurrentHeartbeatThread_", true)); + } + log.info("Created a new client Instance, InstanceIndex:{}, ClientID:{}, ClientConfig:{}, ClientVersion:{}, SerializerType:{}", instanceIndex, this.clientId, @@ -537,6 +549,8 @@ public boolean sendHeartbeatToAllBrokerWithLock() { try { if (clientConfig.isUseHeartbeatV2()) { return this.sendHeartbeatToAllBrokerV2(false); + } else if (clientConfig.isEnableConcurrentHeartbeat()) { + return this.sendHeartbeatToAllBrokerConcurrently(); } else { return this.sendHeartbeatToAllBroker(); } @@ -641,7 +655,7 @@ private boolean sendHeartbeatToBroker(long id, String brokerName, String addr, H try { int version = this.mQClientAPIImpl.sendHeartbeat(addr, heartbeatData, clientConfig.getMqClientApiTimeout()); if (!this.brokerVersionTable.containsKey(brokerName)) { - this.brokerVersionTable.put(brokerName, new HashMap<>(4)); + this.brokerVersionTable.put(brokerName, new ConcurrentHashMap<>(4)); } this.brokerVersionTable.get(brokerName).put(addr, version); long times = this.sendHeartbeatTimesTotal.getAndIncrement(); @@ -721,7 +735,7 @@ private boolean sendHeartbeatToBrokerV2(long id, String brokerName, String addr, } version = heartbeatV2Result.getVersion(); if (!this.brokerVersionTable.containsKey(brokerName)) { - this.brokerVersionTable.put(brokerName, new HashMap<>(4)); + this.brokerVersionTable.put(brokerName, new ConcurrentHashMap<>(4)); } this.brokerVersionTable.get(brokerName).put(addr, version); long times = this.sendHeartbeatTimesTotal.getAndIncrement(); @@ -780,6 +794,100 @@ private boolean sendHeartbeatToAllBrokerV2(boolean isRebalance) { return true; } + private class ClientHeartBeatTask { + private final String brokerName; + private final Long brokerId; + private final String brokerAddr; + private final HeartbeatData heartbeatData; + + public ClientHeartBeatTask(String brokerName, Long brokerId, String brokerAddr, HeartbeatData heartbeatData) { + this.brokerName = brokerName; + this.brokerId = brokerId; + this.brokerAddr = brokerAddr; + this.heartbeatData = heartbeatData; + } + + public void execute() throws Exception { + int version = MQClientInstance.this.mQClientAPIImpl.sendHeartbeat( + brokerAddr, heartbeatData, MQClientInstance.this.clientConfig.getMqClientApiTimeout()); + + ConcurrentHashMap inner = MQClientInstance.this.brokerVersionTable + .computeIfAbsent(brokerName, k -> new ConcurrentHashMap<>(4)); + inner.put(brokerAddr, version); + } + } + + private boolean sendHeartbeatToAllBrokerConcurrently() { + final HeartbeatData heartbeatData = this.prepareHeartbeatData(false); + final boolean producerEmpty = heartbeatData.getProducerDataSet().isEmpty(); + final boolean consumerEmpty = heartbeatData.getConsumerDataSet().isEmpty(); + + if (producerEmpty && consumerEmpty) { + log.warn("sending heartbeat, but no consumer and no producer. [{}]", this.clientId); + return false; + } + + if (this.brokerAddrTable.isEmpty()) { + return false; + } + + long times = this.sendHeartbeatTimesTotal.getAndIncrement(); + List tasks = new ArrayList<>(); + for (Entry> entry : this.brokerAddrTable.entrySet()) { + String brokerName = entry.getKey(); + HashMap oneTable = entry.getValue(); + if (oneTable != null) { + for (Map.Entry entry1 : oneTable.entrySet()) { + Long id = entry1.getKey(); + String addr = entry1.getValue(); + if (addr == null) continue; + if (consumerEmpty && id != MixAll.MASTER_ID) continue; + tasks.add(new ClientHeartBeatTask(brokerName, id, addr, heartbeatData)); + } + } + } + + if (tasks.isEmpty()) { + return false; + } + + final CountDownLatch latch = new CountDownLatch(tasks.size()); + + for (ClientHeartBeatTask task : tasks) { + try { + this.concurrentHeartbeatExecutor.execute(() -> { + try { + task.execute(); + if (times % 20 == 0) { + log.info("send heart beat to broker[{} {} {}] success", task.brokerName, task.brokerId, task.brokerAddr); + } + } catch (Exception e) { + if (MQClientInstance.this.isBrokerInNameServer(task.brokerAddr)) { + log.warn("send heart beat to broker[{} {} {}] failed", task.brokerName, task.brokerId, task.brokerAddr, e); + } else { + log.warn("send heart beat to broker[{} {} {}] exception, because the broker not up, forget it", + task.brokerName, task.brokerId, task.brokerAddr, e); + } + } finally { + latch.countDown(); + } + }); + } catch (RejectedExecutionException rex) { + log.warn("heartbeat submission rejected for broker[{} {} {}], will skip this round", task.brokerName, task.brokerId, task.brokerAddr, rex); + latch.countDown(); + } + } + + try { + // wait all tasks finish + latch.await(); + } catch (InterruptedException ie) { + log.warn("Interrupted while waiting for broker heartbeat tasks to complete", ie); + Thread.currentThread().interrupt(); + } + return true; + } + public boolean updateTopicRouteInfoFromNameServer(final String topic, boolean isDefault, DefaultMQProducer defaultMQProducer) { try { @@ -971,6 +1079,7 @@ public void shutdown() { this.scheduledExecutorService.shutdown(); this.mQClientAPIImpl.shutdown(); this.rebalanceService.shutdown(); + this.concurrentHeartbeatExecutor.shutdown(); MQClientManager.getInstance().removeClientFactory(this.clientId); log.info("the client factory [{}] shutdown OK", this.clientId); diff --git a/client/src/test/java/org/apache/rocketmq/client/impl/factory/MQClientInstanceTest.java b/client/src/test/java/org/apache/rocketmq/client/impl/factory/MQClientInstanceTest.java index 39cff5db82b..a82ec3f3588 100644 --- a/client/src/test/java/org/apache/rocketmq/client/impl/factory/MQClientInstanceTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/impl/factory/MQClientInstanceTest.java @@ -74,6 +74,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; @@ -82,9 +83,11 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -510,4 +513,45 @@ private List createBrokerDatas() { brokerData.setBrokerAddrs(brokerAddrs); return Collections.singletonList(brokerData); } + + @Test + public void testSendHeartbeatToAllBrokerConcurrently() { + try { + String brokerName = "BrokerA"; + HashMap addrMap = new HashMap<>(); + addrMap.put(0L, "127.0.0.1:10911"); + addrMap.put(1L, "127.0.0.1:10912"); + addrMap.put(2L, "127.0.0.1:10913"); + brokerAddrTable.put(brokerName, addrMap); + + DefaultMQPushConsumerImpl mockConsumer = mock(DefaultMQPushConsumerImpl.class); + when(mockConsumer.subscriptions()).thenReturn(Collections.singleton(new SubscriptionData())); + mqClientInstance.registerConsumer("TestConsumerGroup", mockConsumer); + + ClientConfig clientConfig = new ClientConfig(); + FieldUtils.writeDeclaredField(clientConfig, "enableConcurrentHeartbeat", true, true); + FieldUtils.writeDeclaredField(mqClientInstance, "clientConfig", clientConfig, true); + + ExecutorService mockExecutor = mock(ExecutorService.class); + doAnswer(invocation -> { + try { + Runnable task = invocation.getArgument(0); + task.run(); + } catch (Exception e) { + // ignore + } + return null; + }).when(mockExecutor).execute(any(Runnable.class)); + FieldUtils.writeDeclaredField(mqClientInstance, "concurrentHeartbeatExecutor", mockExecutor, true); + MQClientAPIImpl mockMqClientAPIImpl = mock(MQClientAPIImpl.class); + FieldUtils.writeDeclaredField(mqClientInstance, "mQClientAPIImpl", mockMqClientAPIImpl, true); + + mqClientInstance.sendHeartbeatToAllBrokerWithLock(); + + assertTrue(true); + + } catch (Exception e) { + fail("failed: " + e.getMessage()); + } + } }