从源码看ShardingSphere设计-执行引擎篇

2020-09-22  本文已影响0人  蚊子squirrel

执行引擎的职责定位是将改写后的SQL发送到对应数据库(经路由计算所得)执行的过程。执行引擎采用了callback回调的设计模式,对给定的输入分组集合执行指定的callback函数。

与Spring的JDBCTemplate、TransactionTemplate类似,ShardingSphere中的SQLExecuteTemplate、ExecutorEngine也是如此设计,引擎使用者提供CallBack实现类,使用该模式是因为在SQL执行时,需要支持更多类型的SQL,不同的SQL如DQL、DML、DDL、不带参数的SQL、参数化SQL等,不同的SQL操作逻辑并不一样,但执行引擎需要提供一个通用的执行策略。

代码执行分析

继续回到起点,在ShardingPreparedStatement类中

    @Override
    public ResultSet executeQuery() throws SQLException {
        ResultSet result;
        try {
…
            initPreparedStatementExecutor();//PreparedStatement执行器初始化
            MergedResult mergedResult = mergeQuery(preparedStatementExecutor.executeQuery());
…
 }
    private void initPreparedStatementExecutor() throws SQLException {
        preparedStatementExecutor.init(executionContext);
        setParametersForStatements();// 设置Statement参数
        replayMethodForStatements();// satement设置方法调用
}
    private void setParametersForStatements() {
        for (int i = 0; i < preparedStatementExecutor.getStatements().size(); i++) {
            replaySetParameter((PreparedStatement) preparedStatementExecutor.getStatements().get(i), preparedStatementExecutor.getParameterSets().get(i));
        }
    }
    private void replayMethodForStatements() {
        for (Statement each : preparedStatementExecutor.getStatements()) {
            replayMethodsInvocation(each);
        }
    }

可以看到进行了preparedStatementExecutor的初始化、Statement参数设置、方法回放等操作。进入PreparedStatementExecutor类中
org.apache.shardingsphere.shardingjdbc.executor.PreparedStatementExecutor

/**
 * Prepared statement executor.
 */
public final class PreparedStatementExecutor extends AbstractStatementExecutor {
    
    @Getter
    private final boolean returnGeneratedKeys;
    
    public PreparedStatementExecutor(
            final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability, final boolean returnGeneratedKeys, final ShardingConnection shardingConnection) {
        super(resultSetType, resultSetConcurrency, resultSetHoldability, shardingConnection);
        this.returnGeneratedKeys = returnGeneratedKeys;
    }
    
    /**
     * Initialize executor.
     *
     * @param executionContext execution context
     * @throws SQLException SQL exception
     */
    public void init(final ExecutionContext executionContext) throws SQLException {
        setSqlStatementContext(executionContext.getSqlStatementContext());
        getInputGroups().addAll(obtainExecuteGroups(executionContext.getExecutionUnits()));// 生成执行分组
        cacheStatements();
    }

    private Collection<InputGroup<StatementExecuteUnit>> obtainExecuteGroups(final Collection<ExecutionUnit> executionUnits) throws SQLException {
        return getSqlExecutePrepareTemplate().getExecuteUnitGroups(executionUnits, new SQLExecutePrepareCallback() {

            @Override
            // 在指定数据源上创建要求数量的数据库连接
            public List<Connection> getConnections(final ConnectionMode connectionMode, final String dataSourceName, final int connectionSize) throws SQLException {
                return PreparedStatementExecutor.super.getConnection().getConnections(connectionMode, dataSourceName, connectionSize);
            }

            @Override
            //根据执行单元信息 创建Statement执行单元对象
            public StatementExecuteUnit createStatementExecuteUnit(final Connection connection, final ExecutionUnit executionUnit, final ConnectionMode connectionMode) throws SQLException {
                return new StatementExecuteUnit(executionUnit, createPreparedStatement(connection, executionUnit.getSqlUnit().getSql()), connectionMode);
            }
        });
    }

    @SuppressWarnings("MagicConstant")
    private PreparedStatement createPreparedStatement(final Connection connection, final String sql) throws SQLException {
        return returnGeneratedKeys ? connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)
                : connection.prepareStatement(sql, getResultSetType(), getResultSetConcurrency(), getResultSetHoldability());
    }
    
    /**
     * Execute query.
     *
     * @return result set list
     * @throws SQLException SQL exception
     */
    public List<QueryResult> executeQuery() throws SQLException {
        final boolean isExceptionThrown = ExecutorExceptionHandler.isExceptionThrown();
        SQLExecuteCallback<QueryResult> executeCallback = new SQLExecuteCallback<QueryResult>(getDatabaseType(), isExceptionThrown) {
            
            @Override
            // 在指定的Statement上执行SQL,将JDBC结果集包装成查询QueryResult对象(基于流模式、基于内存模式两类)
            protected QueryResult executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode) throws SQLException {
                return getQueryResult(statement, connectionMode);
            }
        };
        return executeCallback(executeCallback);// 通过executeCallback操作
    }
    // 执行SQL,然后将结果集转成QueryResult对象
    private QueryResult getQueryResult(final Statement statement, final ConnectionMode connectionMode) throws SQLException {
        PreparedStatement preparedStatement = (PreparedStatement) statement;
        ResultSet resultSet = preparedStatement.executeQuery();
        getResultSets().add(resultSet);
        return ConnectionMode.MEMORY_STRICTLY == connectionMode ? new StreamQueryResult(resultSet) : new MemoryQueryResult(resultSet);
    }
…
}

首先看init方法中调用obtainExecuteGroups方法,obtainExecuteGroups方法中又调用SQLExecutePrepareTemplate.的getExecuteUnitGroups方法,将输入的ExecutionUnit集合和SQLExecutePrepareCallback生成InputGroup<StatementExecuteUnit>集合。
进入SQLExecutePrepareTemplate类看看getExecuteUnitGroups方法:
org.apache.shardingsphere.sharding.execute.sql.prepare.SQLExecutePrepareTemplate

/**
 * SQL execute prepare template.
 */
@RequiredArgsConstructor
public final class SQLExecutePrepareTemplate {
    
    private final int maxConnectionsSizePerQuery;
    
    /**
     * Get execute unit groups.
     *
     * @param executionUnits execution units
     * @param callback SQL execute prepare callback
     * @return statement execute unit groups
     * @throws SQLException SQL exception
     */
    public Collection<InputGroup<StatementExecuteUnit>> getExecuteUnitGroups(final Collection<ExecutionUnit> executionUnits, final SQLExecutePrepareCallback callback) throws SQLException {
        return getSynchronizedExecuteUnitGroups(executionUnits, callback);
    }

    // 生成同步执行单元分组
    private Collection<InputGroup<StatementExecuteUnit>> getSynchronizedExecuteUnitGroups(
            final Collection<ExecutionUnit> executionUnits, final SQLExecutePrepareCallback callback) throws SQLException {
        Map<String, List<SQLUnit>> sqlUnitGroups = getSQLUnitGroups(executionUnits);// 生成数据源与其SQLUnit的对应映射
        Collection<InputGroup<StatementExecuteUnit>> result = new LinkedList<>();
        for (Entry<String, List<SQLUnit>> entry : sqlUnitGroups.entrySet()) {
            result.addAll(getSQLExecuteGroups(entry.getKey(), entry.getValue(), callback));// 将SQLUnit转化为InputGroup<StatementExecuteUnit>,对应关系为1:1
        }
        return result;
    }

    // 根据执行单元ExecutionUnit,生成各数据源对应的SQLUnit集合
    private Map<String, List<SQLUnit>> getSQLUnitGroups(final Collection<ExecutionUnit> executionUnits) {
        Map<String, List<SQLUnit>> result = new LinkedHashMap<>(executionUnits.size(), 1);
        for (ExecutionUnit each : executionUnits) {
            if (!result.containsKey(each.getDataSourceName())) {
                result.put(each.getDataSourceName(), new LinkedList<>());
            }
            result.get(each.getDataSourceName()).add(each.getSqlUnit());
        }
        return result;
    }
    // 生成SQL执行分组
    private List<InputGroup<StatementExecuteUnit>> getSQLExecuteGroups(final String dataSourceName,
                                                                       final List<SQLUnit> sqlUnits, final SQLExecutePrepareCallback callback) throws SQLException {
        List<InputGroup<StatementExecuteUnit>> result = new LinkedList<>();
        int desiredPartitionSize = Math.max(0 == sqlUnits.size() % maxConnectionsSizePerQuery ? sqlUnits.size() / maxConnectionsSizePerQuery : sqlUnits.size() / maxConnectionsSizePerQuery + 1, 1);
        List<List<SQLUnit>> sqlUnitPartitions = Lists.partition(sqlUnits, desiredPartitionSize);
        ConnectionMode connectionMode = maxConnectionsSizePerQuery < sqlUnits.size() ? ConnectionMode.CONNECTION_STRICTLY : ConnectionMode.MEMORY_STRICTLY;
        List<Connection> connections = callback.getConnections(connectionMode, dataSourceName, sqlUnitPartitions.size()); // 根据要执行的SQL数量和maxConnectionsSizePerQuery配置,计算
        int count = 0;
        for (List<SQLUnit> each : sqlUnitPartitions) {
            result.add(getSQLExecuteGroup(connectionMode, connections.get(count++), dataSourceName, each, callback));// 根据要执行的SQLUnit,生成对应StatementExecuteUnit对象,添加到返回结果集中
        }
        return result;
    }
    
    private InputGroup<StatementExecuteUnit> getSQLExecuteGroup(final ConnectionMode connectionMode, final Connection connection,
                                                                final String dataSourceName, final List<SQLUnit> sqlUnitGroup, final SQLExecutePrepareCallback callback) throws SQLException {
        List<StatementExecuteUnit> result = new LinkedList<>();
        for (SQLUnit each : sqlUnitGroup) {
            result.add(callback.createStatementExecuteUnit(connection, new ExecutionUnit(dataSourceName, each), connectionMode));
        }
        return new InputGroup<>(result);
    }

可以看到,SQLExecutePrepareTemplate类就是将ExecutionUnit集合进行分组转化为InputGroup<StatementExecuteUnit>集合。其核心逻辑是根据maxConnectionsSizePerQuery值(每个SQL最多可以配置多少数据库连接供使用),计算出当前SQL需要多少个数据库连接

    /**
     * Max opened connection size for each query.
     */
    MAX_CONNECTIONS_SIZE_PER_QUERY("max.connections.size.per.query", String.valueOf(1), int.class),

org.apache.shardingsphere.sharding.execute.sql.execute.SQLExecuteTemplate

public final class SQLExecuteTemplate {
    
    private final ExecutorEngine executorEngine;
    
    private final boolean serial;
    
    /**
     * Execute.
     *
     * @param inputGroups input groups
     * @param callback SQL execute callback
     * @param <T> class type of return value
     * @return execute result
     * @throws SQLException SQL exception
     */
    public <T> List<T> execute(final Collection<InputGroup<? extends StatementExecuteUnit>> inputGroups, final SQLExecuteCallback<T> callback) throws SQLException {
        return execute(inputGroups, null, callback);
    }
    
    /**
     * Execute.
     *
     * @param inputGroups input groups
     * @param firstCallback first SQL execute callback
     * @param callback SQL execute callback
     * @param <T> class type of return value
     * @return execute result
     * @throws SQLException SQL exception
     */
    @SuppressWarnings("unchecked")
    public <T> List<T> execute(final Collection<InputGroup<? extends StatementExecuteUnit>> inputGroups,
                               final SQLExecuteCallback<T> firstCallback, final SQLExecuteCallback<T> callback) throws SQLException {
        try {
            return executorEngine.execute((Collection) inputGroups, firstCallback, callback, serial);
        } catch (final SQLException ex) {
            ExecutorExceptionHandler.handleException(ex);
            return Collections.emptyList();
        }
    }
}

可以看到其内部操作又是通过ExecutorEngine类完成,进入该类看看
org.apache.shardingsphere.underlying.executor.engine.ExecutorEngine

/**
 * Executor engine.
 */
public final class ExecutorEngine implements AutoCloseable {
    
    private final ShardingSphereExecutorService executorService;
    
    public ExecutorEngine(final int executorSize) {
        executorService = new ShardingSphereExecutorService(executorSize);
    }
    
    /**
     * Execute.
     *
     * @param inputGroups input groups
     * @param callback grouped callback
     * @param <I> type of input value
     * @param <O> type of return value
     * @return execute result
     * @throws SQLException throw if execute failure
     */
    public <I, O> List<O> execute(final Collection<InputGroup<I>> inputGroups, final GroupedCallback<I, O> callback) throws SQLException {
        return execute(inputGroups, null, callback, false);
    }
    
    /**
     * Execute.
     *
     * @param inputGroups input groups
     * @param firstCallback first grouped callback
     * @param callback other grouped callback
     * @param serial whether using multi thread execute or not
     * @param <I> type of input value
     * @param <O> type of return value
     * @return execute result
     * @throws SQLException throw if execute failure
     */
    public <I, O> List<O> execute(final Collection<InputGroup<I>> inputGroups, 
                                  final GroupedCallback<I, O> firstCallback, final GroupedCallback<I, O> callback, final boolean serial) throws SQLException {
        if (inputGroups.isEmpty()) {
            return Collections.emptyList();
        }
        return serial ? serialExecute(inputGroups, firstCallback, callback) : parallelExecute(inputGroups, firstCallback, callback);
    }

    // 串行执行
    private <I, O> List<O> serialExecute(final Collection<InputGroup<I>> inputGroups, final GroupedCallback<I, O> firstCallback, final GroupedCallback<I, O> callback) throws SQLException {
        Iterator<InputGroup<I>> inputGroupsIterator = inputGroups.iterator();
        InputGroup<I> firstInputs = inputGroupsIterator.next();
        List<O> result = new LinkedList<>(syncExecute(firstInputs, null == firstCallback ? callback : firstCallback));
        for (InputGroup<I> each : Lists.newArrayList(inputGroupsIterator)) {
            result.addAll(syncExecute(each, callback));
        }
        return result;
    }

    // 并行执行,可以支持两个回调函数,第一条记录执行第一个回调函数,其它的执行第二个回调函数
    private <I, O> List<O> parallelExecute(final Collection<InputGroup<I>> inputGroups, final GroupedCallback<I, O> firstCallback, final GroupedCallback<I, O> callback) throws SQLException {
        Iterator<InputGroup<I>> inputGroupsIterator = inputGroups.iterator();
        InputGroup<I> firstInputs = inputGroupsIterator.next();
        Collection<ListenableFuture<Collection<O>>> restResultFutures = asyncExecute(Lists.newArrayList(inputGroupsIterator), callback);
        return getGroupResults(syncExecute(firstInputs, null == firstCallback ? callback : firstCallback), restResultFutures);
    }

    // 同步执行
    private <I, O> Collection<O> syncExecute(final InputGroup<I> inputGroup, final GroupedCallback<I, O> callback) throws SQLException {
        return callback.execute(inputGroup.getInputs(), true, ExecutorDataMap.getValue());
    }

    // 异步执行
    private <I, O> Collection<ListenableFuture<Collection<O>>> asyncExecute(final List<InputGroup<I>> inputGroups, final GroupedCallback<I, O> callback) {
        Collection<ListenableFuture<Collection<O>>> result = new LinkedList<>();
        for (InputGroup<I> each : inputGroups) {
            result.add(asyncExecute(each, callback));
        }
        return result;
    }
    
    private <I, O> ListenableFuture<Collection<O>> asyncExecute(final InputGroup<I> inputGroup, final GroupedCallback<I, O> callback) {
        final Map<String, Object> dataMap = ExecutorDataMap.getValue();
        return executorService.getExecutorService().submit(() -> callback.execute(inputGroup.getInputs(), false, dataMap));
    }
    
    private <O> List<O> getGroupResults(final Collection<O> firstResults, final Collection<ListenableFuture<Collection<O>>> restFutures) throws SQLException {
        List<O> result = new LinkedList<>(firstResults);
        for (ListenableFuture<Collection<O>> each : restFutures) {
            try {
                result.addAll(each.get());
            } catch (final InterruptedException | ExecutionException ex) {
                return throwException(ex);
            }
        }
        return result;
    }

ExecutorEngine类中方法主要分为两个串行执行serialExecute与并行执行parallelExecute,前者使用的是同步执行即当前应用线程,后者则通过ShardingSphere内置的线程池完成,该线程池类为ShardingSphereExecutorService。值得注意的是这些执行方法中都对应的有两个CallBack参数,在真正执行时会对分组后的第一条记录执行第一个CallBack函数,其它的执行第二个CallBack函数,这么设计的目的是有些操作只需执行一次,例如获取元数据,只需要在第一条记录操作生成,后续直接复用即可。

例如在Sharding-proxy中
org.apache.shardingsphere.shardingproxy.backend.communication.jdbc.execute.JDBCExecuteEngine

public BackendResponse execute(final ExecutionContext executionContext) throws SQLException {
        SQLStatementContext sqlStatementContext = executionContext.getSqlStatementContext();
        boolean isReturnGeneratedKeys = sqlStatementContext.getSqlStatement() instanceof InsertStatement;
        boolean isExceptionThrown = ExecutorExceptionHandler.isExceptionThrown();
        Collection<InputGroup<StatementExecuteUnit>> inputGroups = sqlExecutePrepareTemplate.getExecuteUnitGroups(
                executionContext.getExecutionUnits(), new ProxyJDBCExecutePrepareCallback(backendConnection, jdbcExecutorWrapper, isReturnGeneratedKeys));
        Collection<ExecuteResponse> executeResponses = sqlExecuteTemplate.execute((Collection) inputGroups,
                new ProxySQLExecuteCallback(sqlStatementContext, backendConnection, jdbcExecutorWrapper, isExceptionThrown, isReturnGeneratedKeys, true),
                new ProxySQLExecuteCallback(sqlStatementContext, backendConnection, jdbcExecutorWrapper, isExceptionThrown, isReturnGeneratedKeys, false));
        ExecuteResponse executeResponse = executeResponses.iterator().next();
      …
    }

回头看下ShardingSphere自定义的线程池
org.apache.shardingsphere.underlying.executor.engine.impl.ShardingSphereExecutorService

/**
 * ShardingSphere executor service.
 */
@Getter
public final class ShardingSphereExecutorService {
    
    private static final String DEFAULT_NAME_FORMAT = "%d";
    
    private static final ExecutorService SHUTDOWN_EXECUTOR = Executors.newSingleThreadExecutor(ShardingSphereThreadFactoryBuilder.build("Executor-Engine-Closer"));
    
    private ListeningExecutorService executorService;
    
    public ShardingSphereExecutorService(final int executorSize) {
        this(executorSize, DEFAULT_NAME_FORMAT);
    }
    
    public ShardingSphereExecutorService(final int executorSize, final String nameFormat) {
        executorService = MoreExecutors.listeningDecorator(getExecutorService(executorSize, nameFormat));
        MoreExecutors.addDelayedShutdownHook(executorService, 60, TimeUnit.SECONDS);
    }
    
    private ExecutorService getExecutorService(final int executorSize, final String nameFormat) {
        ThreadFactory threadFactory = ShardingSphereThreadFactoryBuilder.build(nameFormat);
        return 0 == executorSize ? Executors.newCachedThreadPool(threadFactory) : Executors.newFixedThreadPool(executorSize, threadFactory);
    }
   …
} 

可以看到ShardingSphereExecutorService 类中用的并不是JDK中原生的,而是google guava工具包中的可监听ExecutorService,不过目前ShardingSphere中没看到使用其listen功能,应该是为后续扩展考虑。

总结

相比其它引擎,可以看到执行引擎较为简单,主要包括三部分:1. 是SQLExecutePrepareTemplate,2. 是SQLExecuteTemplate,3. ExecutorEngine。
SQLExecutePrepareTemplate类负责生成执行分组信息,输入为 Collection<ExecutionUnit> ,输出为Collection<InputGroup<StatementExecuteUnit>>;SQLExecuteTemplate类负责执行具体的SQL操作,输入为Collection<InputGroup<StatementExecuteUnit>>与SQLExecuteCallback,这个类目前并没有自身逻辑,它就是直接调用了ExecutorEngine类完成SQL执行;ExecutorEngine则真正负责完成SQL的串行和并行执行。

在5.x中执行引擎的类名进行了调整,SQLExecuteTemplate修改为org.apache.shardingsphere.infra.executor.sql.resourced.jdbc.executor.SQLExecutor,
ExecutorEngine类修改为org.apache.shardingsphere.infra.executor.kernel.ExecutorKernel,但具体功能实现没有太大变化

最后画一个执行引擎的流程图:

执行引擎流程图
上一篇下一篇

猜你喜欢

热点阅读