Java - CyclicBarrier 源码阅读笔记

in TCEH Java

简介

   CyclicBarrier是同步辅助工具类。CyclicBarrier就像是一个阀门,实现指定数量的线程全部执行到等待点或超时或其中有一个线程被中断才会继续往下执行。支持循环使用。

示例

@Log4j2
public class CyclicBarrierTest {
    public static void main(String[] args) throws Exception {
        int threadNum = 10;
        CyclicBarrier cyclicBarrier = new CyclicBarrier(threadNum, () -> log.info("到达栏栅点!"));
        for (int i = 0; i < threadNum - 1; i++) {
            new Thread(() -> {
                try {
                    log.info("子线程等待");
                    cyclicBarrier.await();
                    log.info("子线程执行完毕!");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } catch (BrokenBarrierException e) {
                    e.printStackTrace();
                }
            }).start();
        }
        log.info("主线程等待!");
        cyclicBarrier.await();
        log.info("主线程执行完毕!");
    }
}

CyclicBarrier 类

public class CyclicBarrier {
    // 每一代表示一个栏栅
    private static class Generation {
        // 是否损坏
        boolean broken = false;
    }

    /** 使用可重入锁确保并发线程安全 */
    private final ReentrantLock lock = new ReentrantLock();
    private final Condition trip = lock.newCondition();
    /** 栏栅线程数 */
    private final int parties;
    /* 到达栏栅点运行线程 */
    private final Runnable barrierCommand;
    /** 代数 */
    private Generation generation = new Generation();
    /** 剩余栏栅线程数 */
    private int count;
    ...
}

构造方法

    public CyclicBarrier(int parties) {
        this(parties, null);
    }    
    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }

await() 方法

    public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }

dowait() 方法

    private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        final ReentrantLock lock = this.lock;
        lock.lock();// 加锁
        try {
            final Generation g = generation;
            // 损坏的话抛异常
            if (g.broken)
                throw new BrokenBarrierException();
            // 线程中断
            if (Thread.interrupted()) {
                // 中断当前代
                breakBarrier();
                throw new InterruptedException();
            }

            int index = --count;
            if (index == 0) {  // 到达栏栅点
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    // 如果传递了栏栅点线程,则运行指定线程
                    if (command != null)
                        command.run();
                    ranAction = true;
                    // 一个栏栅完成,重置各项参数。开始下一个栏栅
                    nextGeneration();
                    return 0;
                } finally {
                    if (!ranAction)
                        breakBarrier();
                }
            }

            // 循环直到损坏、中断或超时
            for (;;) {
                try {
                    // 未设置超时,直接等待
                    if (!timed)
                        trip.await();
                    // 设置超时等待
                    else if (nanos > 0L)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    // 代数一致且未损坏,则中断当前代
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        // 延时中断
                        Thread.currentThread().interrupt();
                    }
                }
                // 损坏
                if (g.broken)
                    throw new BrokenBarrierException();
                // 代数不一致返回索引
                if (g != generation)
                    return index;
                // 超时
                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            lock.unlock();// 解锁
        }
    }

breakBarrier() 方法

    private void breakBarrier() {
        generation.broken = true;// 标记损坏
        count = parties;// 重置计数
        trip.signalAll();// 放行等待的线程
    }

nextGeneration() 方法

    private void nextGeneration() {
        // 放行等待的线程
        trip.signalAll();
        // 设置下一代
        count = parties;
        generation = new Generation();
    }

reset() 方法

如遇中断超时等异常,可重置

    public void reset() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            breakBarrier();   // 打破当前代
            nextGeneration(); // 开始新代
        } finally {
            lock.unlock();
        }
    }