算法调用优化演变

2021-06-30  本文已影响0人  EmilioWong

项目中碰到需要多线程并发调用nlu算法服务,随着不断压榨性能及碰到一些特殊问题,调用不断优化,特此记录下演变过程。

lombok配置

lombok配置了lombok.fieldDefaults.defaultPrivate=true,因此下述代码里类属性都没加private

演变过程

项目背景

项目是做聊天内容相关的,提供实时聊天消息和离线聊天消息的检出功能,需要调用算法组提供的nlu算法服务。为了提高检出速度,采用多线程处理聊天,基本可以理解成一个线程处理一个聊天。

阶段一

各线程同步调用nlu服务,大致代码如下

NluResult result = nluService.analyze(Chatpeer chatpeer);

问题

算法组采用cpu和gpu计算,对于gpu而言,多线程无法起到加速作用,反而会因为多线程网络io多,耗时更多的情况。

阶段二

针对阶段一的问题,需要将多线程改为单线程批量调用。而这种改动不能影响原来的代码结构。

实现思路

封装到一个batchService,传入一个chatpeer,返回Future。batchService内部里维持一个Queue,多线程不断往Queue里加元素,另外启一个线程,不断做出队操作,直到满一批次或者超时,则将这一批元素调用nlu算法批处理服务处理掉。原代码仍多线程调用,再调用get方法阻塞住,等待批处理完成。

code

改动后,大致调用如下

Future<NluResult> resultFuture = nluBatchService.analyze(Chatpeer chatpeer);
NluResult result = resultFuture.get();// 阻塞,等待批处理完成

batchService抽象实现如下

@Slf4j
@FieldDefaults(makeFinal = true)
public abstract class AbstractBatchProcessor<Q, R> {

    ConcurrentHashMap<Q, FutureTask<R>> chatpeerIdFutureMap = new ConcurrentHashMap<>();

    Map<Q, R> processResultsMap = new ConcurrentHashMap<>();

    ExecutorService singleThreadPool;

    int batchSize;

    long batchWaitMs;

    BlockingQueue<Q> requestWaitQueue;

    public AbstractBatchProcessor(UpdaterProperties updaterProperties, String singleThreadPoolThreadNamePrefix) {
        this.batchSize = updaterProperties.getUnicornBatchSize();
        this.batchWaitMs = updaterProperties.getUnicornBatchWaitMs();
        this.singleThreadPool = Executors.newSingleThreadExecutor(new CustomizableThreadFactory(singleThreadPoolThreadNamePrefix));
        this.requestWaitQueue = new LinkedBlockingQueue<>(); // 暂定无界队列
    }

    public Future<R> doProcess(Q request) {
        FutureTask<R> future = new FutureTask<>(() -> {
            R result = processResultsMap.get(request);
            if (result == null) {
                log.warn("batch 处理异常,无法获取处理结果");
            }
            processResultsMap.remove(request);
            chatpeerIdFutureMap.remove(request);
            return result;
        });
        Future<R> theRequestIdValue = chatpeerIdFutureMap.putIfAbsent(request, future);
        Preconditions.checkArgument(theRequestIdValue == null);
        while (!requestWaitQueue.offer(request)) {
            // 无界队列,理论上不会进这里
            log.warn("队列已满,等待100ms");
            ThreadUtils.sleep(100);
        }
        return future;
    }

    public void startup() {
        singleThreadPool.execute(() -> {
            List<Q> batchRequests = new ArrayList<>();
            while (true) {
                try {
                    batchRequests.clear();
                    Q firstRequest = requestWaitQueue.take();
                    batchRequests.add(firstRequest);
                    long currentTimeMillis = System.currentTimeMillis();
                    long maxWait = currentTimeMillis + batchWaitMs;
                    while (batchRequests.size() < batchSize) {
                        long timeout = maxWait - System.currentTimeMillis();
                        Q request = requestWaitQueue.poll(timeout, TimeUnit.MILLISECONDS);
                        if (request != null) {
                            batchRequests.add(request);
                        } else {
                            break;
                        }
                    }
                    List<R> results = batchProcess(batchRequests);
                    if (batchRequests.size() != results.size()) {
                        log.warn("analyzeBatch返回数量错误, inspectionEntryDtos:{}, inspectionResultDtos:{}", results.size(), batchRequests.size());
                        for (Q request : batchRequests) {
                            processResultsMap.put(request, null);
                        }
                    } else {
                        for (int i = 0; i < batchRequests.size(); i++) {
                            Q request = batchRequests.get(i);
                            R result = results.get(i);
                            processResultsMap.put(request, result);
                        }
                    }
                } catch (Throwable t) {
                    log.error("批处理异常", t);
                } finally {
                    for (Q request : batchRequests) {
                        FutureTask<R> futureTask = chatpeerIdFutureMap.get(request);
                        if (futureTask == null) {
                            log.error("缺失futureTask!! request: {}", request);
                        } else {
                            futureTask.run();
                        }
                    }
                }
            }
        });
    }

    protected abstract List<R> batchProcess(List<Q> requests);

}

问题

gpu压力是上来了cpu没上来

阶段三

针对阶段二的问题,可以将batchService内单线程改成多线程,但是不宜过多。

实现思路

引入生产者-消费者模型,单线程从Queue里生产批数据,多线程消费

code

@Slf4j
@FieldDefaults(makeFinal = true)
public abstract class AbstractBatchProcessor<Q, R> implements DisposableBean {

    ConcurrentHashMap<Q, FutureTask<R>> chatpeerIdFutureMap = new ConcurrentHashMap<>();

    Map<Q, R> processResultsMap = new ConcurrentHashMap<>();

    ExecutorService producerThreadPool;

    ExecutorService consumerThreadPool;

    int batchSize;

    long batchWaitMs;

    BlockingQueue<Q> requestWaitQueue;

    public AbstractBatchProcessor(UpdaterProperties updaterProperties, String bathProcessThreadPoolThreadNamePrefix) {
        this.batchSize = updaterProperties.getUnicornBatchSize();
        this.batchWaitMs = updaterProperties.getUnicornBatchWaitMs();
        this.requestWaitQueue = new LinkedBlockingQueue<>(); // 暂定无界队列

        this.producerThreadPool = Executors.newSingleThreadExecutor(new CustomizableThreadFactory(bathProcessThreadPoolThreadNamePrefix + "producer-"));
        Integer unicornProcessThreadSize = updaterProperties.getUnicornProcessThreadSize();
        this.consumerThreadPool = Executors.newFixedThreadPool(unicornProcessThreadSize, new CustomizableThreadFactory(bathProcessThreadPoolThreadNamePrefix + "consumer-"));
        Storage storage = new Storage(unicornProcessThreadSize);
        Producer producer = new Producer(storage);
        this.producerThreadPool.execute(producer);
        for (int i = 0; i < unicornProcessThreadSize; i++) {
            Consumer consumer = new Consumer(storage);
            this.consumerThreadPool.execute(consumer);
        }
    }

    public Future<R> doProcess(Q request) {
        FutureTask<R> future = new FutureTask<>(() -> {
            R result = processResultsMap.get(request);
            if (result == null) {
                log.warn("batch 处理异常,无法获取处理结果");
            }
            processResultsMap.remove(request);
            chatpeerIdFutureMap.remove(request);
            return result;
        });
        Future<R> theRequestIdValue = chatpeerIdFutureMap.putIfAbsent(request, future);
        Preconditions.checkArgument(theRequestIdValue == null);
        while (!requestWaitQueue.offer(request)) {
            // 无界队列,理论上不会进这里
            log.warn("队列已满,等待100ms");
            ThreadUtils.sleep(100);
        }
        return future;
    }

    @Override
    public void destroy() throws Exception {
        producerThreadPool.shutdown();
        consumerThreadPool.shutdown();
    }

    protected abstract List<R> batchProcess(List<Q> requests);

    @AllArgsConstructor
    @FieldDefaults(makeFinal = true)
    private class Producer implements Runnable {

        Storage storage;

        @Override
        public void run() {
            while (true) {
                try {
                    List<Q> batchRequests = new ArrayList<>();
                    Q firstRequest = requestWaitQueue.take();
                    batchRequests.add(firstRequest);
                    long currentTimeMillis = System.currentTimeMillis();
                    long maxWait = currentTimeMillis + batchWaitMs;
                    while (batchRequests.size() < batchSize) {
                        long timeout = maxWait - System.currentTimeMillis();
                        Q request = requestWaitQueue.poll(timeout, TimeUnit.MILLISECONDS);
                        if (request != null) {
                            batchRequests.add(request);
                        } else {
                            break;
                        }
                    }
                    Product product = new Product(batchRequests);
                    storage.push(product);
                } catch (Throwable t) {
                    log.error("生产者处理异常", t);
                }
            }
        }
    }

    @AllArgsConstructor
    @FieldDefaults(makeFinal = true)
    private class Consumer implements Runnable {
        Storage storage;

        @Override
        public void run() {
            List<Q> batchRequests = null;
            while (true) {
                try {
                    Product product = storage.pop();
                    batchRequests = product.getBatchRequests();
                    List<R> results = batchProcess(batchRequests);
                    if (batchRequests.size() != results.size()) {
                        log.warn("analyzeBatch返回数量错误, inspectionEntryDtos:{}, inspectionResultDtos:{}", results.size(), batchRequests.size());
                        for (Q request : batchRequests) {
                            processResultsMap.put(request, null);
                        }
                    } else {
                        for (int i = 0; i < batchRequests.size(); i++) {
                            Q request = batchRequests.get(i);
                            R result = results.get(i);
                            processResultsMap.put(request, result);
                        }
                    }
                } catch (Throwable t) {
                    log.error("消费者处理异常", t);
                } finally {
                    if (batchRequests != null) {
                        for (Q request : batchRequests) {
                            FutureTask<R> futureTask = chatpeerIdFutureMap.get(request);
                            if (futureTask == null) {
                                log.error("缺失futureTask!! request: {}", request);
                            } else {
                                futureTask.run();
                            }
                        }
                        batchRequests = null;
                    }
                }
            }
        }
    }

    @FieldDefaults(makeFinal = true)
    private class Storage {
        BlockingQueue<Product> queues;

        public Storage(int size) {
            this.queues = new LinkedBlockingQueue<>(size);
        }

        public void push(Product p) throws InterruptedException {
            queues.put(p);
        }

        public Product pop() throws InterruptedException {
            return queues.take();
        }
    }

    @Getter
    @AllArgsConstructor
    @FieldDefaults(makeFinal = true)
    private class Product {
        List<Q> batchRequests;
    }
}

问题

如果单个chatpeer内数据量过大,算法调用会出现超时情况。

阶段四

针对阶段三的问题,决定对单个chatpeer做分段处理。因内容分段的上下文不完整导致的nlu可能解析错误,可接受。

实现思路

引入forkable和joinable接口,取名自ForkJoinPool,将请求fork成多段,全部完成后再join回一个结果。此外,生产者在打包批处理数据时,粒度不能控制到请求的数量上。请求应该提供一个方法来表示该请求里的数据量,打包时采用该数据量累计判断是否需要继续添加下一个请求到当前批次里。

code

forkable接口

public interface Forkable<SELF extends Forkable<SELF>> {

    /**
     * fork出来的偏移量
     */
    int getOffset();

    /**
     * 根据分区容量判断是否需要fork
     *
     * @param partitionCapacity 每个分区容量
     */
    boolean shouldFork(int partitionCapacity);

    /**
     * 具体的fork实现
     *
     * @param partitionCapacity 没个分区容量
     * @return fork后的具体请求
     */
    List<SELF> fork(int partitionCapacity);

    /**
     * 数据量
     */
    int size();
}

joinable接口

public interface Joinable<SELF extends Joinable<SELF>> {

    /**
     * 根据偏移量修改自身数据,采用函数式编程思想,返回新实例
     *
     * @param offset 偏移量
     * @return 偏移后结果集
     */
    SELF join(int offset);

    /**
     * 根据偏移量合并结果集,采用函数式编程思想,返回新实例
     *
     * @param other  另一个结果集
     * @param offset 另一个结果集的偏移量
     * @return 合并后新结果集
     */
    SELF join(SELF other, int offset);
}

批处理抽象类

@Slf4j
@FieldDefaults(makeFinal = true)
public abstract class AbstractBatchProcessor<F extends Forkable<F>, J extends Joinable<J>> implements DisposableBean {

    /**
     * 源请求到future映射
     */
    ConcurrentHashMap<F, FutureTask<J>> requestFutureMap = new ConcurrentHashMap<>();

    /**
     * 各(子)请求处理映射
     */
    Map<F, J> processResultsMap = new ConcurrentHashMap<>();

    /**
     * 源请求到子请求映射
     * <p>消费者在批处理时丢失了源请求,依靠这个映射关系找到源请求
     */
    Multimap<F, F> forkRequestMultimap = ArrayListMultimap.create();

    /**
     * 源请求未完成数映射
     * <p>请求接收处fork掉源请求后,记录出有多少未完成数,消费者批处理时,判断是否已全部处理完,由于采用多线程,使用AtomicInteger
     */
    Map<F, AtomicInteger> forkRequestUndoMap = new ConcurrentHashMap<>();

    /**
     * 子请求到源请求映射
     * <p>消费者在批处理时丢失了源请求,依靠这个映射关系找到源请求
     */
    Map<F, F> forkedRequestToSourceRequestMap = new ConcurrentHashMap<>();

    ExecutorService producerThreadPool;

    ExecutorService consumerThreadPool;

    int batchSize;

    long batchWaitMs;

    int partitionCapacity;

    BlockingQueue<F> requestWaitQueue;

    public AbstractBatchProcessor(UpdaterProperties updaterProperties, String bathProcessThreadPoolThreadNamePrefix) {
        this.batchSize = updaterProperties.getUnicornBatchSize();
        this.batchWaitMs = updaterProperties.getUnicornBatchWaitMs();
        this.partitionCapacity = updaterProperties.getUnicornPartitionCapacity();
        this.requestWaitQueue = new LinkedBlockingQueue<>(); // 暂定无界队列

        this.producerThreadPool = Executors.newSingleThreadExecutor(new CustomizableThreadFactory(bathProcessThreadPoolThreadNamePrefix + "producer-"));
        Integer unicornProcessThreadSize = updaterProperties.getUnicornProcessThreadSize();
        this.consumerThreadPool = Executors.newFixedThreadPool(unicornProcessThreadSize, new CustomizableThreadFactory(bathProcessThreadPoolThreadNamePrefix + "consumer-"));
        Storage storage = new Storage(unicornProcessThreadSize);
        Producer producer = new Producer(storage);
        this.producerThreadPool.execute(producer);
        for (int i = 0; i < unicornProcessThreadSize; i++) {
            Consumer consumer = new Consumer(storage);
            this.consumerThreadPool.execute(consumer);
        }
    }

    public Future<J> doProcess(F request) {
        FutureTask<J> future;
        if (request.shouldFork(this.partitionCapacity)) {
            List<F> forks = request.fork(this.partitionCapacity);
            forkRequestMultimap.putAll(request, forks);
            forkRequestUndoMap.put(request, new AtomicInteger(forks.size()));
            forks.forEach(f -> forkedRequestToSourceRequestMap.put(f, request));
            log.info("有大数据请求,fork成{}个子请求", forks.size());
            future = new FutureTask<>(() -> {
                J result = null;
                Collection<F> forkRequests = forkRequestMultimap.get(request);
                try {
                    for (F fork : forkRequests) {
                        J response = processResultsMap.get(fork);
                        if (response == null) {
                            log.warn("batch 处理异常,无法获取处理结果");
                        } else {
                            if (result == null) {
                                result = response.join(fork.getOffset());
                            } else {
                                result = result.join(response, fork.getOffset());
                            }
                        }
                        processResultsMap.remove(fork);
                    }
                } finally {
                    requestFutureMap.remove(request);
                    forkRequestMultimap.removeAll(request);
                    forkRequestUndoMap.remove(request);
                }
                return result;
            });
        } else {
            future = new FutureTask<>(() -> {
                J result = processResultsMap.get(request);
                if (result == null) {
                    log.warn("batch 处理异常,无法获取处理结果");
                }
                processResultsMap.remove(request);
                requestFutureMap.remove(request);
                return result;
            });
        }

        Future<J> theRequestIdValue = requestFutureMap.putIfAbsent(request, future);
        Preconditions.checkArgument(theRequestIdValue == null);
        while (!requestWaitQueue.offer(request)) {
            // 无界队列,理论上不会进这里
            log.warn("队列已满,等待100ms");
            ThreadUtils.sleep(100);
        }
        return future;
    }

    @Override
    public void destroy() throws Exception {
        producerThreadPool.shutdown();
        consumerThreadPool.shutdown();
    }

    protected abstract List<J> batchProcess(List<F> requests);

    @RequiredArgsConstructor
    @FieldDefaults(makeFinal = true)
    private class Producer implements Runnable {

        Storage storage;

        BatchController batchController = new BatchController();

        @Override
        @SuppressWarnings("InfiniteLoopStatement")
        public void run() {
            while (true) {
                try {
                    F firstRequest = requestWaitQueue.take();
                    this.batchController.init();
                    handleRequest(firstRequest);
                    while (this.batchController.sizeInLimit()) {
                        long timeout = this.batchController.getMaxWait() - System.currentTimeMillis();
                        F request = requestWaitQueue.poll(timeout, TimeUnit.MILLISECONDS);
                        if (request != null) {
                            handleRequest(request);
                        } else {
                            break;
                        }
                    }
                    // 这里再判断下非空,是因为有可能fork的请求处理完后init,但是等待期间没有新的request,超时后batchRequest就为空了。此时不需要生产product
                    if (CollectionUtils.isNotEmpty(this.batchController.getBatchRequests())) {
                        Product product = new Product(this.batchController.getBatchRequests());
                        if (log.isDebugEnabled()) {
                            int size = product.getBatchRequests().stream().mapToInt(F::size).sum();
                            log.debug("时间到或满一批,发送到storage, size:{}", size);
                        }
                        storage.push(product);
                    }
                } catch (Throwable t) {
                    log.error("生产者处理异常", t);
                }
            }
        }

        private void handleRequest(F request) throws InterruptedException {
            if (forkRequestMultimap.containsKey(request)) {
                for (F fork : forkRequestMultimap.get(request)) {
                    if (!this.batchController.sizeInLimit()) {
                        Product product = new Product(this.batchController.getBatchRequests());
                        if (log.isDebugEnabled()) {
                            int size = product.getBatchRequests().stream().mapToInt(F::size).sum();
                            log.debug("fork满一批,发送到storage, size:{}", size);
                        }
                        storage.push(product);
                        this.batchController.init();
                    }
                    this.batchController.add(fork);
                }
            } else {
                this.batchController.add(request);
            }
        }
    }

    private class BatchController {

        @Getter
        List<F> batchRequests;

        int currSize;

        @Getter
        long maxWait;

        public void init() {
            this.batchRequests = new ArrayList<>();
            this.currSize = 0;
            this.maxWait = System.currentTimeMillis() + batchWaitMs;
        }

        public void add(F request) {
            this.batchRequests.add(request);
            this.currSize += request.size();
        }

        public boolean sizeInLimit() {
            return this.currSize < batchSize;
        }
    }

    @AllArgsConstructor
    @FieldDefaults(makeFinal = true)
    private class Consumer implements Runnable {
        Storage storage;

        @Override
        @SuppressWarnings("InfiniteLoopStatement")
        public void run() {
            List<F> batchRequests = null;
            while (true) {
                try {
                    Product product = storage.pop();
                    batchRequests = product.getBatchRequests();
                    log.debug("收到storage批处理");
                    List<J> results = batchProcess(batchRequests);
                    log.debug("批处理完成");
                    if (batchRequests.size() != results.size()) {
                        log.warn("analyzeBatch返回数量错误, result size:{}, request size:{}, request:{}", results.size(),
                                batchRequests.size(), batchRequests);
                        for (F request : batchRequests) {
                            processResultsMap.put(request, null);
                        }
                    } else {
                        for (int i = 0; i < batchRequests.size(); i++) {
                            F request = batchRequests.get(i);
                            J result = results.get(i);
                            processResultsMap.put(request, result);
                        }
                    }
                } catch (Throwable t) {
                    log.error("消费者处理异常", t);
                } finally {
                    if (batchRequests != null) {
                        for (F request : batchRequests) {
                            F sourceRequest = null;
                            if (forkedRequestToSourceRequestMap.containsKey(request)) {
                                sourceRequest = forkedRequestToSourceRequestMap.get(request);
                                int decrementNum = forkRequestUndoMap.get(sourceRequest).decrementAndGet();
                                if (decrementNum > 0) {
                                    // fork的请求未全部处理完毕,sourceRequest置回null,futureTask not run
                                    sourceRequest = null;
                                }
                                forkedRequestToSourceRequestMap.remove(request);
                            } else {
                                sourceRequest = request;
                            }
                            if (sourceRequest != null) {
                                FutureTask<J> futureTask = requestFutureMap.get(sourceRequest);
                                if (futureTask == null) {
                                    log.error("缺失futureTask!! request: {}", request);
                                } else {
                                    futureTask.run();
                                }
                            }
                        }
                        batchRequests = null;
                    }
                }
            }
        }
    }

    @FieldDefaults(makeFinal = true)
    private class Storage {
        BlockingQueue<Product> queues;

        public Storage(int size) {
            this.queues = new LinkedBlockingQueue<>(size);
        }

        public void push(Product p) throws InterruptedException {
            queues.put(p);
        }

        public Product pop() throws InterruptedException {
            return queues.take();
        }
    }

    @Getter
    @AllArgsConstructor
    @FieldDefaults(makeFinal = true)
    private class Product {
        List<F> batchRequests;
    }
}
上一篇 下一篇

猜你喜欢

热点阅读