Java - Spliterator 接口

2019-11-04  本文已影响0人  sschrodinger

Java - Spliterator 接口

sschrodinger

2019/11/01


Spliterator 接口简介


在 Java 中,对于一个集合的遍历前前后后经历了三种方式,第一种方式是使用 Enumeration 接口,第二种方式是使用 Iterator 接口,第三种方式就是使用 Spliterator 接口。

Enumeration 接口提供了两个函数用于实现遍历,如下:

public interface Enumeration<E> {

    boolean hasMoreElements();
    E nextElement();
    
}

不过,现在 Enumration 接口基本上都被 Iterator 接口取代,两个接口提供的方法基本相似,但是 Iterator 多提供了一个 remove 方法(在 Java 1.8 之前),如下:

public interface Iterator<E> {

    boolean hasNext();

    E next();

    default void remove() {
        throw new UnsupportedOperationException("remove");
    }
    
    // 该函数暂时略过
    default void forEachRemaining(Consumer<? super E> action) {
        Objects.requireNonNull(action);
        while (hasNext())
            action.accept(next());
    }
}

为什么加入 remove 方法,因为这样可以实现快速失败机制,保证一定的线程安全性(保证只有迭代器持有的线程能够对集合做结构上的修改,避免多个线程同时修改集合结构)。

forEachRemaining 可以使用函数式编程的写法,大大减少编写代码的长度,比如说遍历打印元素数组,传统写法如下:

List<String> list = new LinkedList<>();
list.add("a");
list.add("b");
list.add("c");
Iterator<String> iterator = list.iterator();
while (iterator.hasNext()) {
    System.out.println(iterator.next());
}

现在我们可以使用 forEachRemaining 这样写:

List<String> list = new LinkedList<>();
list.add("a");
list.add("b");
list.add("c");
Iterator<String> iterator = list.iterator();
iterator.forEachRemaining(System.out::println);

当然,对于集合的遍历,也可以使用更加简洁的写法,如下:

List<String> list = new LinkedList<>();
list.add("a");
list.add("b");
list.add("c");
list.forEach(System.out::println);

这两种传统的遍历方式,最大的问题,就是使用的顺序遍历的方式,顺序遍历的方式,在多线程的情况下,并不能够很好的利用多线程的性能,所以, Java 库的开发者构建了 Spliterator 接口,不使用顺序遍历的方式,用来加快遍历处理的速度。


Spliterator 接口实现


首先来看 Spiterator 接口定义,如下:

public interface Spliterator<T> {

    boolean tryAdvance(Consumer<? super T> action);

    default void forEachRemaining(Consumer<? super T> action) {
        do { } while (tryAdvance(action));
    }

    Spliterator<T> trySplit();

    long estimateSize();

    default long getExactSizeIfKnown() {
        return (characteristics() & SIZED) == 0 ? -1L : estimateSize();
    }

    int characteristics();

    default boolean hasCharacteristics(int characteristics) {
        return (characteristics() & characteristics) == characteristics;
    }

    default Comparator<? super T> getComparator() {
        throw new IllegalStateException();
    }

    public static final int ORDERED    = 0x00000010;
    public static final int DISTINCT   = 0x00000001;
    public static final int SORTED     = 0x00000004;
    public static final int SIZED      = 0x00000040;
    public static final int NONNULL    = 0x00000100;
    public static final int IMMUTABLE  = 0x00000400;
    public static final int CONCURRENT = 0x00001000;
    public static final int SUBSIZED = 0x00004000;

    // ... 
}

提供的函数非常的多,我们对照官方文档查看。

可以用 Spitetor 接口处理的对象不仅仅限于集合(Collection),还包括 arrayIO channelgenerator function

对于 Spiterator 提供的基本功能,最重要的当然是遍历,对于 Spliterator 接口来说,遍历方法有两个,第一个是 tryAdvance,一个是 forEachRemaining

tryAdvance 类似于 Iteratornext() 方法,但是他和 next() 方法不同的是,他的返回值不是下一个元素,而是一个布尔值,代表的是如果有剩下的元素,那么使用函数参数对应的方法处理,并返回 true,否则返回 false

我们考虑一个 Spliterator 接口的实现类 MySpliterator,那么使用 tryAdvance 打印其中的一个元素就可以写成这样:

public void func(MySpliterator spliterator) {
    // 使用 tryAdvance;
    if (spliterator.tryAdvance(System.out::println)) {
        System.out.println(打印一个元素成功成功);
    } else {
        System.out.println(打印一个元素成功失败,没有更多的元素);
    }
}

forEachRemaining 就是遍历所有的元素,从他的实现上来看,就是对 tryAdvance 的一个再次封装。

如果仅仅是对元素进行遍历,那么就又会变成一个顺序的过程,并不能使用在多线程中,所以 Spliterator 接口还提供了函数 trySplit,将部分元素分开来形成一个新的 Spliterator,这个新的 Spliterator 就可以使用其他的一些线程并发的进行遍历。

调用 trySplit 函数,可以将集合元素分成不相交且并集为全集的两部分,一部分在新的 Spliterator 中,一部分仍然在该 Spliterator 中,那么并行处理的一般框架如下:

public <T> void func(Spliterator spliterator) {
    Spliterator newSpliterator = spliterator.trySplit();
    if (newSpliterator != null) {
        // 另一个线程处理一部分元素
        new Thread(new MyRunnable<T>(newSpliterator, System.out::println)).start();
    }
    // 该线程处理一部分元素
    spliterator.forEachRemaining(System.out::println);
}

private static class MyRunnable<T>  implements Runnable {
    private Spliterator spliterator;
    private Consumer<? super T> action;
    MyRunnable(Spliterator spliterator, Consumer<? super T> action) {
        this.spliterator = spliterator;
        this.action = action;
    }

    @Override
    public void run() {
        spliterator.forEachRemaining(action);
    }
}    

除了以上最核心的两个功能,Spliterator 接口还提供了 characteristics 属性,所谓的 characteristics 属性,代表迭代的八种属性,使用八个常量代替,如下(解释见注释):

// 代表迭代的元素是有序的( trySplit 一定返回前缀元素,forEachRemainning 也一定按照前缀迭代)
public static final int ORDERED    = 0x00000010;

// 代表迭代的元素没有重复元素,由 Set 产生的 Spliterator 都会有这个属性
public static final int DISTINCT   = 0x00000001;

// 代表迭代的元素是有序的,必须根据 Comparator 进行排序,由 NavigableSet 和 SortedSet 产生的一定有该属性
public static final int SORTED     = 0x00000004;

// 代表在遍历或拆分之前可以从 valuateSize() 返回有限大小的 spliterator。
public static final int SIZED      = 0x00000040;

// 表示没有空元素
public static final int NONNULL    = 0x00000100;

// 表示不能够修改原始结构
public static final int IMMUTABLE  = 0x00000400;

// 表示在多线程情况下可能可以对其进行修改
public static final int CONCURRENT = 0x00001000;

// 代表 trySplit 返回的 Spliterator 也是有 sized 属性和 subsized 属性的
public static final int SUBSIZED = 0x00004000;

这些属性可能可以加快遍历的速度(比如说,如果是无序的,就可以多个线程同时遍历)。

举一个 Java 文档的 Spliterator 实现,如下:

static class IteratorSpliterator<T> implements Spliterator<T> {
        static final int BATCH_UNIT = 1 << 10;  // batch array size increment
        static final int MAX_BATCH = 1 << 25;  // max batch array size;
        private final Collection<? extends T> collection; // null OK
        private Iterator<? extends T> it;
        private final int characteristics;
        private long est;             // size estimate
        private int batch;            // batch size for splits

        public IteratorSpliterator(Collection<? extends T> collection, int characteristics) {
            this.collection = collection;
            this.it = null;
            this.characteristics = (characteristics & Spliterator.CONCURRENT) == 0
                                   ? characteristics | Spliterator.SIZED | Spliterator.SUBSIZED
                                   : characteristics;
        }

        public IteratorSpliterator(Iterator<? extends T> iterator, long size, int characteristics) {
            this.collection = null;
            this.it = iterator;
            this.est = size;
            this.characteristics = (characteristics & Spliterator.CONCURRENT) == 0
                                   ? characteristics | Spliterator.SIZED | Spliterator.SUBSIZED
                                   : characteristics;
        }

        public IteratorSpliterator(Iterator<? extends T> iterator, int characteristics) {
            this.collection = null;
            this.it = iterator;
            this.est = Long.MAX_VALUE;
            this.characteristics = characteristics & ~(Spliterator.SIZED | Spliterator.SUBSIZED);
        }

        @Override
        public Spliterator<T> trySplit() {

            Iterator<? extends T> i;
            long s;
            if ((i = it) == null) {
                i = it = collection.iterator();
                s = est = (long) collection.size();
            }
            else
                s = est;
            if (s > 1 && i.hasNext()) {
                int n = batch + BATCH_UNIT;
                if (n > s)
                    n = (int) s;
                if (n > MAX_BATCH)
                    n = MAX_BATCH;
                Object[] a = new Object[n];
                int j = 0;
                do { a[j] = i.next(); } while (++j < n && i.hasNext());
                batch = j;
                if (est != Long.MAX_VALUE)
                    est -= j;
                return new ArraySpliterator<>(a, 0, j, characteristics);
            }
            return null;
        }

        @Override
        public void forEachRemaining(Consumer<? super T> action) {
            if (action == null) throw new NullPointerException();
            Iterator<? extends T> i;
            if ((i = it) == null) {
                i = it = collection.iterator();
                est = (long)collection.size();
            }
            i.forEachRemaining(action);
        }

        @Override
        public boolean tryAdvance(Consumer<? super T> action) {
            if (action == null) throw new NullPointerException();
            if (it == null) {
                it = collection.iterator();
                est = (long) collection.size();
            }
            if (it.hasNext()) {
                action.accept(it.next());
                return true;
            }
            return false;
        }

        @Override
        public long estimateSize() {
            if (it == null) {
                it = collection.iterator();
                return est = (long)collection.size();
            }
            return est;
        }

        @Override
        public int characteristics() { return characteristics; }

        @Override
        public Comparator<? super T> getComparator() {
            if (hasCharacteristics(Spliterator.SORTED))
                return null;
            throw new IllegalStateException();
        }
    }

BaTCH_UNIT 代表 split 的大小,我们看 trySplit,如下:

public Spliterator<T> trySplit() {

    Iterator<? extends T> i;
    long s;
    if ((i = it) == null) {
        i = it = collection.iterator();
        s = est = (long) collection.size();
    }
    else
        s = est;
    if (s > 1 && i.hasNext()) {
        int n = batch + BATCH_UNIT;
        if (n > s)
            n = (int) s;
        if (n > MAX_BATCH)
            n = MAX_BATCH;
        Object[] a = new Object[n];
        int j = 0;
        do { a[j] = i.next(); } while (++j < n && i.hasNext());
        batch = j;
        if (est != Long.MAX_VALUE)
            est -= j;
        return new ArraySpliterator<>(a, 0, j, characteristics);
    }
    return null;
}

实际上就是通过迭代将部分的数据分到一个新的数组,然后返回。不过,因为集合只能使用顺序的方式遍历,所以效率不高,这时就会有 ArraySpliterator。他的 trySplit 方法如下:

public Spliterator<T> trySplit() {
    int lo = index, mid = (lo + fence) >>> 1;
    return (lo >= mid)
           ? null
           : new ArraySpliterator<>(array, lo, index = mid, characteristics);
}

这就是用了随机访问的方式真正实现了并行处理。


Spliterator 与并行 stream


具体 stream 知识参见。主要就是利用 ForkJoin 框架实现的并行。

stream 处理并行的函数如下:

// package java.util.stream.TerminalOp
default <P_IN> R evaluateParallel(PipelineHelper<E_IN> helper,
                              Spliterator<P_IN> spliterator) {
    if (Tripwire.ENABLED)
        Tripwire.trip(getClass(), "{0} triggering TerminalOp.evaluateParallel serial default");
    return evaluateSequential(helper, spliterator);
}

我们看其子类(reduceOp),如下:

@Override
public <P_IN> R evaluateParallel(PipelineHelper<T> helper,
                                 Spliterator<P_IN> spliterator) {
    return new ReduceTask<>(this, helper, spliterator).invoke().get();
}

ReduceTask 是一个 ForkJoinTask 的实现类,继承关系如下:

|--------------|   |----------------|   |--------------------|   |----------------|
|  ReduceTask  |-->|  AbstractTask  |-->|  CountedCompleter  |-->|  ForkJoinTask  |
|--------------|   |----------------|   |--------------------|   |----------------|

ForkJoinTask 的执行由 invoke 开始,如下:

public final V invoke() {
    int s;
    // important function : doInvoke
    if ((s = doInvoke() & DONE_MASK) != NORMAL)
        reportException(s);
    return getRawResult();
}
private int doInvoke() {
    int s; Thread t; ForkJoinWorkerThread wt;
    // important function : doExec
    return (s = doExec()) < 0 ? s :
    ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
        (wt = (ForkJoinWorkerThread)t).pool.
        awaitJoin(wt.workQueue, this, 0L) :
        externalAwaitDone();
}
final int doExec() {
    int s; boolean completed;
    if ((s = status) >= 0) {
        try {
            // important function : exec
            completed = exec();
        } catch (Throwable rex) {
            return setExceptionalCompletion(rex);
        }
        if (completed)
            s = setCompletion(NORMAL);
    }
    return s;
}

protected abstract boolean exec();

主要就是在 exec 函数做业务逻辑处理,包括计算,任务拆分等。

exec 函数由 CountedCompleter 类重写,如下:

protected final boolean exec() {
    compute();
    return false;
}

public abstract void compute();

同理,compute 函数由 AbstractTask 类重写,如下:

@Override
public void compute() {
    Spliterator<P_IN> rs = spliterator, ls; // right, left spliterators
    long sizeEstimate = rs.estimateSize();
    long sizeThreshold = getTargetSize(sizeEstimate);
    boolean forkRight = false;
    @SuppressWarnings("unchecked") K task = (K) this;
    while (sizeEstimate > sizeThreshold && (ls = rs.trySplit()) != null) {
        K leftChild, rightChild, taskToFork;
        task.leftChild  = leftChild = task.makeChild(ls);
        task.rightChild = rightChild = task.makeChild(rs);
        task.setPendingCount(1);
        if (forkRight) {
            forkRight = false;
            rs = ls;
            task = leftChild;
            taskToFork = rightChild;
        }
        else {
            forkRight = true;
            task = rightChild;
            taskToFork = leftChild;
        }
        taskToFork.fork();
        sizeEstimate = rs.estimateSize();
    }
    task.setLocalResult(task.doLeaf());
    task.tryComplete();
}

注意到 while 循环条件的 ls = rs.trySplit() 语句,实际上就是新开一个 spliterator,此时,rs 为当前的 spliteratorls 为新生成的 spliterator

如下的语句主要是生成一个任务树,便于任务的合并。

task.leftChild  = leftChild = task.makeChild(ls);
task.rightChild = rightChild = task.makeChild(rs);

如下的语句主要是交替将左子树和右子树的任务 fork 到线程池中。

if (forkRight) {
    forkRight = false;
    rs = ls;
    task = leftChild;
    taskToFork = rightChild;
} else {
    forkRight = true;
    task = rightChild;
    taskToFork = leftChild;
}
taskToFork.fork();

sizeEstimate 代表剩下的任务大小,上面交替 fork 的步骤要等到 sizeEstimate 小于阈值时才完成。

doLeaf 函数就是处理的函数,如下:

// package java.util.stream.ReduceOps.
@Override
protected S doLeaf() {
    return helper.wrapAndCopyInto(op.makeSink(), spliterator);
}

这实际上就是开启了了 stream 的流水线操作。如下:

@Override
final <P_IN, S extends Sink<E_OUT>> S wrapAndCopyInto(S sink, Spliterator<P_IN> spliterator) {
    copyInto(wrapSink(Objects.requireNonNull(sink)), spliterator);
    return sink;
}

@Override
final <P_IN> void copyInto(Sink<P_IN> wrappedSink, Spliterator<P_IN> spliterator) {
    Objects.requireNonNull(wrappedSink);
    if (!StreamOpFlag.SHORT_CIRCUIT.isKnown(getStreamAndOpFlags())) {
        wrappedSink.begin(spliterator.getExactSizeIfKnown());
        // 这实际上就是在遍历每一个元素
        spliterator.forEachRemaining(wrappedSink);
        wrappedSink.end();
    } else {
        copyIntoWithCancel(wrappedSink, spliterator);
    }
}

之后,就是对结果的组合,如下:

task.tryComplete();

public final void tryComplete() {
    CountedCompleter<?> a = this, s = a;
    for (int c;;) {
        if ((c = a.pending) == 0) {
            a.onCompletion(s);
            if ((a = (s = a).completer) == null) {
                s.quietlyComplete();
                return;
            }
        }
        else if (U.compareAndSwapInt(a, PENDING, c, c - 1))
            return;
    }
}

@Override
public void onCompletion(CountedCompleter<?> caller) {
    if (!isLeaf()) {
        S leftResult = leftChild.getLocalResult();
        leftResult.combine(rightChild.getLocalResult());
        setLocalResult(leftResult);
    }
    // GC spliterator, left and right child
    super.onCompletion(caller);
}

小结


Spliterator 是 Java 提供的方便并行迭代的方式,相对于 Iterator,可以更加高效的处理集合。

上一篇下一篇

猜你喜欢

热点阅读