Kotlin 协程源码阅读笔记 —— Mutex

2023-12-27  本文已影响0人  BlueSocks

Kotlin 协程源码阅读笔记 —— Mutex

我们在 Java / Kotlin 编程时如果需要某段代码块同一时间只有一个线程能够执行时,通常是使用 synchronized,但是协程中可不能使用 synchronized,为什么呢?如果你了解过协程的工作方式就不会觉得奇怪(如果不了解协程工作方式的同学,可以看
以下代码为什么不能正常工作呢(其中 hello()world() 方法是 suspend 方法):

    suspend fun helloWorld(): String {
        synchronized(this) {
            val hello = hello()
            val world = world()
            return hello + world
        }
    }

因为协程调用 suspend 方法后就相当于调用了一个异步函数,而后续的恢复执行时就相当于异步函数的回调成功。这个时候你想想:synchronized 可以在调用异步函数的时候获取锁,然后在回调成功的时候释放前面获取的锁吗?当然不行。那么我们要在协程中如何写出安全的代码(这里感觉还不能用线程安全这个词,因为协程中的锁不是基于线程的,或者叫协程安全?反正感觉也怪怪的)?答案是使用 Kotlin 协程中提供的 Mutex(它是一种不允许重入的互斥锁)。上面的代码改成以下代码就可以执行了:

    val lock = Mutex()
    suspend fun helloWorld(): String {
        lock.lock()
        val hello = hello()
        val world = world()
        lock.unlock()
        return hello + world
    }

emmm,发散一下,那么通过 Java 中的可重入锁 ReentrantLock 可以对协程代码加锁吗?答案也是不行的,因为 ReentrantLock 的锁的状态也是针对线程的,在某个线程中获取了锁,也一定要在那个线程释放锁。而协程挂起前执行的线程和恢复后所在的线程不一定是同一个线程,所以不能使用 ReentrantLock。举一个例子:协程中的代码执行的线程都是由 ContinuationInterceptor 来决定的,通常也是通过线程池来管理,在挂起前执行的线程,在执行挂起后,这个线程也就闲置了,闲置后他还可以去做其他的工作,比如在别的协程中做另外的工作,而挂起的协程恢复后需要线程执行,但是挂起前的线程有别的事情要做处于忙碌状态,这样 ContinuationInterceptor 就需要指定一个新的线程去执行恢复后的任务。所以这样挂起前后恢复后的线程就不一致了。

emmm,继续发散,像 Java 中线程安全的集合类,比如 ConcurrentHashMap 可以使用吗?可以使用,因为他们的内部又没有 suspend 函数,当然可以使用。只要没有使用 synchronized 或者 ReentrantLock 来锁 suspend 的函数都是没有问题的。

前面废话说的有点多,我们进入主题:Mutex 是如何工作的。

源码阅读基于 Kotlin 协程 1.8.0-RC2

获取锁

首先通过以下方法创建一个 Mutex 对象:

@Suppress("FunctionName")
public fun Mutex(locked: Boolean = false): Mutex =
    MutexImpl(locked)

它的实现类是 MutexImpl,它继承于 SemaphoreImpl,这里不要和 Java 中的 Semaphore 搞混了,他们完全不是一回事儿。

我们看看 MutexImpl 获取锁的方法 lock()

override suspend fun lock(owner: Any?) {
    // 尝试获取锁
    if (tryLock(owner)) return
    // 获取锁失败,挂起当前协程,直到锁可用后再恢复
    lockSuspend(owner)
}

这里需要解释一下这个 owner 参数,它主要是用来 debug 发现我们代码中锁使用的问题的。lock()unlock() 中传入的 owner 必须是同一个实例,如果两次不一样就会抛出异常。如果你在代码中成对调用 lock()unlock() 的,那就不会有问题,也没有必要传一个 owner 对象。
首先通过 tryLock() 方法尝试去获取锁,如果获取成功就返回;如果获取失败通过 lockSuspend() 方法挂起当前的协程,等到锁被释放后再唤醒当前协程。

tryLock()

override fun tryLock(owner: Any?): Boolean = when (tryLockImpl(owner)) {
    TRY_LOCK_SUCCESS -> true
    TRY_LOCK_FAILED -> false
    TRY_LOCK_ALREADY_LOCKED_BY_OWNER -> error("This mutex is already locked by the specified owner: $owner")
    else -> error("unexpected")
}

这里继续调用 tryLockImpl() 方法去获取锁,如果返回值为 TRY_LOCK_SUCCESS 表示获取锁成功;TRY_LOCK_FAILED 表示获取锁失败;TRY_LOCK_ALREADY_LOCKED_BY_OWNER 表示同一个 owner 重复调用 lock() 方法,抛出异常。

我们再来看看 tryLockImpl() 方法:

private fun tryLockImpl(owner: Any?): Int {
    while (true) {
        // 获取锁
        if (tryAcquire()) {
            // 获取锁成功
            assert { this.owner.value === NO_OWNER }
            // 设置新的 owner
            this.owner.value = owner
            return TRY_LOCK_SUCCESS
        } else {
            // 获取锁失败
            // The semaphore permit acquisition has failed.
            // However, we need to check that this mutex is not
            // locked by our owner.
            if (owner == null) return TRY_LOCK_FAILED
            // 检查 owner 状态
            when (holdsLockImpl(owner)) {
                // This mutex is already locked by our owner.
                HOLDS_LOCK_YES -> return TRY_LOCK_ALREADY_LOCKED_BY_OWNER
                // This mutex is locked by another owner, `trylock(..)` must return `false`.
                HOLDS_LOCK_ANOTHER_OWNER -> return TRY_LOCK_FAILED
                // This mutex is no longer locked, restart the operation.
                HOLDS_LOCK_UNLOCKED -> continue
            }
        }
    }
}

private fun holdsLockImpl(owner: Any?): Int {
    while (true) {
        // Is this mutex locked?
        if (!isLocked) return HOLDS_LOCK_UNLOCKED
        val curOwner = this.owner.value
        // Wait in a spin-loop until the owner is set
        if (curOwner === NO_OWNER) continue // <-- ATTENTION, BLOCKING PART HERE
        // Check the owner
        return if (curOwner === owner) HOLDS_LOCK_YES else HOLDS_LOCK_ANOTHER_OWNER
    }
}

通过 tryAcquire() 尝试获取锁,后面的逻辑又分为两部分,获取锁成功和获取锁失败。

  1. 获取锁成功
    代码很简单,直接更新 owner 然后返回 TRY_LOCK_SUCCESS 表示获取锁成功。
  2. 获取锁失败
    如果是 owner 为空,直接返回 TRY_LOCK_FAILED 表示获取锁失败;如果 owner 不为空,那么会通过 holdsLockImpl() 方法检查 owner 状态,根据不同的返回值有不同的处理方式:

继续看看 SemaphoreImpl#tryAcquire() 方法的实现:

override fun tryAcquire(): Boolean {
    while (true) {
        // Get the current number of available permits.
        val p = _availablePermits.value
        // Is the number of available permits greater
        // than the maximal one because of an incorrect
        // `release()` call without a preceding `acquire()`?
        // Change it to `permits` and start from the beginning.
        if (p > permits) {
            coerceAvailablePermitsAtMaximum()
            continue
        }
        // Try to decrement the number of available
        // permits if it is greater than zero.
        if (p <= 0) return false
        if (_availablePermits.compareAndSet(p, p - 1)) return true
    }
}

这里要解释一下 _availablePermits,只有当它大于 0 时才可以获取锁,默认的 permits 是 1,也就是同时只有一个 lock() 方法能够获取到锁,其他的 lock() 方法只能等获取到锁的地方释放后才能继续执行。
_availablePermits 大于 0 时,通过 CAS 的方式把 _availablePermits 中的值减 1,如果 CAS 操作失败就重试,成功就直接返回 true。在 Mutex 中很多地方用到了 CAS 自旋的方式去修改值,如果不懂的同学可以去网上找找 CAS 的概念。

lockSuspend()

private suspend fun lockSuspend(owner: Any?) = suspendCancellableCoroutineReusable<Unit> { cont ->
    val contWithOwner = CancellableContinuationWithOwner(cont, owner)
    acquire(contWithOwner)
}

这里通过 suspendCancellableCoroutineReusable() 方法获取到了协程的 Continuation 对象,然后使用 CancellableContinuationWithOwner 对象将原来的 Continuation 封装了一下,然后继续调用 acquire() 方法。

继续看看 SemaphoreImpl#acquire() 方法的实现:

protected fun acquire(waiter: CancellableContinuation<Unit>) = acquire(
    waiter = waiter,
    suspend = { cont -> addAcquireToQueue(cont as Waiter) },
    onAcquired = { cont -> cont.resume(Unit, onCancellationRelease) }
)

private inline fun <W> acquire(waiter: W, suspend: (waiter: W) -> Boolean, onAcquired: (waiter: W) -> Unit) {
    while (true) {
        // Decrement the number of available permits at first.
        // 先获取 _availablePermits 值,然后再减 1
        val p = decPermits()
        // Is the permit acquired?
        if (p > 0) {
            // 表示获取锁成功
            onAcquired(waiter)
            return
        }
        // Permit has not been acquired, try to suspend.
        // 执行挂起操作。
        if (suspend(waiter)) return
    }
}

suspendonAcquired 这两个函数对象分别表示挂起操作和获取锁成功的操作。suspend 中通过 addAcquireToQueue() 方法将当前 Continaution 对象添加到等待队列;而 onAcquired 就非常简单了直接通过 Continuation#resume() 方法恢复协程。

我们看看 addAcquireToQueue() 方法的实现:

private fun addAcquireToQueue(waiter: Waiter): Boolean {
    val curTail = this.tail.value
    // 获取 id
    val enqIdx = enqIdx.getAndIncrement()
    // 获取创建 Segment 的方法
    val createNewSegment = ::createSegment
    // 从链表尾部开始查找 Segment,如果没有查找到就创建一个新的 Semgent
    val segment = this.tail.findSegmentAndMoveForward(id = enqIdx / SEGMENT_SIZE, startFrom = curTail,
        createNewSegment = createNewSegment).segment // cannot be closed
        
    // 计算 Continuation 存储在 Segemnt 中的 index    
    val i = (enqIdx % SEGMENT_SIZE).toInt()
    // the regular (fast) path -- if the cell is empty, try to install continuation
    // 通过 CAS 的方式将 Continuation 添加到 Segment 中去
    if (segment.cas(i, null, waiter)) { // installed continuation successfully
        // 添加成功,注册 Continuation 被取消的监听,取消后会通知 Segment
        waiter.invokeOnCancellation(segment, i)
        return true
    }
    // ...
    // CAS 操作失败,在 acquire() 方法中会进行重试。
    return false // broken cell, need to retry on a different cell
}

private fun createSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev, 0)

和等待队列相关的代码就要复杂一丢丢了,先解释一下,我们的等待的 Continuation 是存放在 Segment 中的,每个 Segment 最多能够存放 SEGMENT_SIZE (默认 16) 个 Continuation,存放的方式是数组, Segment 存放满了,就会创建新的 SegmentSegment 之间是链表的存储方式。

简单整理下上面方法的流程:

  1. 通过 enqIdx 用来计算入队的 id,它是一个原子类,依次递增的。
  2. 通过 id / SEGMENT_SIZE 的方式计算 Segmentid,然后通过方法 findSegmentAndMoveForward()Segment 链表尾开始查找一个可用的 Segment,如果没有可用的了,就通过 createSegment() 方法创建一个新的。
  3. 通过 id % SEGEMNT_SIZE 的方式计算出 Continuation 存放在 Segment 中的数组的位置。
  4. 通过 CAS 的方法将 Continuation 添加到 Segment 中,如果添加成功会调用 Continuaiton#invokeOnCancellation() 方法来监听协程取消时的消息;如果添加失败会触发 acquire() 方法重试。

我们再简单看看 findSegmentAndMoveForward() 方法是如何查找和创建一个新的 Segment 的:

@Suppress("NOTHING_TO_INLINE")
internal inline fun <S : Segment<S>> AtomicRef<S>.findSegmentAndMoveForward(
    id: Long,
    startFrom: S,
    noinline createNewSegment: (id: Long, prev: S) -> S
): SegmentOrClosed<S> {
    while (true) {
        val s = startFrom.findSegmentInternal(id, createNewSegment)
        // 检查查询到的 Segment 状态
        if (s.isClosed || moveForward(s.segment)) return s
    }
}

internal fun <S : Segment<S>> S.findSegmentInternal(
    id: Long,
    createNewSegment: (id: Long, prev: S) -> S
): SegmentOrClosed<S> {
    var cur: S = this
    // 当前的 id 如果小于目标 id 或者当前已经 remove了 执行查找
    while (cur.id < id || cur.isRemoved) {
        val next = cur.nextOrIfClosed { return SegmentOrClosed(CLOSED) }
        // 如果 next 为空就表示需要创建新的 Segment,反之继续进入循环判断 id
        if (next != null) { // there is a next node -- move there
            cur = next
            continue
        }
        // 创建一个新的 Segemnt
        val newTail = createNewSegment(cur.id + 1, cur)
        // 将旧的 tail 的 next 指向新的 Segemnt
        if (cur.trySetNext(newTail)) { // successfully added new node -- move there
            if (cur.isRemoved) cur.remove()
            cur = newTail
        }
    }
    return SegmentOrClosed(cur)
}

@Suppress("NOTHING_TO_INLINE", "RedundantNullableReturnType") // Must be inline because it is an AtomicRef extension
internal inline fun <S : Segment<S>> AtomicRef<S>.moveForward(to: S): Boolean = loop { cur ->
    if (cur.id >= to.id) return true
    if (!to.tryIncPointers()) return false
    if (compareAndSet(cur, to)) { // the segment is moved
        if (cur.decPointers()) cur.remove()
        return true
    }
    if (to.decPointers()) to.remove() // undo tryIncPointers
}

获取/创建 Segment 的代码我就不重点介绍了,我添加了一些注释,大家自己看看。

到这里我们知道了如果 lock() 获取锁失败,对应的 Continaution 就会被挂起,然后 Continaution 对象会被添加到 Segment 中。聪明的你应该也猜到了,获取到锁的协程调用 unlock() 方法后就会尝试恢复 Segment 中的一个 Continaution 那么对应的调用 lock() 的协程就可以恢复执行了,是的,确实是这样,我们继续看看后面是怎么释放锁的。

释放锁

override fun unlock(owner: Any?) {
    while (true) {
        // Is this mutex locked?
        // 检查锁状态
        check(isLocked) { "This mutex is not locked" }
        // Read the owner, waiting until it is set in a spin-loop if required.
        // 检查 owner 状态
        val curOwner = this.owner.value
        if (curOwner === NO_OWNER) continue // <-- ATTENTION, BLOCKING PART HERE
        // Check the owner.
        check(curOwner === owner || owner == null) { "This mutex is locked by $curOwner, but $owner is expected" }
        // Try to clean the owner first. We need to use CAS here to synchronize with concurrent `unlock(..)`-s.
        // 将 owner 状态修改成 NO_OWNER
        if (!this.owner.compareAndSet(curOwner, NO_OWNER)) continue
        // Release the semaphore permit at the end.
        // 释放锁操作
        release()
        return
    }
}

上面代码比较简单,检查锁状态;检查 owner 状态;将 owner 设置为 NO_OWNER;如果设置成功调用 release() 方法执行释放操作。上面的一些 CAS 操作如果失败,然后会再重试。

我们再看看 SemaphoreImpl#release() 方法的实现:

override fun release() {
    while (true) {
        // 获取 _availablePermits 后,再把它的值加 1
        val p = _availablePermits.getAndIncrement()
        if (p >= permits) {
            // Revert the number of available permits
            // back to the correct one and fail with error.
            coerceAvailablePermitsAtMaximum()
            error("The number of released permits cannot be greater than $permits")
        }
        // 如果 p 大于等于 0,表示没有 Continuation 在等待锁,直接返回
        if (p >= 0) return
        // 尝试从等待队列中恢复一个 Continuation
        if (tryResumeNextFromQueue()) return
    }
}

简单解释一下上面代码:

  1. 获取 _availablePermits 的值,并把它的值加 1.
  2. 如果上次 _availablePermits 大于等于 0 就表示没有 Continuation 在等待锁,反之就是有等待锁.
  3. 如果有 Continuation 在等待锁,通过 tryResumeNextFromQueue() 方法尝试从等待队列中获取等待最久的一个 Continuation 来获取锁,并恢复它。

我们继续看看 tryResumeNextFromQueue() 方法的实现:

private fun tryResumeNextFromQueue(): Boolean {
    val curHead = this.head.value
    // 从 deqIdx 中获取基础 id
    val deqIdx = deqIdx.getAndIncrement()
    
    // 计算 Segment 的 id
    val id = deqIdx / SEGMENT_SIZE
    val createNewSegment = ::createSegment
    
    // 和插入队列一样,先去查找一个 Segment,这里不同的是从 head 开始查找
    val segment = this.head.findSegmentAndMoveForward(id, startFrom = curHead,
        createNewSegment = createNewSegment).segment // cannot be closed
    segment.cleanPrev()
    // 判断查找到的 Segemnt 的 id 和目标的 id 是否对应,如果不对应就继续查找
    if (segment.id > id) return false
    // 获取对应的 Continuation 在 Segemnt 中对应的位置
    val i = (deqIdx % SEGMENT_SIZE).toInt()
    // 获取到对应的 Continuation
    val cellState = segment.getAndSet(i, PERMIT) // set PERMIT and retrieve the prev cell state
    when {
        cellState === null -> {
            // Acquire has not touched this cell yet, wait until it comes for a bounded time
            // The cell state can only transition from PERMIT to TAKEN by addAcquireToQueue
            repeat(MAX_SPIN_CYCLES) {
                if (segment.get(i) === TAKEN) return true
            }
            // Try to break the slot in order not to wait
            return !segment.cas(i, PERMIT, BROKEN)
        }
        cellState === CANCELLED -> return false // the acquirer has already been cancelled
        
        // 恢复对应的 Continuation
        else -> return cellState.tryResumeAcquire()
    }
}

和插入队列时的代码类似,通过 deqIdx 来生成对应的目标的 SegmentidContinuationSegment 中的 index(这里注意插入的时候是用的 enqIdx,他们都是每次获取都会加 1)。查询 Semgent 同样是使用的 findSegmentAndMoveForward() 方法,不同的是出队列是从 head 开始查询。查询到对应的 Segment 后,会通过它的 getAndSet() 来获取一个 Continuation (正常情况下是一个 Continuation),然后调用它的 tryResumeAcquire() 方法,我们再来看看它的实现:

private fun Any.tryResumeAcquire(): Boolean = when(this) {
    is CancellableContinuation<*> -> {
        this as CancellableContinuation<Unit>
        // 尝试 Resume
        val token = tryResume(Unit, null, onCancellationRelease)
        if (token != null) {
            // 可以 Resume,通过 completeResume() 方法确认执行下发 Resume
            completeResume(token)
            true
        } else false
    }
    is SelectInstance<*> -> {
        trySelect(this@SemaphoreImpl, Unit)
    }
    else -> error("unexpected: $this")
}

我们的 Continuation 就是一个 CancellableContinuation,首先会通过 tryResume() 方法去尝试恢复协程,如果返回的 token 不为空,就表示当前的协程可以恢复,然后通过 completeResume() 方法确认执行协程恢复。

我们看看 CancellableContinuationWithOwner#tryResume() 方法的实现:

        override fun tryResume(value: Unit, idempotent: Any?, onCancellation: ((cause: Throwable) -> Unit)?): Any? {
            // 校验 owner 状态是 NO_OWNER
            assert { this@MutexImpl.owner.value === NO_OWNER }
            // 执行被代理的 Continaution 的 tryResume 方法
            val token = cont.tryResume(value, idempotent) {
            
                // 这个 Lambda 会在 Dispatcher 把任务取消时才执行,也就是表示 resume 失败了。
                assert { this@MutexImpl.owner.value.let { it === NO_OWNER ||it === owner } }
                this@MutexImpl.owner.value = owner
                // 重新解锁
                unlock(owner)
            }
            
            // token 不为空表示 tryResume 成功
            if (token != null) {
                assert { this@MutexImpl.owner.value === NO_OWNER }
                // 将 owner 修改为当前 Continaution 的 owner.
                this@MutexImpl.owner.value = owner
            }
            return token
        }

上面代码很简单了,就是调用被代理 ContinautiontryResume() 方法,如果返回值不为空就表示协程恢复成功(如果 Dispatcher 将该 resume 任务取消就会解锁),协程恢复成功就会将它对应的 owner 添加到 MutexImpl 中去,表示由该 owner 占有当前锁。

最后

到这里你理解了 Mutex 的工作方式了吗?如果还有困惑,可以再看一遍,有问题也可以在评论区中指出。MutexReentrantLocksynchronized 的最大区别是它是非重入锁,而且它也不以线程作为锁的拥有者,它是专门为协程设计它获取锁的方式是 lock() 方法,而且它是一个 suspend 方法,不过释放锁的 unlock() 方法它不是一个 suspend 方法。ReentrantLocksynchronized 都有可能会使用 monitor 锁,但是 Mutext 是不会使用的,它通过 CAS 自旋的方式来修改各种状态,这样做到线程安全。

上一篇下一篇

猜你喜欢

热点阅读