带着BAT大厂的面试问题去理解

  • 什么是CountDownLatch?
  • CountDownLatch底层实现原理?
  • CountDownLatch一次可以唤醒几个任务? 多个
  • CountDownLatch有哪些主要方法? await(),countDown()
  • CountDownLatch适用于什么场景?
  • 写道题:实现一个容器,提供两个方法,add,size 写两个线程,线程1添加10个元素到容器中,线程2实现监控元素的个数,当个数到5个时,线程2给出提示并结束? 使用CountDownLatch 代替wait notify 好处。

CountDownLatch介绍

从源码可知,其底层是由AQS提供支持,所以其数据结构可以参考AQS的数据结构,而AQS的数据结构核心就是两个虚拟队列: 同步队列sync queue 和条件队列condition queue,不同的条件会有不同的条件队列。

CountDownLatch典型的用法是将一个程序分为n个互相独立的可解决任务,并创建值为n的CountDownLatch。
等待问题被解决的任务调用这个锁存器的await(),将自己拦住,直至锁存器计数结束。
当每一个任务完成时,都会在这个锁存器上调用countDown()

源码分析

类的继承关系

CountDownLatch没有显示继承哪个父类或者实现哪个父接口, 它底层是AQS提供支持,是通过内部类Sync来实现的。

1
2
3
public class CountDownLatch {
private static final class Sync extends AbstractQueuedSynchronizer {}
}

内部类

CountDownLatch类存在一个内部类Sync,继承自AbstractQueuedSynchronizer。
对CountDownLatch方法的调用会转发到对Sync或AQS的方法的调用,所以,AQS对CountDownLatch提供支持。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
private static final class Sync extends AbstractQueuedSynchronizer {
// 版本号
private static final long serialVersionUID = 4982264981922014374L;

// 构造器
Sync(int count) {
setState(count);
}

// 返回当前计数
int getCount() {
return getState();
}

// 试图在共享模式下获取对象状态
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

// 试图设置状态来反映共享模式下的一个释放
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
// 无限循环
for (;;) {
// 获取状态
int c = getState();
if (c == 0) // 没有被线程占有
return false;
// 下一个状态
int nextc = c-1;
if (compareAndSetState(c, nextc)) // 比较并且设置成功
return nextc == 0;
}
}
}

类的属性和构造函数

1
2
3
4
5
6
7
8
9
10
public class CountDownLatch {
// 同步队列
private final Sync sync;

public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
// 初始化状态数
this.sync = new Sync(count);
}
}
  • 可以看到CountDownLatch类的内部只有一个Sync类型的属性。
  • 该构造函数可以构造一个给定计数初始化的CountDownLatch,并且构造函数内完成了sync的初始化,并设置了状态数。

核心函数 await()

此函数将会使当前线程在锁存器倒计数至零之前一直等待,除非线程被中断。

1
2
3
4
5
6
7
8
9
public void await() throws InterruptedException {
// 转发到sync对象上
sync.acquireSharedInterruptibly(1);
}

public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

对CountDownLatch对象的await的调用会转发为对Sync的acquireSharedInterruptibly(从AQS继承的方法)方法的调用。

CountDownLatch的await调用链:

  • AQS#acquireSharedInterruptibly

    1
    2
    3
    4
    5
    6
    public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
    if (Thread.interrupted())
    throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
    doAcquireSharedInterruptibly(arg);
    }

    acquireSharedInterruptibly又调用了CountDownLatch的内部类Sync的tryAcquireSharedAQS的doAcquireSharedInterruptibly函数

  • CountDownLatch#Sync#tryAcquireShared

    1
    2
    3
    protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
    }

    该函数只是简单的判断AQS的state是否为0,为0则返回1,不为0则返回-1。

  • AQS#doAcquireSharedInterruptibly

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
    // 添加节点至等待队列
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
    for (;;) { // 无限循环
    // 获取node的前驱节点
    final Node p = node.predecessor();
    if (p == head) { // 前驱节点为头节点
    // 试图在共享模式下获取对象状态
    int r = tryAcquireShared(arg);
    if (r >= 0) { // 获取成功
    // 设置头节点并进行繁殖
    setHeadAndPropagate(node, r);
    // 设置节点next域
    p.next = null; // help GC
    failed = false;
    return;
    }
    }
    if (shouldParkAfterFailedAcquire(p, node) &&
    parkAndCheckInterrupt()) // 在获取失败后是否需要禁止线程并且进行中断检查
    // 抛出异常
    throw new InterruptedException();
    }
    } finally {
    if (failed)
    cancelAcquire(node);
    }
    }

    在AQS的doAcquireSharedInterruptibly中可能会再次调用CountDownLatch的内部类Sync的tryAcquireShared方法和AQS的setHeadAndPropagate方法。

核心函数 countDown()

此函数将递减锁存器的计数,如果计数到达零,则释放所有等待的线程。

1
2
3
public void countDown() {
sync.releaseShared(1);
}

对countDown的调用转换为对Sync对象的releaseShared(从AQS继承而来)方法的调用。

CountDownLatch的countDown调用链:

深入理解

实现一个容器,提供两个方法,add,size
写两个线程,线程1添加10个元素到容器中,线程2实现监控元素的个数,当个数到5个时,线程2给出提示并结束.

使用wait和notify实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import java.util.ArrayList;
import java.util.List;

/**
* 必须先让t2先进行启动 使用wait 和 notify 进行相互通讯,wait会释放锁,notify不会释放锁
*/
public class T2 {

volatile List list = new ArrayList();

public void add (int i){
list.add(i);
}

public int getSize(){
return list.size();
}

public static void main(String[] args) {

T2 t2 = new T2();

Object lock = new Object();

new Thread(() -> {
synchronized(lock){
System.out.println("t2 启动");
if(t2.getSize() != 5){
try {
/**会释放锁*/
lock.wait();
System.out.println("t2 结束");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
lock.notify();
}
},"t2").start();

new Thread(() -> {
synchronized (lock){
System.out.println("t1 启动");
for (int i=0;i<9;i++){
t2.add(i);
System.out.println("add"+i);
if(t2.getSize() == 5){
/**不会释放锁*/
lock.notify();
try {
lock.wait();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
}).start();
}
}

CountDownLatch实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;

/**
* 使用CountDownLatch 代替wait notify 好处是通讯方式简单,不涉及锁定。 Count 值为0时当前线程继续执行
*/
public class T3 {

volatile List list = new ArrayList();

public void add(int i){
list.add(i);
}

public int getSize(){
return list.size();
}


public static void main(String[] args) {
T3 t = new T3();
CountDownLatch countDownLatch = new CountDownLatch(1);

new Thread(() -> {
System.out.println("t2 start");
if(t.getSize() != 5){
try {
countDownLatch.await();
System.out.println("t2 end");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
},"t2").start();

new Thread(()->{
System.out.println("t1 start");
for (int i = 0;i<9;i++){
t.add(i);
System.out.println("add"+ i);
if(t.getSize() == 5){
System.out.println("countdown is open");
countDownLatch.countDown();
}
}
System.out.println("t1 end");
},"t1").start();
}

}