聊聊ShardingSphere是怎么进行sql重写的
序
本文主要研究一下ShardingSphere进行sql重写的原理
prepareStatement
org/apache/shardingsphere/driver/jdbc/core/connection/ShardingSphereConnection.java
public final class ShardingSphereConnection extends AbstractConnectionAdapter {
@Override
public PreparedStatement prepareStatement(final String sql) throws SQLException {
return new ShardingSpherePreparedStatement(this, sql);
}
//......
}
ShardingSphereConnection的prepareStatement创建的是ShardingSpherePreparedStatement
ShardingSpherePreparedStatement
org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java
public final class ShardingSpherePreparedStatement extends AbstractPreparedStatementAdapter {
@Getter
private final ShardingSphereConnection connection;
public ShardingSpherePreparedStatement(final ShardingSphereConnection connection, final String sql) throws SQLException {
this(connection, sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, ResultSet.HOLD_CURSORS_OVER_COMMIT, false, null);
}
private ShardingSpherePreparedStatement(final ShardingSphereConnection connection, final String sql,
final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability, final boolean returnGeneratedKeys,
final String[] columns) throws SQLException {
if (Strings.isNullOrEmpty(sql)) {
throw new EmptySQLException().toSQLException();
}
this.connection = connection;
metaDataContexts = connection.getContextManager().getMetaDataContexts();
SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
hintValueContext = sqlParserRule.isSqlCommentParseEnabled() ? new HintValueContext() : SQLHintUtils.extractHint(sql).orElseGet(HintValueContext::new);
this.sql = sqlParserRule.isSqlCommentParseEnabled() ? sql : SQLHintUtils.removeHint(sql);
statements = new ArrayList<>();
parameterSets = new ArrayList<>();
SQLParserEngine sqlParserEngine = sqlParserRule.getSQLParserEngine(
DatabaseTypeEngine.getTrunkDatabaseTypeName(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType()));
sqlStatement = sqlParserEngine.parse(this.sql, true);
sqlStatementContext = SQLStatementContextFactory.newInstance(metaDataContexts.getMetaData(), sqlStatement, connection.getDatabaseName());
parameterMetaData = new ShardingSphereParameterMetaData(sqlStatement);
statementOption = returnGeneratedKeys ? new StatementOption(true, columns) : new StatementOption(resultSetType, resultSetConcurrency, resultSetHoldability);
executor = new DriverExecutor(connection);
JDBCExecutor jdbcExecutor = new JDBCExecutor(connection.getContextManager().getExecutorEngine(), connection.getDatabaseConnectionManager().getConnectionContext());
batchPreparedStatementExecutor = new BatchPreparedStatementExecutor(metaDataContexts, jdbcExecutor, connection.getDatabaseName());
kernelProcessor = new KernelProcessor();
statementsCacheable = isStatementsCacheable(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getRuleMetaData());
trafficRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(TrafficRule.class);
selectContainsEnhancedTable = sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsEnhancedTable();
statementManager = new StatementManager();
}
//......
}
ShardingSpherePreparedStatement继承了AbstractPreparedStatementAdapter,其构造器主要是通过SQLParserEngine解析sql得到SQLStatement,创建DriverExecutor、BatchPreparedStatementExecutor、KernelProcessor、StatementManager;这里即使useServerPrepStmts=true,也不会触发mysql server的prepare操作
executeUpdate
public int executeUpdate() throws SQLException {
try {
if (statementsCacheable && !statements.isEmpty()) {
resetParameters();
return statements.iterator().next().executeUpdate();
}
clearPrevious();
QueryContext queryContext = createQueryContext();
trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeUpdate());
}
executionContext = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Collection<ExecuteResult> executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return accumulate(executeResults);
}
return isNeedImplicitCommitTransaction(connection, executionContext) ? executeUpdateWithImplicitCommitTransaction() : useDriverToExecuteUpdate();
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
handleExceptionInTransaction(connection, metaDataContexts);
throw SQLExceptionTransformEngine.toSQLException(ex, metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType().getType());
} finally {
clearBatch();
}
}
private void clearPrevious() {
statements.clear();
parameterSets.clear();
generatedValues.clear();
}
private ExecutionContext createExecutionContext(final QueryContext queryContext) {
ShardingSphereRuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName());
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
ExecutionContext result = kernelProcessor.generateExecutionContext(
queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
return result;
}
这里executeUpdate会先执行clearPrevious方法,清空statements、parameterSets、generatedValues,然后createExecutionContext,这里有一步是kernelProcessor.generateExecutionContext
KernelProcessor
generateExecutionContext
shardingsphere-infra-context-5.4.0-sources.jar!/org/apache/shardingsphere/infra/connection/kernel/KernelProcessor.java
public ExecutionContext generateExecutionContext(final QueryContext queryContext, final ShardingSphereDatabase database, final ShardingSphereRuleMetaData globalRuleMetaData,
final ConfigurationProperties props, final ConnectionContext connectionContext) {
RouteContext routeContext = route(queryContext, database, globalRuleMetaData, props, connectionContext);
SQLRewriteResult rewriteResult = rewrite(queryContext, database, globalRuleMetaData, props, routeContext, connectionContext);
ExecutionContext result = createExecutionContext(queryContext, database, routeContext, rewriteResult);
logSQL(queryContext, props, result);
return result;
}
KernelProcessor的generateExecutionContext方法先创建routeContext,然后执行rewrite,最后执行createExecutionContext
rewrite
private SQLRewriteResult rewrite(final QueryContext queryContext, final ShardingSphereDatabase database, final ShardingSphereRuleMetaData globalRuleMetaData,
final ConfigurationProperties props, final RouteContext routeContext, final ConnectionContext connectionContext) {
SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(database, globalRuleMetaData, props);
return sqlRewriteEntry.rewrite(queryContext.getSql(), queryContext.getParameters(), queryContext.getSqlStatementContext(), routeContext, connectionContext, queryContext.getHintValueContext());
}
rewrite主要是通过SQLRewriteEntry的rewrite方法进行的
SQLRewriteEntry
shardingsphere-infra-rewrite-5.4.0-sources.jar!/org/apache/shardingsphere/infra/rewrite/SQLRewriteEntry.java
/**
* Rewrite.
*
* @param sql SQL
* @param params SQL parameters
* @param sqlStatementContext SQL statement context
* @param routeContext route context
* @param connectionContext connection context
* @param hintValueContext hint value context
*
* @return route unit and SQL rewrite result map
*/
public SQLRewriteResult rewrite(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext, hintValueContext);
SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
DatabaseType protocolType = database.getProtocolType();
Map<String, DatabaseType> storageTypes = database.getResourceMetaData().getStorageTypes();
return routeContext.getRouteUnits().isEmpty()
? new GenericSQLRewriteEngine(rule, protocolType, storageTypes).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, protocolType, storageTypes).rewrite(sqlRewriteContext, routeContext);
}
private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
final RouteContext routeContext, final ConnectionContext connectionContext, final HintValueContext hintValueContext) {
SQLRewriteContext result = new SQLRewriteContext(database.getName(), database.getSchemas(), sqlStatementContext, sql, params, connectionContext, hintValueContext);
decorate(decorators, result, routeContext, hintValueContext);
result.generateSQLTokens();
return result;
}
private void decorate(final Map<ShardingSphereRule, SQLRewriteContextDecorator> decorators, final SQLRewriteContext sqlRewriteContext,
final RouteContext routeContext, final HintValueContext hintValueContext) {
if (hintValueContext.isSkipSQLRewrite()) {
return;
}
for (Entry<ShardingSphereRule, SQLRewriteContextDecorator> entry : decorators.entrySet()) {
entry.getValue().decorate(entry.getKey(), props, sqlRewriteContext, routeContext);
}
}
SQLRewriteEntry的rewrite方法,先通过createSQLRewriteContext来创建SQLRewriteContext,这里通过decorate方法遍历decorators,挨个执行SQLRewriteContextDecorator的decorate方法;最后通过GenericSQLRewriteEngine或者RouteSQLRewriteEngine进行rewrite
SQLRewriteContextDecorator
org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContextDecorator.java
@SingletonSPI
public interface SQLRewriteContextDecorator<T extends ShardingSphereRule> extends OrderedSPI<T> {
/**
* Decorate SQL rewrite context.
*
* @param rule rule
* @param props ShardingSphere properties
* @param sqlRewriteContext SQL rewrite context to be decorated
* @param routeContext route context
*/
void decorate(T rule, ConfigurationProperties props, SQLRewriteContext sqlRewriteContext, RouteContext routeContext);
}
SQLRewriteContextDecorator定义了decorate方法,它有诸如ShardingSQLRewriteContextDecorator、EncryptSQLRewriteContextDecorator的实现类
EncryptSQLRewriteContextDecorator
org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java
/**
* SQL rewrite context decorator for encrypt.
*/
public final class EncryptSQLRewriteContextDecorator implements SQLRewriteContextDecorator<EncryptRule> {
@Override
public void decorate(final EncryptRule encryptRule, final ConfigurationProperties props, final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
SQLStatementContext sqlStatementContext = sqlRewriteContext.getSqlStatementContext();
if (!containsEncryptTable(encryptRule, sqlStatementContext)) {
return;
}
Collection<EncryptCondition> encryptConditions = createEncryptConditions(encryptRule, sqlRewriteContext);
if (!sqlRewriteContext.getParameters().isEmpty()) {
Collection<ParameterRewriter> parameterRewriters = new EncryptParameterRewriterBuilder(encryptRule,
sqlRewriteContext.getDatabaseName(), sqlRewriteContext.getSchemas(), sqlStatementContext, encryptConditions).getParameterRewriters();
rewriteParameters(sqlRewriteContext, parameterRewriters);
}
Collection<SQLTokenGenerator> sqlTokenGenerators = new EncryptTokenGenerateBuilder(encryptRule,
sqlStatementContext, encryptConditions, sqlRewriteContext.getDatabaseName()).getSQLTokenGenerators();
sqlRewriteContext.addSQLTokenGenerators(sqlTokenGenerators);
}
private Collection<EncryptCondition> createEncryptConditions(final EncryptRule encryptRule, final SQLRewriteContext sqlRewriteContext) {
SQLStatementContext sqlStatementContext = sqlRewriteContext.getSqlStatementContext();
if (!(sqlStatementContext instanceof WhereAvailable)) {
return Collections.emptyList();
}
Collection<WhereSegment> whereSegments = ((WhereAvailable) sqlStatementContext).getWhereSegments();
Collection<ColumnSegment> columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
return new EncryptConditionEngine(encryptRule, sqlRewriteContext.getSchemas())
.createEncryptConditions(whereSegments, columnSegments, sqlStatementContext, sqlRewriteContext.getDatabaseName());
}
private boolean containsEncryptTable(final EncryptRule encryptRule, final SQLStatementContext sqlStatementContext) {
for (String each : sqlStatementContext.getTablesContext().getTableNames()) {
if (encryptRule.findEncryptTable(each).isPresent()) {
return true;
}
}
return false;
}
private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {
for (ParameterRewriter each : parameterRewriters) {
each.rewrite(sqlRewriteContext.getParameterBuilder(), sqlRewriteContext.getSqlStatementContext(), sqlRewriteContext.getParameters());
}
}
@Override
public int getOrder() {
return EncryptOrder.ORDER;
}
@Override
public Class<EncryptRule> getTypeClass() {
return EncryptRule.class;
}
}
rewriteParameters是通过ParameterRewriter进行rewrite,主要是修改ParameterBuilder;而具体sql语句的修改则通过sqlTokenGenerators进行
SQLToken
@RequiredArgsConstructor
@Getter
public abstract class SQLToken implements Comparable<SQLToken> {
private final int startIndex;
@Override
public final int compareTo(final SQLToken sqlToken) {
return startIndex - sqlToken.startIndex;
}
}
SQLToken它有诸如InsertValuesToken、SubstitutableColumnNameToken、InsertColumnsToken之类的实现类
RouteSQLRewriteEngine
/**
* Rewrite SQL and parameters.
*
* @param sqlRewriteContext SQL rewrite context
* @param routeContext route context
* @return SQL rewrite result
*/
public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits = new LinkedHashMap<>(routeContext.getRouteUnits().size(), 1F);
for (Entry<String, Collection<RouteUnit>> entry : aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet()) {
Collection<RouteUnit> routeUnits = entry.getValue();
if (isNeedAggregateRewrite(sqlRewriteContext.getSqlStatementContext(), routeUnits)) {
sqlRewriteUnits.put(routeUnits.iterator().next(), createSQLRewriteUnit(sqlRewriteContext, routeContext, routeUnits));
} else {
addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits);
}
}
return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext().getSqlStatement(), sqlRewriteUnits));
}
private void addSQLRewriteUnits(final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits, final SQLRewriteContext sqlRewriteContext,
final RouteContext routeContext, final Collection<RouteUnit> routeUnits) {
for (RouteUnit each : routeUnits) {
sqlRewriteUnits.put(each, new SQLRewriteUnit(new RouteSQLBuilder(sqlRewriteContext, each).toSQL(), getParameters(sqlRewriteContext.getParameterBuilder(), routeContext, each)));
}
}
private Map<RouteUnit, SQLRewriteUnit> translate(final SQLStatement sqlStatement, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
Map<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<>(sqlRewriteUnits.size(), 1F);
for (Entry<RouteUnit, SQLRewriteUnit> entry : sqlRewriteUnits.entrySet()) {
DatabaseType storageType = storageTypes.get(entry.getKey().getDataSourceMapper().getActualName());
String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatement, protocolType, storageType);
SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sql, entry.getValue().getParameters());
result.put(entry.getKey(), sqlRewriteUnit);
}
return result;
}
addSQLRewriteUnits是往sqlRewriteUnits添加SQLRewriteUnit,最后translate方法构建SQLRewriteUnit;SQLRewriteUnit包含了更改之后的sql以及对应改动后的参数
useDriverToExecuteUpdate
org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java
private int useDriverToExecuteUpdate() throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext();
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().executeUpdate(executionGroupContext,
executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteUpdateCallback());
}
private ExecutionGroupContext<JDBCExecutionUnit> createExecutionGroupContext() throws SQLException {
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = createDriverExecutionPrepareEngine();
return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(connection.getDatabaseName()));
}
private void cacheStatements(final Collection<ExecutionGroup<JDBCExecutionUnit>> executionGroups) throws SQLException {
for (ExecutionGroup<JDBCExecutionUnit> each : executionGroups) {
each.getInputs().forEach(eachInput -> {
statements.add((PreparedStatement) eachInput.getStorageResource());
parameterSets.add(eachInput.getExecutionUnit().getSqlUnit().getParameters());
});
}
replay();
}
private void replay() throws SQLException {
replaySetParameter();
for (Statement each : statements) {
getMethodInvocationRecorder().replay(each);
}
}
private void replaySetParameter() throws SQLException {
for (int i = 0; i < statements.size(); i++) {
replaySetParameter(statements.get(i), parameterSets.get(i));
}
}
protected final void replaySetParameter(final PreparedStatement preparedStatement, final List<Object> params) throws SQLException {
setParameterMethodInvocations.clear();
addParameters(params);
for (PreparedStatementInvocationReplayer each : setParameterMethodInvocations) {
each.replayOn(preparedStatement);
}
}
private void addParameters(final List<Object> params) {
int i = 0;
for (Object each : params) {
int index = ++i;
setParameterMethodInvocations.add(preparedStatement -> preparedStatement.setObject(index, each));
}
}
useDriverToExecuteUpdate方法会执行createExecutionGroupContext(
会执行prepare方法
),cacheStatements这里主要是把eachInput.getStorageResource()真正的PrepareStatement赋值到ShardingSpherePreparedStatement的statements变量中,把eachInput.getExecutionUnit().getSqlUnit().getParameters()赋值到parameterSets,然后执行replay方法通过PreparedStatementInvocationReplayer把修改后的变量replay到真正的PrepareStatement
该方法委托给executor.getRegularExecutor().executeUpdate,最后一个参数为callback,即createExecuteUpdateCallback
DriverExecutionPrepareEngine.prepare
org/apache/shardingsphere/infra/executor/sql/prepare/AbstractExecutionPrepareEngine.java
public final ExecutionGroupContext<T> prepare(final RouteContext routeContext, final Collection<ExecutionUnit> executionUnits,
final ExecutionGroupReportContext reportContext) throws SQLException {
return prepare(routeContext, Collections.emptyMap(), executionUnits, reportContext);
}
public final ExecutionGroupContext<T> prepare(final RouteContext routeContext, final Map<String, Integer> connectionOffsets, final Collection<ExecutionUnit> executionUnits,
final ExecutionGroupReportContext reportContext) throws SQLException {
Collection<ExecutionGroup<T>> result = new LinkedList<>();
for (Entry<String, List<SQLUnit>> entry : aggregateSQLUnitGroups(executionUnits).entrySet()) {
String dataSourceName = entry.getKey();
List<SQLUnit> sqlUnits = entry.getValue();
List<List<SQLUnit>> sqlUnitGroups = group(sqlUnits);
ConnectionMode connectionMode = maxConnectionsSizePerQuery < sqlUnits.size() ? ConnectionMode.CONNECTION_STRICTLY : ConnectionMode.MEMORY_STRICTLY;
result.addAll(group(dataSourceName, connectionOffsets.getOrDefault(dataSourceName, 0), sqlUnitGroups, connectionMode));
}
return decorate(routeContext, result, reportContext);
}
protected List<ExecutionGroup<T>> group(final String dataSourceName, final int connectionOffset, final List<List<SQLUnit>> sqlUnitGroups, final ConnectionMode connectionMode) throws SQLException {
List<ExecutionGroup<T>> result = new LinkedList<>();
List<C> connections = databaseConnectionManager.getConnections(dataSourceName, connectionOffset, sqlUnitGroups.size(), connectionMode);
int count = 0;
for (List<SQLUnit> each : sqlUnitGroups) {
result.add(createExecutionGroup(dataSourceName, each, connections.get(count++), connectionMode));
}
return result;
}
private ExecutionGroup<T> createExecutionGroup(final String dataSourceName, final List<SQLUnit> sqlUnits, final C connection, final ConnectionMode connectionMode) throws SQLException {
List<T> result = new LinkedList<>();
for (SQLUnit each : sqlUnits) {
result.add((T) sqlExecutionUnitBuilder.build(new ExecutionUnit(dataSourceName, each), statementManager, connection, connectionMode, option, databaseTypes.get(dataSourceName)));
}
return new ExecutionGroup<>(result);
}
group方法调用遍历SQLUnit执行createExecutionGroup,而后者则执行sqlExecutionUnitBuilder.build;这里databaseConnectionManager.getConnections获取的connection是通过真正driver获取的connection(
com.mysql.jdbc.Driver
)
PreparedStatementExecutionUnitBuilder
org/apache/shardingsphere/infra/executor/sql/prepare/driver/jdbc/builder/PreparedStatementExecutionUnitBuilder.java
public JDBCExecutionUnit build(final ExecutionUnit executionUnit, final ExecutorJDBCStatementManager statementManager,
final Connection connection, final ConnectionMode connectionMode, final StatementOption option, final DatabaseType databaseType) throws SQLException {
PreparedStatement preparedStatement = createPreparedStatement(
executionUnit, statementManager, connection, connectionMode, option, databaseType);
return new JDBCExecutionUnit(executionUnit, connectionMode, preparedStatement);
}
private PreparedStatement createPreparedStatement(final ExecutionUnit executionUnit, final ExecutorJDBCStatementManager statementManager, final Connection connection,
final ConnectionMode connectionMode, final StatementOption option, final DatabaseType databaseType) throws SQLException {
return (PreparedStatement) statementManager.createStorageResource(executionUnit, connection, connectionMode, option, databaseType);
}
PreparedStatementExecutionUnitBuilder的build方法这里才真正创建PreparedStatement
StatementManager
org/apache/shardingsphere/driver/jdbc/core/statement/StatementManager.java
public Statement createStorageResource(final ExecutionUnit executionUnit, final Connection connection, final ConnectionMode connectionMode, final StatementOption option,
final DatabaseType databaseType) throws SQLException {
Statement result = cachedStatements.get(new CacheKey(executionUnit, connectionMode));
if (null == result || result.isClosed() || result.getConnection().isClosed()) {
String sql = executionUnit.getSqlUnit().getSql();
if (option.isReturnGeneratedKeys()) {
result = null == option.getColumns() || 0 == option.getColumns().length
? connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)
: connection.prepareStatement(sql, option.getColumns());
} else {
result = connection.prepareStatement(sql, option.getResultSetType(), option.getResultSetConcurrency(), option.getResultSetHoldability());
}
cachedStatements.put(new CacheKey(executionUnit, connectionMode), result);
}
return result;
}
createStorageResource则是通过connection.prepareStatement来创建真正的PrepareStatement,而此时传入的sql也是经过重写之后的sql
createExecuteUpdateCallback
org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java
private JDBCExecutorCallback<Integer> createExecuteUpdateCallback() {
boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
return new JDBCExecutorCallback<Integer>(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType(),
metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getResourceMetaData().getStorageTypes(), sqlStatement, isExceptionThrown) {
@Override
protected Integer executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
return ((PreparedStatement) statement).executeUpdate();
}
@Override
protected Optional<Integer> getSaneResult(final SQLStatement sqlStatement, final SQLException ex) {
return Optional.empty();
}
};
}
createExecuteUpdateCallback创建的JDBCExecutorCallback,其executeSQL方法则是通过((PreparedStatement) statement).executeUpdate()来执行,即委托给了真正的PreparedStatement
小结
- ShardingSphereConnection的prepareStatement创建的是ShardingSpherePreparedStatement,它在ShardingSpherePreparedStatement的executeUpdate的时候进行sql重写,然后prepare,最后执行的时候是通过JDBCExecutorCallback,其executeSQL方法则是通过((PreparedStatement) statement).executeUpdate()来执行,即委托给了真正的PreparedStatement
- rewriteParameters是通过ParameterRewriter进行rewrite,主要是修改ParameterBuilder;而具体sql语句的修改则通过sqlTokenGenerators进行
- PreparedStatementExecutionUnitBuilder的build方法这里才真正创建PreparedStatement:它通过StatementManager.createStorageResource则是通过connection.prepareStatement来创建真正的PrepareStatement,而此时传入的sql也是经过重写之后的sql
- useDriverToExecuteUpdate方法会执行createExecutionGroupContext(
会执行prepare方法
),cacheStatements这里主要是把eachInput.getStorageResource()真正的PrepareStatement赋值到ShardingSpherePreparedStatement的statements变量中,把eachInput.getExecutionUnit().getSqlUnit().getParameters()赋值到parameterSets,然后执行replay方法通过PreparedStatementInvocationReplayer把修改后的变量replay到真正的PrepareStatement
ShardingSpherePreparedStatement实现了java.sql.PreparedStatement接口,其sql属性是用户传入的sql,即未经过重写的sql,而实际execute的时候,会触发sql重写(包括重写sql语句及参数),最后会通过connection.prepareStatement(传入重写之后的sql)来创建真正的PrepareStatement,然后有一步replay操作,把重写后的参数作用到真正的PrepareStatement,最后通过((PreparedStatement) statement).executeUpdate()来触发执行
至此我们可以得到sql重写的一个基本思路:通过实现java.sql.PreparedStatement接口伪装一个PreparedStatement类,其创建和set参数先内存缓存起来,之后在execute的时候进行sql重写,创建真正的PreparedStatement,replay参数,执行execute方法