本文目录导读:

我来详细介绍Java中实现一致性哈希的完整案例。
基础一致性哈希实现
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.*;
public class ConsistentHash<T> {
// 虚拟节点数量
private final int virtualNodeCount;
// 哈希环(有序的TreeMap)
private final SortedMap<Integer, T> ring = new TreeMap<>();
// 节点对应的虚拟节点列表
private final Map<T, List<Integer>> nodeVirtualNodes = new HashMap<>();
public ConsistentHash(int virtualNodeCount, Collection<T> nodes) {
this.virtualNodeCount = virtualNodeCount;
for (T node : nodes) {
addNode(node);
}
}
// 添加节点
public void addNode(T node) {
List<Integer> virtualNodes = new ArrayList<>();
for (int i = 0; i < virtualNodeCount; i++) {
// 为每个节点生成虚拟节点
String virtualNodeName = node.toString() + "#" + i;
int hash = getHash(virtualNodeName);
ring.put(hash, node);
virtualNodes.add(hash);
}
nodeVirtualNodes.put(node, virtualNodes);
}
// 移除节点
public void removeNode(T node) {
List<Integer> virtualNodes = nodeVirtualNodes.get(node);
if (virtualNodes != null) {
for (Integer hash : virtualNodes) {
ring.remove(hash);
}
nodeVirtualNodes.remove(node);
}
}
// 获取key对应的节点
public T getNode(Object key) {
if (ring.isEmpty()) {
return null;
}
int hash = getHash(key.toString());
// 找到大于等于hash值的第一个节点
SortedMap<Integer, T> tailMap = ring.tailMap(hash);
// 如果不存在,则取第一个节点(形成环)
Integer nodeHash = tailMap.isEmpty() ? ring.firstKey() : tailMap.firstKey();
return ring.get(nodeHash);
}
// 计算hash值
public static int getHash(String key) {
try {
MessageDigest md5 = MessageDigest.getInstance("MD5");
md5.update(key.getBytes());
byte[] digest = md5.digest();
// 取前4个字节作为hash值
int hash = ((digest[3] & 0xFF) << 24)
| ((digest[2] & 0xFF) << 16)
| ((digest[1] & 0xFF) << 8)
| (digest[0] & 0xFF);
return hash & 0x7FFFFFFF; // 确保非负
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("MD5 algorithm not found", e);
}
}
// 获取所有节点
public Set<T> getNodes() {
return new HashSet<>(ring.values());
}
// 获取节点数量
public int size() {
return nodeVirtualNodes.size();
}
}
测试和演示代码
public class ConsistentHashDemo {
public static void main(String[] args) {
// 1. 基础功能测试
testBasicFunctionality();
// 2. 节点增减测试
testNodeAdditionAndRemoval();
// 3. 缓存服务器负载均衡模拟
testCacheLoadBalancing();
}
// 基础功能测试
public static void testBasicFunctionality() {
System.out.println("========== 基础功能测试 ==========");
// 创建缓存服务器列表
List<String> servers = Arrays.asList(
"192.168.1.1:6379",
"192.168.1.2:6379",
"192.168.1.3:6379"
);
// 创建一致性哈希,每个服务器100个虚拟节点
ConsistentHash<String> consistentHash =
new ConsistentHash<>(100, servers);
// 模拟缓存key
String[] keys = {"user:1001", "user:1002", "user:1003",
"order:2023001", "product:10001"};
System.out.println("缓存服务器分布:");
for (String key : keys) {
String server = consistentHash.getNode(key);
System.out.println(key + " -> " + server);
}
}
// 节点增减测试
public static void testNodeAdditionAndRemoval() {
System.out.println("\n========== 节点增减测试 ==========");
// 初始服务器列表
List<String> servers = new ArrayList<>(Arrays.asList(
"Server-A", "Server-B", "Server-C"
));
ConsistentHash<String> consistentHash =
new ConsistentHash<>(100, servers);
// 100个测试key
Map<String, String> initialMapping = new HashMap<>();
for (int i = 1; i <= 100; i++) {
String key = "key" + i;
initialMapping.put(key, consistentHash.getNode(key));
}
// 添加新节点
System.out.println("添加新节点 Server-D:");
consistentHash.addNode("Server-D");
// 统计映射变化
int changed = 0;
for (Map.Entry<String, String> entry : initialMapping.entrySet()) {
String newServer = consistentHash.getNode(entry.getKey());
if (!newServer.equals(entry.getValue())) {
changed++;
}
}
System.out.println("添加节点后,100个key中 " + changed +
" 个映射发生变化(理想情况下约 25%)");
}
// 缓存服务器负载均衡模拟
public static void testCacheLoadBalancing() {
System.out.println("\n========== 负载均衡测试 ==========");
// 测试不同虚拟节点数量的影响
int[] virtualNodeCounts = {1, 10, 100, 200};
for (int virtualCount : virtualNodeCounts) {
System.out.println("\n虚拟节点数量: " + virtualCount);
List<String> servers = Arrays.asList(
"Server-A", "Server-B", "Server-C", "Server-D"
);
ConsistentHash<String> ch =
new ConsistentHash<>(virtualCount, servers);
// 模拟10000个缓存请求
Map<String, Integer> distribution = new HashMap<>();
for (String server : servers) {
distribution.put(server, 0);
}
for (int i = 0; i < 10000; i++) {
String key = "test:key:" + i;
String server = ch.getNode(key);
distribution.put(server, distribution.get(server) + 1);
}
// 输出分布情况
System.out.println("负载分布:");
for (Map.Entry<String, Integer> entry : distribution.entrySet()) {
double percentage = (entry.getValue() / 100.0);
System.out.println(entry.getKey() + ": " +
entry.getValue() + " (" + percentage + "%)");
}
}
}
}
优化版实现(带监控和统计)
public class OptimizedConsistentHash<T> {
private final int virtualNodeCount;
private final SortedMap<Integer, VirtualNode<T>> ring = new TreeMap<>();
private final Map<T, List<VirtualNode<T>>> nodeVirtualNodes = new HashMap<>();
private volatile boolean monitoring = false;
// 监控数据
private final Map<T, AtomicLong> requestCount = new ConcurrentHashMap<>();
private final Map<T, AtomicLong> hitCount = new ConcurrentHashMap<>();
public OptimizedConsistentHash(int virtualNodeCount, Collection<T> nodes) {
this.virtualNodeCount = virtualNodeCount;
for (T node : nodes) {
addNode(node);
}
startMonitoring();
}
public void addNode(T node) {
List<VirtualNode<T>> virtualNodes = new ArrayList<>();
for (int i = 0; i < virtualNodeCount; i++) {
VirtualNode<T> virtualNode = new VirtualNode<>(node, i);
int hash = getHash(virtualNode.getKey());
ring.put(hash, virtualNode);
virtualNodes.add(virtualNode);
}
nodeVirtualNodes.put(node, virtualNodes);
requestCount.put(node, new AtomicLong(0));
hitCount.put(node, new AtomicLong(0));
}
public void removeNode(T node) {
List<VirtualNode<T>> virtualNodes = nodeVirtualNodes.get(node);
if (virtualNodes != null) {
for (VirtualNode<T> virtualNode : virtualNodes) {
ring.remove(getHash(virtualNode.getKey()));
}
nodeVirtualNodes.remove(node);
requestCount.remove(node);
hitCount.remove(node);
}
}
public T getNode(Object key, boolean trackStats) {
if (ring.isEmpty()) {
return null;
}
int hash = getHash(key.toString());
SortedMap<Integer, VirtualNode<T>> tailMap = ring.tailMap(hash);
Integer nodeHash = tailMap.isEmpty() ? ring.firstKey() : tailMap.firstKey();
VirtualNode<T> virtualNode = ring.get(nodeHash);
T node = virtualNode.getPhysicalNode();
if (trackStats) {
requestCount.get(node).incrementAndGet();
}
return node;
}
// 记录缓存命中
public void recordHit(T node) {
hitCount.get(node).incrementAndGet();
}
// 获取命中率
public Map<T, Double> getHitRates() {
Map<T, Double> hitRates = new HashMap<>();
for (T node : nodeVirtualNodes.keySet()) {
long requests = requestCount.get(node).get();
long hits = hitCount.get(node).get();
double rate = requests > 0 ? (double) hits / requests : 0;
hitRates.put(node, rate * 100);
}
return hitRates;
}
// 启动监控
private void startMonitoring() {
monitoring = true;
new Thread(() -> {
while (monitoring) {
try {
Thread.sleep(60000); // 每分钟输出一次
printStats();
} catch (InterruptedException e) {
break;
}
}
}).start();
}
// 打印统计信息
public void printStats() {
System.out.println("\n=== 缓存节点统计 ===");
for (T node : nodeVirtualNodes.keySet()) {
long requests = requestCount.get(node).get();
long hits = hitCount.get(node).get();
double hitRate = requests > 0 ? (double) hits / requests * 100 : 0;
System.out.printf("节点 %s: 请求=%d, 命中=%d, 命中率=%.2f%%\n",
node, requests, hits, hitRate);
}
}
public void stopMonitoring() {
this.monitoring = false;
}
// 虚拟节点类
private static class VirtualNode<T> {
private final T physicalNode;
private final int replicaIndex;
public VirtualNode(T physicalNode, int replicaIndex) {
this.physicalNode = physicalNode;
this.replicaIndex = replicaIndex;
}
public T getPhysicalNode() {
return physicalNode;
}
public String getKey() {
return physicalNode.toString() + "#" + replicaIndex;
}
}
// Hash计算
public static int getHash(String key) {
// 使用FNV-1a算法,性能更好
int hash = 2166136261;
for (int i = 0; i < key.length(); i++) {
hash ^= key.charAt(i);
hash *= 16777619;
}
return hash & 0x7FFFFFFF;
}
}
使用示例
public class UsageExample {
public static void main(String[] args) {
// 创建缓存节点
List<CacheNode> nodes = Arrays.asList(
new CacheNode("127.0.0.1", 6379),
new CacheNode("127.0.0.2", 6379),
new CacheNode("127.0.0.3", 6379)
);
// 使用优化版一致性哈希(200个虚拟节点)
OptimizedConsistentHash<CacheNode> ch =
new OptimizedConsistentHash<>(200, nodes);
// 缓存数据
Map<String, String> cacheData = new HashMap<>();
for (int i = 1; i <= 1000; i++) {
String key = "cache:data:" + i;
String value = "value-" + i;
CacheNode node = ch.getNode(key, true);
// 实际应用中这里会将数据存储到对应的Redis节点
cacheData.put(key, value);
System.out.println("Key: " + key + " -> 节点: " + node.getAddress());
}
// 模拟缓存命中
for (int i = 1; i <= 500; i++) {
String key = "cache:data:" + i;
CacheNode node = ch.getNode(key, true);
ch.recordHit(node); // 记录命中
}
// 查看统计
ch.printStats();
ch.stopMonitoring();
}
}
// 缓存节点类
class CacheNode {
private final String host;
private final int port;
public CacheNode(String host, int port) {
this.host = host;
this.port = port;
}
public String getAddress() {
return host + ":" + port;
}
@Override
public String toString() {
return getAddress();
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CacheNode cacheNode = (CacheNode) o;
return port == cacheNode.port && Objects.equals(host, cacheNode.host);
}
@Override
public int hashCode() {
return Objects.hash(host, port);
}
}
这个实现的主要特点:
- 均匀分布:使用虚拟节点确保数据均匀分布
- 最小干扰:节点增减时只影响少量数据
- 高效查找:使用TreeMap实现O(log n)的查找效率
- 可监控:支持统计和监控功能
应用场景包括分布式缓存、负载均衡、数据库分片等。