从源码看ShardingSphere设计-执行引擎篇
执行引擎的职责定位是将改写后的SQL发送到对应数据库(经路由计算所得)执行的过程。执行引擎采用了callback回调的设计模式,对给定的输入分组集合执行指定的callback函数。
与Spring的JDBCTemplate、TransactionTemplate类似,ShardingSphere中的SQLExecuteTemplate、ExecutorEngine也是如此设计,引擎使用者提供CallBack实现类,使用该模式是因为在SQL执行时,需要支持更多类型的SQL,不同的SQL如DQL、DML、DDL、不带参数的SQL、参数化SQL等,不同的SQL操作逻辑并不一样,但执行引擎需要提供一个通用的执行策略。
代码执行分析
继续回到起点,在ShardingPreparedStatement类中
@Override
public ResultSet executeQuery() throws SQLException {
ResultSet result;
try {
…
initPreparedStatementExecutor();//PreparedStatement执行器初始化
MergedResult mergedResult = mergeQuery(preparedStatementExecutor.executeQuery());
…
}
private void initPreparedStatementExecutor() throws SQLException {
preparedStatementExecutor.init(executionContext);
setParametersForStatements();// 设置Statement参数
replayMethodForStatements();// satement设置方法调用
}
private void setParametersForStatements() {
for (int i = 0; i < preparedStatementExecutor.getStatements().size(); i++) {
replaySetParameter((PreparedStatement) preparedStatementExecutor.getStatements().get(i), preparedStatementExecutor.getParameterSets().get(i));
}
}
private void replayMethodForStatements() {
for (Statement each : preparedStatementExecutor.getStatements()) {
replayMethodsInvocation(each);
}
}
可以看到进行了preparedStatementExecutor的初始化、Statement参数设置、方法回放等操作。进入PreparedStatementExecutor类中
org.apache.shardingsphere.shardingjdbc.executor.PreparedStatementExecutor
/**
* Prepared statement executor.
*/
public final class PreparedStatementExecutor extends AbstractStatementExecutor {
@Getter
private final boolean returnGeneratedKeys;
public PreparedStatementExecutor(
final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability, final boolean returnGeneratedKeys, final ShardingConnection shardingConnection) {
super(resultSetType, resultSetConcurrency, resultSetHoldability, shardingConnection);
this.returnGeneratedKeys = returnGeneratedKeys;
}
/**
* Initialize executor.
*
* @param executionContext execution context
* @throws SQLException SQL exception
*/
public void init(final ExecutionContext executionContext) throws SQLException {
setSqlStatementContext(executionContext.getSqlStatementContext());
getInputGroups().addAll(obtainExecuteGroups(executionContext.getExecutionUnits()));// 生成执行分组
cacheStatements();
}
private Collection<InputGroup<StatementExecuteUnit>> obtainExecuteGroups(final Collection<ExecutionUnit> executionUnits) throws SQLException {
return getSqlExecutePrepareTemplate().getExecuteUnitGroups(executionUnits, new SQLExecutePrepareCallback() {
@Override
// 在指定数据源上创建要求数量的数据库连接
public List<Connection> getConnections(final ConnectionMode connectionMode, final String dataSourceName, final int connectionSize) throws SQLException {
return PreparedStatementExecutor.super.getConnection().getConnections(connectionMode, dataSourceName, connectionSize);
}
@Override
//根据执行单元信息 创建Statement执行单元对象
public StatementExecuteUnit createStatementExecuteUnit(final Connection connection, final ExecutionUnit executionUnit, final ConnectionMode connectionMode) throws SQLException {
return new StatementExecuteUnit(executionUnit, createPreparedStatement(connection, executionUnit.getSqlUnit().getSql()), connectionMode);
}
});
}
@SuppressWarnings("MagicConstant")
private PreparedStatement createPreparedStatement(final Connection connection, final String sql) throws SQLException {
return returnGeneratedKeys ? connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)
: connection.prepareStatement(sql, getResultSetType(), getResultSetConcurrency(), getResultSetHoldability());
}
/**
* Execute query.
*
* @return result set list
* @throws SQLException SQL exception
*/
public List<QueryResult> executeQuery() throws SQLException {
final boolean isExceptionThrown = ExecutorExceptionHandler.isExceptionThrown();
SQLExecuteCallback<QueryResult> executeCallback = new SQLExecuteCallback<QueryResult>(getDatabaseType(), isExceptionThrown) {
@Override
// 在指定的Statement上执行SQL,将JDBC结果集包装成查询QueryResult对象(基于流模式、基于内存模式两类)
protected QueryResult executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode) throws SQLException {
return getQueryResult(statement, connectionMode);
}
};
return executeCallback(executeCallback);// 通过executeCallback操作
}
// 执行SQL,然后将结果集转成QueryResult对象
private QueryResult getQueryResult(final Statement statement, final ConnectionMode connectionMode) throws SQLException {
PreparedStatement preparedStatement = (PreparedStatement) statement;
ResultSet resultSet = preparedStatement.executeQuery();
getResultSets().add(resultSet);
return ConnectionMode.MEMORY_STRICTLY == connectionMode ? new StreamQueryResult(resultSet) : new MemoryQueryResult(resultSet);
}
…
}
首先看init方法中调用obtainExecuteGroups方法,obtainExecuteGroups方法中又调用SQLExecutePrepareTemplate.的getExecuteUnitGroups方法,将输入的ExecutionUnit集合和SQLExecutePrepareCallback生成InputGroup<StatementExecuteUnit>集合。
进入SQLExecutePrepareTemplate类看看getExecuteUnitGroups方法:
org.apache.shardingsphere.sharding.execute.sql.prepare.SQLExecutePrepareTemplate
/**
* SQL execute prepare template.
*/
@RequiredArgsConstructor
public final class SQLExecutePrepareTemplate {
private final int maxConnectionsSizePerQuery;
/**
* Get execute unit groups.
*
* @param executionUnits execution units
* @param callback SQL execute prepare callback
* @return statement execute unit groups
* @throws SQLException SQL exception
*/
public Collection<InputGroup<StatementExecuteUnit>> getExecuteUnitGroups(final Collection<ExecutionUnit> executionUnits, final SQLExecutePrepareCallback callback) throws SQLException {
return getSynchronizedExecuteUnitGroups(executionUnits, callback);
}
// 生成同步执行单元分组
private Collection<InputGroup<StatementExecuteUnit>> getSynchronizedExecuteUnitGroups(
final Collection<ExecutionUnit> executionUnits, final SQLExecutePrepareCallback callback) throws SQLException {
Map<String, List<SQLUnit>> sqlUnitGroups = getSQLUnitGroups(executionUnits);// 生成数据源与其SQLUnit的对应映射
Collection<InputGroup<StatementExecuteUnit>> result = new LinkedList<>();
for (Entry<String, List<SQLUnit>> entry : sqlUnitGroups.entrySet()) {
result.addAll(getSQLExecuteGroups(entry.getKey(), entry.getValue(), callback));// 将SQLUnit转化为InputGroup<StatementExecuteUnit>,对应关系为1:1
}
return result;
}
// 根据执行单元ExecutionUnit,生成各数据源对应的SQLUnit集合
private Map<String, List<SQLUnit>> getSQLUnitGroups(final Collection<ExecutionUnit> executionUnits) {
Map<String, List<SQLUnit>> result = new LinkedHashMap<>(executionUnits.size(), 1);
for (ExecutionUnit each : executionUnits) {
if (!result.containsKey(each.getDataSourceName())) {
result.put(each.getDataSourceName(), new LinkedList<>());
}
result.get(each.getDataSourceName()).add(each.getSqlUnit());
}
return result;
}
// 生成SQL执行分组
private List<InputGroup<StatementExecuteUnit>> getSQLExecuteGroups(final String dataSourceName,
final List<SQLUnit> sqlUnits, final SQLExecutePrepareCallback callback) throws SQLException {
List<InputGroup<StatementExecuteUnit>> result = new LinkedList<>();
int desiredPartitionSize = Math.max(0 == sqlUnits.size() % maxConnectionsSizePerQuery ? sqlUnits.size() / maxConnectionsSizePerQuery : sqlUnits.size() / maxConnectionsSizePerQuery + 1, 1);
List<List<SQLUnit>> sqlUnitPartitions = Lists.partition(sqlUnits, desiredPartitionSize);
ConnectionMode connectionMode = maxConnectionsSizePerQuery < sqlUnits.size() ? ConnectionMode.CONNECTION_STRICTLY : ConnectionMode.MEMORY_STRICTLY;
List<Connection> connections = callback.getConnections(connectionMode, dataSourceName, sqlUnitPartitions.size()); // 根据要执行的SQL数量和maxConnectionsSizePerQuery配置,计算
int count = 0;
for (List<SQLUnit> each : sqlUnitPartitions) {
result.add(getSQLExecuteGroup(connectionMode, connections.get(count++), dataSourceName, each, callback));// 根据要执行的SQLUnit,生成对应StatementExecuteUnit对象,添加到返回结果集中
}
return result;
}
private InputGroup<StatementExecuteUnit> getSQLExecuteGroup(final ConnectionMode connectionMode, final Connection connection,
final String dataSourceName, final List<SQLUnit> sqlUnitGroup, final SQLExecutePrepareCallback callback) throws SQLException {
List<StatementExecuteUnit> result = new LinkedList<>();
for (SQLUnit each : sqlUnitGroup) {
result.add(callback.createStatementExecuteUnit(connection, new ExecutionUnit(dataSourceName, each), connectionMode));
}
return new InputGroup<>(result);
}
可以看到,SQLExecutePrepareTemplate类就是将ExecutionUnit集合进行分组转化为InputGroup<StatementExecuteUnit>集合。其核心逻辑是根据maxConnectionsSizePerQuery值(每个SQL最多可以配置多少数据库连接供使用),计算出当前SQL需要多少个数据库连接
/**
* Max opened connection size for each query.
*/
MAX_CONNECTIONS_SIZE_PER_QUERY("max.connections.size.per.query", String.valueOf(1), int.class),
org.apache.shardingsphere.sharding.execute.sql.execute.SQLExecuteTemplate
public final class SQLExecuteTemplate {
private final ExecutorEngine executorEngine;
private final boolean serial;
/**
* Execute.
*
* @param inputGroups input groups
* @param callback SQL execute callback
* @param <T> class type of return value
* @return execute result
* @throws SQLException SQL exception
*/
public <T> List<T> execute(final Collection<InputGroup<? extends StatementExecuteUnit>> inputGroups, final SQLExecuteCallback<T> callback) throws SQLException {
return execute(inputGroups, null, callback);
}
/**
* Execute.
*
* @param inputGroups input groups
* @param firstCallback first SQL execute callback
* @param callback SQL execute callback
* @param <T> class type of return value
* @return execute result
* @throws SQLException SQL exception
*/
@SuppressWarnings("unchecked")
public <T> List<T> execute(final Collection<InputGroup<? extends StatementExecuteUnit>> inputGroups,
final SQLExecuteCallback<T> firstCallback, final SQLExecuteCallback<T> callback) throws SQLException {
try {
return executorEngine.execute((Collection) inputGroups, firstCallback, callback, serial);
} catch (final SQLException ex) {
ExecutorExceptionHandler.handleException(ex);
return Collections.emptyList();
}
}
}
可以看到其内部操作又是通过ExecutorEngine类完成,进入该类看看
org.apache.shardingsphere.underlying.executor.engine.ExecutorEngine
/**
* Executor engine.
*/
public final class ExecutorEngine implements AutoCloseable {
private final ShardingSphereExecutorService executorService;
public ExecutorEngine(final int executorSize) {
executorService = new ShardingSphereExecutorService(executorSize);
}
/**
* Execute.
*
* @param inputGroups input groups
* @param callback grouped callback
* @param <I> type of input value
* @param <O> type of return value
* @return execute result
* @throws SQLException throw if execute failure
*/
public <I, O> List<O> execute(final Collection<InputGroup<I>> inputGroups, final GroupedCallback<I, O> callback) throws SQLException {
return execute(inputGroups, null, callback, false);
}
/**
* Execute.
*
* @param inputGroups input groups
* @param firstCallback first grouped callback
* @param callback other grouped callback
* @param serial whether using multi thread execute or not
* @param <I> type of input value
* @param <O> type of return value
* @return execute result
* @throws SQLException throw if execute failure
*/
public <I, O> List<O> execute(final Collection<InputGroup<I>> inputGroups,
final GroupedCallback<I, O> firstCallback, final GroupedCallback<I, O> callback, final boolean serial) throws SQLException {
if (inputGroups.isEmpty()) {
return Collections.emptyList();
}
return serial ? serialExecute(inputGroups, firstCallback, callback) : parallelExecute(inputGroups, firstCallback, callback);
}
// 串行执行
private <I, O> List<O> serialExecute(final Collection<InputGroup<I>> inputGroups, final GroupedCallback<I, O> firstCallback, final GroupedCallback<I, O> callback) throws SQLException {
Iterator<InputGroup<I>> inputGroupsIterator = inputGroups.iterator();
InputGroup<I> firstInputs = inputGroupsIterator.next();
List<O> result = new LinkedList<>(syncExecute(firstInputs, null == firstCallback ? callback : firstCallback));
for (InputGroup<I> each : Lists.newArrayList(inputGroupsIterator)) {
result.addAll(syncExecute(each, callback));
}
return result;
}
// 并行执行,可以支持两个回调函数,第一条记录执行第一个回调函数,其它的执行第二个回调函数
private <I, O> List<O> parallelExecute(final Collection<InputGroup<I>> inputGroups, final GroupedCallback<I, O> firstCallback, final GroupedCallback<I, O> callback) throws SQLException {
Iterator<InputGroup<I>> inputGroupsIterator = inputGroups.iterator();
InputGroup<I> firstInputs = inputGroupsIterator.next();
Collection<ListenableFuture<Collection<O>>> restResultFutures = asyncExecute(Lists.newArrayList(inputGroupsIterator), callback);
return getGroupResults(syncExecute(firstInputs, null == firstCallback ? callback : firstCallback), restResultFutures);
}
// 同步执行
private <I, O> Collection<O> syncExecute(final InputGroup<I> inputGroup, final GroupedCallback<I, O> callback) throws SQLException {
return callback.execute(inputGroup.getInputs(), true, ExecutorDataMap.getValue());
}
// 异步执行
private <I, O> Collection<ListenableFuture<Collection<O>>> asyncExecute(final List<InputGroup<I>> inputGroups, final GroupedCallback<I, O> callback) {
Collection<ListenableFuture<Collection<O>>> result = new LinkedList<>();
for (InputGroup<I> each : inputGroups) {
result.add(asyncExecute(each, callback));
}
return result;
}
private <I, O> ListenableFuture<Collection<O>> asyncExecute(final InputGroup<I> inputGroup, final GroupedCallback<I, O> callback) {
final Map<String, Object> dataMap = ExecutorDataMap.getValue();
return executorService.getExecutorService().submit(() -> callback.execute(inputGroup.getInputs(), false, dataMap));
}
private <O> List<O> getGroupResults(final Collection<O> firstResults, final Collection<ListenableFuture<Collection<O>>> restFutures) throws SQLException {
List<O> result = new LinkedList<>(firstResults);
for (ListenableFuture<Collection<O>> each : restFutures) {
try {
result.addAll(each.get());
} catch (final InterruptedException | ExecutionException ex) {
return throwException(ex);
}
}
return result;
}
ExecutorEngine类中方法主要分为两个串行执行serialExecute与并行执行parallelExecute,前者使用的是同步执行即当前应用线程,后者则通过ShardingSphere内置的线程池完成,该线程池类为ShardingSphereExecutorService。值得注意的是这些执行方法中都对应的有两个CallBack参数,在真正执行时会对分组后的第一条记录执行第一个CallBack函数,其它的执行第二个CallBack函数,这么设计的目的是有些操作只需执行一次,例如获取元数据,只需要在第一条记录操作生成,后续直接复用即可。
例如在Sharding-proxy中
org.apache.shardingsphere.shardingproxy.backend.communication.jdbc.execute.JDBCExecuteEngine
public BackendResponse execute(final ExecutionContext executionContext) throws SQLException {
SQLStatementContext sqlStatementContext = executionContext.getSqlStatementContext();
boolean isReturnGeneratedKeys = sqlStatementContext.getSqlStatement() instanceof InsertStatement;
boolean isExceptionThrown = ExecutorExceptionHandler.isExceptionThrown();
Collection<InputGroup<StatementExecuteUnit>> inputGroups = sqlExecutePrepareTemplate.getExecuteUnitGroups(
executionContext.getExecutionUnits(), new ProxyJDBCExecutePrepareCallback(backendConnection, jdbcExecutorWrapper, isReturnGeneratedKeys));
Collection<ExecuteResponse> executeResponses = sqlExecuteTemplate.execute((Collection) inputGroups,
new ProxySQLExecuteCallback(sqlStatementContext, backendConnection, jdbcExecutorWrapper, isExceptionThrown, isReturnGeneratedKeys, true),
new ProxySQLExecuteCallback(sqlStatementContext, backendConnection, jdbcExecutorWrapper, isExceptionThrown, isReturnGeneratedKeys, false));
ExecuteResponse executeResponse = executeResponses.iterator().next();
…
}
回头看下ShardingSphere自定义的线程池
org.apache.shardingsphere.underlying.executor.engine.impl.ShardingSphereExecutorService
/**
* ShardingSphere executor service.
*/
@Getter
public final class ShardingSphereExecutorService {
private static final String DEFAULT_NAME_FORMAT = "%d";
private static final ExecutorService SHUTDOWN_EXECUTOR = Executors.newSingleThreadExecutor(ShardingSphereThreadFactoryBuilder.build("Executor-Engine-Closer"));
private ListeningExecutorService executorService;
public ShardingSphereExecutorService(final int executorSize) {
this(executorSize, DEFAULT_NAME_FORMAT);
}
public ShardingSphereExecutorService(final int executorSize, final String nameFormat) {
executorService = MoreExecutors.listeningDecorator(getExecutorService(executorSize, nameFormat));
MoreExecutors.addDelayedShutdownHook(executorService, 60, TimeUnit.SECONDS);
}
private ExecutorService getExecutorService(final int executorSize, final String nameFormat) {
ThreadFactory threadFactory = ShardingSphereThreadFactoryBuilder.build(nameFormat);
return 0 == executorSize ? Executors.newCachedThreadPool(threadFactory) : Executors.newFixedThreadPool(executorSize, threadFactory);
}
…
}
可以看到ShardingSphereExecutorService 类中用的并不是JDK中原生的,而是google guava工具包中的可监听ExecutorService,不过目前ShardingSphere中没看到使用其listen功能,应该是为后续扩展考虑。
总结
相比其它引擎,可以看到执行引擎较为简单,主要包括三部分:1. 是SQLExecutePrepareTemplate,2. 是SQLExecuteTemplate,3. ExecutorEngine。
SQLExecutePrepareTemplate类负责生成执行分组信息,输入为 Collection<ExecutionUnit> ,输出为Collection<InputGroup<StatementExecuteUnit>>;SQLExecuteTemplate类负责执行具体的SQL操作,输入为Collection<InputGroup<StatementExecuteUnit>>与SQLExecuteCallback,这个类目前并没有自身逻辑,它就是直接调用了ExecutorEngine类完成SQL执行;ExecutorEngine则真正负责完成SQL的串行和并行执行。
在5.x中执行引擎的类名进行了调整,SQLExecuteTemplate修改为org.apache.shardingsphere.infra.executor.sql.resourced.jdbc.executor.SQLExecutor,
ExecutorEngine类修改为org.apache.shardingsphere.infra.executor.kernel.ExecutorKernel,但具体功能实现没有太大变化
最后画一个执行引擎的流程图:
执行引擎流程图