算法 数据结构 leetcode

Java 实现 Snowflake 算法

2019-07-13  本文已影响47人  又语

本文介绍 Java 实现 Snowflake 算法生成分布式 ID。


目录


Snowflake 算法简介

Snowflake 算法是 Twitter 开源的分布式 ID 生成算法,将 64 bit 划分为多个不同组成部分,每部分代表不同含义。

Snowflake 算法生成的分布式 ID 并非绝对唯一,但已满足绝大多数应用场景需求。


示例

package tutorial.java.util;

import java.rmi.UnexpectedException;
import java.util.concurrent.atomic.AtomicLong;

public class SnowflakeDistributedId {

    /**
     * Snowflake算法中第三部分长度,即数据中心和工作机器ID总共占位长度
     */
    private static final long DATA_CENTER_AND_WORKER_ID_BITS = 10;

    /**
     * Snowflake算法中第四部分长度,即自增序列占位长度
     */
    private static final long AUTO_INCREMENT_SEQUENCE_BITS = 12;

    /**
     * 自增序列最大值
     */
    private static final long MAX_SEQUENCE = 4095;

    /**
     * 开始时间戳
     */
    private final long epoch;

    /**
     * 数据中心ID
     */
    private final long dataCenterId;

    /**
     * 机器ID占位长度
     */
    private final long workerIdBits;

    /**
     * 机器ID
     */
    private final long workerId;

    /**
     * 保存上一次生成ID的时间戳
     */
    private long lastTimestamp;

    /**
     * 分布式ID自增序列
     */
    private AtomicLong autoIncrementSequence;

    /**
     * @param dataCenterIdBits 数据中心ID占位长度
     * @param dataCenterId     数据中心ID
     * @param workerId         工作机器ID
     */
    public SnowflakeDistributedId(long epoch, long dataCenterIdBits, long dataCenterId, long workerId) {
        this.epoch = epoch;
        this.dataCenterId = validateDataCenterId(dataCenterIdBits, dataCenterId);
        workerIdBits = DATA_CENTER_AND_WORKER_ID_BITS - dataCenterIdBits;
        this.workerId = validateWorkerId(workerId);
        this.lastTimestamp = -1L;
        this.autoIncrementSequence = new AtomicLong(0);
    }

    /**
     * 初始化数据中心ID
     *
     * @param dataCenterIdBits 数据中心ID占位长度
     * @param dataCenterId     数据中心ID
     * @return 校验通过的数据中心ID
     */
    private long validateDataCenterId(long dataCenterIdBits, long dataCenterId) {
        if (dataCenterIdBits < 0 || dataCenterIdBits >= DATA_CENTER_AND_WORKER_ID_BITS) {
            throw new IllegalArgumentException("Data center ID bits must be in [0, 10)!");
        }
        if (dataCenterIdBits > 0) {
            // 支持的最大数据中心 ID
            long maxDataCenterId = ~(-1 << dataCenterIdBits);
            if (dataCenterId < 0 || dataCenterId > maxDataCenterId) {
                throw new IllegalArgumentException("Data center ID must be in [0, " + maxDataCenterId + "]!");
            }
            return dataCenterId;
        }
        return -1;
    }

    /**
     * 初始化工作机器ID
     *
     * @param workerId 工作机器ID
     * @return 校验通过的工作机器ID
     */
    private long validateWorkerId(long workerId) {
        // 支持的最大机器ID
        long maxWorkerId = ~(-1 << this.workerIdBits);
        if (workerId < 0 || workerId > maxWorkerId) {
            throw new IllegalArgumentException("Worker ID must be in [0, " + maxWorkerId + "]!");
        }
        return workerId;
    }

    /**
     * 生成分布式ID
     *
     * @return long类型ID
     * @throws UnexpectedException 如果系统时间回退则抛出此异常
     */
    public long generate() throws UnexpectedException {
        long currentTimestamp = System.currentTimeMillis();
        // 如果当前时间小于上一次ID生成时间,说明系统时间回退
        if (currentTimestamp < lastTimestamp) {
            throw new UnexpectedException("System clock moved backward, refused to generate ID!");
        }
        long currentSequence;
        if (currentTimestamp == lastTimestamp) {
            // 如果当前时间等于上一次ID生成时间,获取自增序列值后加1
            currentSequence = autoIncrementSequence.getAndIncrement();
            // 如果获取的自增序列值大于允许的最大值
            if (currentSequence > MAX_SEQUENCE) {
                // 等待到下一毫秒
                currentTimestamp = block(currentTimestamp);
                // 更新时间戳
                lastTimestamp = currentTimestamp;
                // 重新获取自增序列值
                currentSequence = resetAutoIncrementSequence();
            }
        } else {
            // 如果当前时间大于上一次ID生成时间,重置自增序列并获取自增序列值后加1
            currentSequence = resetAutoIncrementSequence();
            // 更新时间戳
            lastTimestamp = currentTimestamp;
        }
        // 时间戳左移
        long id = (currentTimestamp - epoch) << (DATA_CENTER_AND_WORKER_ID_BITS + AUTO_INCREMENT_SEQUENCE_BITS);
        if (dataCenterId != -1) {
            // 数据中心ID左移
            id = id | (this.dataCenterId << (workerIdBits + AUTO_INCREMENT_SEQUENCE_BITS));
        }
        return id | (this.workerId << AUTO_INCREMENT_SEQUENCE_BITS) | currentSequence;
    }

    /**
     * 重置自增序列
     *
     * @return 自增序列值
     */
    private synchronized long resetAutoIncrementSequence() {
        autoIncrementSequence = new AtomicLong(0);
        return autoIncrementSequence.getAndIncrement();
    }

    /**
     * 阻塞至下一毫秒
     *
     * @param timestamp 当前时间戳
     * @return 下一毫秒时间戳
     */
    private long block(long timestamp) {
        long currentTimestamp = System.currentTimeMillis();
        while (currentTimestamp <= timestamp) {
            currentTimestamp = System.currentTimeMillis();
        }
        return currentTimestamp;
    }
}

单元测试

import org.junit.Assert;
import org.junit.Test;

import java.rmi.UnexpectedException;
import java.time.Instant;
import java.util.HashSet;
import java.util.Set;

public class SnowflakeDistributedIdTest {

    @Test
    public void test() {
        SnowflakeDistributedId id = new SnowflakeDistributedId(Instant.now().toEpochMilli(),
                5, 1, 8);
        Set<Long> ids = new HashSet<>();
        int iteratorTimes = 100000;
        Runnable runnable = () -> {
            for (int i = 0; i < iteratorTimes; i++) {
                try {
                    ids.add(id.generate());
                } catch (UnexpectedException e) {
                    Assert.fail();
                }
            }
        };
        Set<Thread> threads = new HashSet<>();
        int threadCount = 10;
        for (int i = 0; i < threadCount; i++) {
            threads.add(new Thread(runnable));
        }
        threads.forEach(thread -> {
            thread.start();
            try {
                thread.join();
            } catch (InterruptedException e) {
                Assert.fail();
            }
        });
        Assert.assertEquals(iteratorTimes * threadCount, ids.stream().distinct().count());
    }
}

单元测试说明:共启动 10 个线程,每个线程循环 100000 次执行生成 ID 操作,生成的 ID 全部放入 SET 数据结构中,执行过程抛出任何异常都会导致单元测试失败,最后检查 SET 中元素数量是否等于 10 * 100000,测试结果略。


总结

  1. Java 中 long 类型长度为 64 bit,因此 Java 实现 Snowflake 算法生成的 ID 即保存为 long 类型。
  2. 除 Snowflake 算法外,常见的分布式 ID 生成方案还包括:
    • UUID
    • 数据库生成
    • Redis 生成
    • 百度 UidGenerator
    • 美团 Leaf
上一篇下一篇

猜你喜欢

热点阅读