sharding-jdbc 执行流程源码分析-sql 路由
上一节 sql 提取过程中提取出 sql 分段后封装成了 SQLStatement,承接上一节,本节讲解 sql 路由。
public SQLRouteResult route(final List<Object> parameters) {
if (null == sqlStatement) {
//解析 sql 并且根据 sql 语句的类型返回对应的 Statement
sqlStatement = shardingRouter.parse(logicSQL, true);
}
//sql 路由
return masterSlaveRouter.route(shardingRouter.route(logicSQL, parameters, sqlStatement));
}
sql 路由的入口在shardingRouter.route(logicSQL, parameters, sqlStatement)
1、校验 sql 是否正确
2、将 sqlStatement 进一步分解,根据 sql 类型创建 sql 的描述:表名 + 列名
3、根据 where 查询条件或 insert/update 的 sharding key 和入参做映射: columnName + tableName + value
4、根据路由引擎,找到分片策略并且执行
ShardingRouter.java
/**
* @param logicSQL 逻辑 sql 即原始 sql
* @param parameters SQL parameters 参数
* @param sqlStatement SQL statement 经过解析、提取,的sql 分段
* @return parse result
*/
public SQLRouteResult route(final String logicSQL, final List<Object> parameters, final SQLStatement sqlStatement) {
Optional<ShardingStatementValidator> shardingStatementValidator = ShardingStatementValidatorFactory.newInstance(sqlStatement);
if (shardingStatementValidator.isPresent()) {
//根据分片规则校验 sql 是否正确
shardingStatementValidator.get().validate(shardingRule, sqlStatement, parameters);
}
//将 sqlStatement 进一步分解,根据 sql 类型创建 sql 的描述:表名 + 列名,
//如果是select: 包括 group + order 列名称
//如果是insert: 表名 + 列名 + value 对应的值列表
SQLStatementContext sqlStatementContext = SQLStatementContextFactory.newInstance(metaData.getRelationMetas(), logicSQL, parameters, sqlStatement);
Optional<GeneratedKey> generatedKey = sqlStatement instanceof InsertStatement
? GeneratedKey.getGenerateKey(shardingRule, metaData.getTables(), parameters, (InsertStatement) sqlStatement) : Optional.<GeneratedKey>absent();
//根据 where 查询条件或insert/update 的sharding key和入参做映射: columnName + tableName + value
ShardingConditions shardingConditions = getShardingConditions(parameters, sqlStatementContext, generatedKey.orNull(), metaData.getRelationMetas());
boolean needMergeShardingValues = isNeedMergeShardingValues(sqlStatementContext);
if (sqlStatementContext.getSqlStatement() instanceof DMLStatement && needMergeShardingValues) {
//getConditions不能为空,除了 Hint
checkSubqueryShardingValues(sqlStatementContext, shardingConditions);
mergeShardingConditions(shardingConditions);
}
RoutingEngine routingEngine = RoutingEngineFactory.newInstance(shardingRule, metaData, sqlStatementContext, shardingConditions);
//根据路由引擎,找到分片策略并且执行,然后封装逻辑名,和真实名的映射
RoutingResult routingResult = routingEngine.route();
if (needMergeShardingValues) {
Preconditions.checkState(1 == routingResult.getRoutingUnits().size(), "Must have one sharding with subquery.");
}
SQLRouteResult result = new SQLRouteResult(sqlStatementContext, shardingConditions, generatedKey.orNull());
result.setRoutingResult(routingResult);
if (sqlStatementContext instanceof InsertSQLStatementContext) {
setGeneratedValues(result);
}
return result;
}
1、校验 sql 是否正确
1、从校验工厂类中获取 ShardingStatementValidator
2、 如果需要校验那么执行 validate
如果 sql 的类型是 insert 或 update 那么返回相应的校验类
public final class ShardingStatementValidatorFactory {
public static Optional<ShardingStatementValidator> newInstance(final SQLStatement sqlStatement) {
if (sqlStatement instanceof InsertStatement) {
return Optional.<ShardingStatementValidator>of(new ShardingInsertStatementValidator());
}
if (sqlStatement instanceof UpdateStatement) {
return Optional.<ShardingStatementValidator>of(new ShardingUpdateStatementValidator());
}
return Optional.absent();
}
}
ShardingInsertStatementValidator 只是校验 INSERT INTO .... ON DUPLICATE KEY UPDATE 中的 update 语句块中不能修改 sharding key
public final class ShardingInsertStatementValidator implements ShardingStatementValidator<InsertStatement> {
@Override
public void validate(final ShardingRule shardingRule, final InsertStatement sqlStatement, final List<Object> parameters) {
Optional<OnDuplicateKeyColumnsSegment> onDuplicateKeyColumnsSegment = sqlStatement.findSQLSegment(OnDuplicateKeyColumnsSegment.class);
if (onDuplicateKeyColumnsSegment.isPresent() && isUpdateShardingKey(shardingRule, onDuplicateKeyColumnsSegment.get(), sqlStatement.getTable().getTableName())) {
throw new ShardingException("INSERT INTO .... ON DUPLICATE KEY UPDATE can not support update for sharding column.");
}
}
private boolean isUpdateShardingKey(final ShardingRule shardingRule, final OnDuplicateKeyColumnsSegment onDuplicateKeyColumnsSegment, final String tableName) {
for (ColumnSegment each : onDuplicateKeyColumnsSegment.getColumns()) {
if (shardingRule.isShardingColumn(each.getName(), tableName)) {
return true;
}
}
return false;
}
}
ShardingUpdateStatementValidator
update 不能修改 sharding key ,除非 update 的 where 中也包括 sharding key 并且入参值是相等的,否则抛出异常ShardingException 下边看一下,具体实现方法
1、循环 update 的 set 段
2、判断 set Column 是否是 sharding key
3、获取 set Column 的用户入参值
4、根据 sharding key 获取 where 参数值
5、如果 set 修改的值和 where 条件的值相等,那么说明也是正确的
6、否则校验失败了
//ShardingUpdateStatementValidator.java
@Override
public void validate(final ShardingRule shardingRule, final UpdateStatement sqlStatement, final List<Object> parameters) {
String tableName = new TablesContext(sqlStatement).getSingleTableName();
//循环 update 的 set 段
for (AssignmentSegment each : sqlStatement.getSetAssignment().getAssignments()) {
String shardingColumn = each.getColumn().getName();
//shardingColumn 是否是 sharding key
if (shardingRule.isShardingColumn(shardingColumn, tableName)) {
//根据?参数位置获取,parameter 值
Optional<Object> shardingColumnSetAssignmentValue = getShardingColumnSetAssignmentValue(each, parameters);
Optional<Object> shardingValue = Optional.absent();
Optional<WhereSegment> whereSegmentOptional = sqlStatement.getWhere();
if (whereSegmentOptional.isPresent()) {
//根据 sharding key 获取 where 参数值,如果 where 条件包含
shardingValue = getShardingValue(whereSegmentOptional.get(), parameters, shardingColumn);
}
//如果 set 修改的值和 where 条件的值相等,那么说明也是正确的
if (shardingColumnSetAssignmentValue.isPresent() && shardingValue.isPresent() && shardingColumnSetAssignmentValue.get().equals(shardingValue.get())) {
continue;
}
//否则失败了
throw new ShardingException("Can not update sharding key, logic table: [%s], column: [%s].", tableName, each);
}
}
}
获取 set Column 的用户输入参值
1、如果参数的值是? ,代表是用户输入的那么从 parameters 获取
2、如果参数是表达式例如 ‘1+1’ 那么从 segment.getLiterals 中提取
//ShardingUpdateStatementValidator.java
private Optional<Object> getShardingColumnSetAssignmentValue(final AssignmentSegment assignmentSegment, final List<Object> parameters) {
//获取 set 段表达式 例如 a=b
ExpressionSegment segment = assignmentSegment.getValue();
int shardingSetAssignIndex = -1;
//获取 set 段中的 ? 部分
if (segment instanceof ParameterMarkerExpressionSegment) {
shardingSetAssignIndex = ((ParameterMarkerExpressionSegment) segment).getParameterMarkerIndex();
}
//获取运算表达式部分,直接返回
if (segment instanceof LiteralExpressionSegment) {
return Optional.of(((LiteralExpressionSegment) segment).getLiterals());
}
//排除数组越界的可能性
if (-1 == shardingSetAssignIndex || shardingSetAssignIndex > parameters.size() - 1) {
return Optional.absent();
}
//根据参数的位置,返回入参的值
return Optional.of(parameters.get(shardingSetAssignIndex));
}
根据 sharding key 获取 where 参数值
1、循环每一个 and 查询条件,遍历and 条件中的每一个单元,找到 sharding key
2、 找到参数表达式,去出对应的值
private Optional<Object> getShardingValue(final WhereSegment whereSegment, final List<Object> parameters, final String shardingColumn) {
//循环每一个 and 条件
for (AndPredicate each : whereSegment.getAndPredicates()) {
return getShardingValue(each, parameters, shardingColumn);
}
return Optional.absent();
}
private Optional<Object> getShardingValue(final AndPredicate andPredicate, final List<Object> parameters, final String shardingColumn) {
for (PredicateSegment each : andPredicate.getPredicates()) {
//根据入参过滤,根据入参 shardingColumn 查找
if (!shardingColumn.equalsIgnoreCase(each.getColumn().getName())) {
continue;
}
//and 后边的表达式
PredicateRightValue rightValue = each.getRightValue();
// = < > 比较表达式
if (rightValue instanceof PredicateCompareRightValue) {
ExpressionSegment segment = ((PredicateCompareRightValue) rightValue).getExpression();
return getPredicateCompareShardingValue(segment, parameters, shardingColumn);
}
//in 包含表达式
if (rightValue instanceof PredicateInRightValue) {
Collection<ExpressionSegment> segments = ((PredicateInRightValue) rightValue).getSqlExpressions();
return getPredicateInShardingValue(segments, parameters, shardingColumn);
}
}
return Optional.absent();
}
private Optional<Object> getPredicateCompareShardingValue(final ExpressionSegment segment, final List<Object> parameters, final String shardingColumn) {
int shardingValueParameterMarkerIndex = -1;
//普通参数表达式
if (segment instanceof ParameterMarkerExpressionSegment) {
shardingValueParameterMarkerIndex = ((ParameterMarkerExpressionSegment) segment).getParameterMarkerIndex();
if (-1 == shardingValueParameterMarkerIndex || shardingValueParameterMarkerIndex > parameters.size() - 1) {
return Optional.absent();
}
return Optional.of(parameters.get(shardingValueParameterMarkerIndex));
}
//运算符表达式
if (segment instanceof LiteralExpressionSegment) {
return Optional.of(((LiteralExpressionSegment) segment).getLiterals());
}
return Optional.absent();
}
private Optional<Object> getPredicateInShardingValue(final Collection<ExpressionSegment> segments, final List<Object> parameters, final String shardingColumn) {
int shardingColumnWhereIndex = -1;
for (ExpressionSegment each : segments) {
if (each instanceof ParameterMarkerExpressionSegment) {
shardingColumnWhereIndex = ((ParameterMarkerExpressionSegment) each).getParameterMarkerIndex();
if (-1 == shardingColumnWhereIndex || shardingColumnWhereIndex > parameters.size() - 1) {
continue;
}
return Optional.of(parameters.get(shardingColumnWhereIndex));
}
if (each instanceof LiteralExpressionSegment) {
return Optional.of(((LiteralExpressionSegment) each).getLiterals());
}
}
return Optional.absent();
}
2、将 sqlStatement 进一步分解,根据 sql 类型创建 sql 的描述
根据 sql 类型分三种 StatementContext
1、select 类型 sql 创建 SelectSQLStatementContext
2、insert 类型 sql 创建 InsertSQLStatementContext
3、其他类型 sql 创建
public final class SQLStatementContextFactory {
public static SQLStatementContext newInstance(final RelationMetas relationMetas, final String sql, final List<Object> parameters, final SQLStatement sqlStatement) {
if (sqlStatement instanceof SelectStatement) {
return new SelectSQLStatementContext(relationMetas, sql, parameters, (SelectStatement) sqlStatement);
}
if (sqlStatement instanceof InsertStatement) {
return new InsertSQLStatementContext(relationMetas, parameters, (InsertStatement) sqlStatement);
}
return new CommonSQLStatementContext(sqlStatement);
}
}
select 类型 sql 创建 SelectSQLStatementContext
1、调用父类构造方法 CommonSQLStatementContext
2、创建根据 sqlStatement 中的 GroupBySegment 和 OrderBySegment 创建 GroupByContext 和 OrderByContext
3、获取sql语句中出现的所有的列封装成 ProjectionsContext
但是需要去重,因为可能和 order by 后 group by 中的字段重复。
public SelectSQLStatementContext(final RelationMetas relationMetas, final String sql, final List<Object> parameters, final SelectStatement sqlStatement) {
super(sqlStatement);
//sql 语法规范中 group by 在 order by 前边
groupByContext = new GroupByContextEngine().createGroupByContext(sqlStatement);
orderByContext = new OrderByContextEngine().createOrderBy(sqlStatement, groupByContext);
//获取sql语句中所有的列,包括 group by 和 order by 中的列但是需要去重。
projectionsContext = new ProjectionsContextEngine(relationMetas).createProjectionsContext(sql, sqlStatement, groupByContext, orderByContext);
paginationContext = new PaginationContextEngine().createPaginationContext(sqlStatement, projectionsContext, parameters);
containsSubquery = containsSubquery();
}
调用父类构造方法 CommonSQLStatementContext
根据 sqlStatement 抽取表名封装成 TablesContext
public class CommonSQLStatementContext implements SQLStatementContext {
private final SQLStatement sqlStatement;
private final TablesContext tablesContext;
public CommonSQLStatementContext(final SQLStatement sqlStatement) {
this.sqlStatement = sqlStatement;
tablesContext = new TablesContext(sqlStatement);
}
}
/**
* 数据库类型 catalog schema
* Oracle 不支持,留空 等同于Oracle的user id
* DB2 支持,留空 等同于数据库owner
* Sybase 等同于数据库名称 等同于数据库owner
* MSSQL 等同于数据库名称 必须等于catalog或是database的owner
* informix 不支持,留空 无必要
* PointBase 不支持,留空 等同于数据库名称(name)
* MySQL 不支持 数据库名
* Tables context.
*
* @author zhangliang
*
* sql 表名和别名的集合,oracle: userid mysql: 数据库名称
*/
@ToString
public final class TablesContext {
//sql 表名和别名的集合
private final Collection<Table> tables = new ArrayList<>();
//oracle: userid mysql: 数据库名称
private String schema;
public TablesContext(final SQLStatement sqlStatement) {
Collection<String> aliases = new HashSet<>();
//查找表名类型的 sql 分段,因为 TableSegment 实现了 TableAvailable 接口
for (TableAvailable each : sqlStatement.findSQLSegments(TableAvailable.class)) {
//如果 sql 中包含别名,那么获取别名
Optional<String> alias = getAlias(each);
if (alias.isPresent()) {
aliases.add(alias.get());
}
}
for (TableAvailable each : sqlStatement.findSQLSegments(TableAvailable.class)) {
Optional<String> alias = getAlias(each);
//别名和表名相同
if (aliases.contains(each.getTableName()) && !alias.isPresent()) {
continue;
}
tables.add(new Table(each.getTableName(), alias.orNull()));
if (each instanceof TableSegment) {
setSchema((TableSegment) each);
}
}
}
... 略
}
创建根据 sqlStatement 中的 GroupBySegment 和 OrderBySegment 创建 GroupByContext 和 OrderByContext
GroupByContext 的创建过程和 OrderByContext 的类似因为 antlr 中 group by 复用了 orderByItem,这边就解释一个 GroupByContext
1、从 sqlStatement 中获取 group by 部分,循环创建 OrderByItem 然后 add 到 groupByItems 然后创建 GroupByContext。
封装成 GroupByContext 的目的是为了,下一步寻找sql 全部字段去重或添加做准备。
public final class GroupByContextEngine {
public GroupByContext createGroupByContext(final SelectStatement selectStatement) {
if (!selectStatement.getGroupBy().isPresent()) {
return new GroupByContext(new LinkedList<OrderByItem>(), 0);
}
Collection<OrderByItem> groupByItems = new LinkedList<>();
//因为存在 group by 多个字段的情况
for (OrderByItemSegment each : selectStatement.getGroupBy().get().getGroupByItems()) {
//antlr 中 group by 复用了 orderByItem
OrderByItem orderByItem = new OrderByItem(each);
if (each instanceof IndexOrderByItemSegment) {
orderByItem.setIndex(((IndexOrderByItemSegment) each).getColumnIndex());
}
groupByItems.add(orderByItem);
}
return new GroupByContext(groupByItems, selectStatement.getGroupBy().get().getStopIndex());
}
}
获取sql语句中出现的所有的列封装成 ProjectionsContext
1、select 语句中查询的每一个字段创建映射
2、添加 Group 中的列到 ProjectionsContext . projections,但是要和 selectItems 中的去重
3、添加 OrderBy 中的列到 ProjectionsContext . projections,但是要和 selectItems 中的去重
public final class ProjectionsContextEngine {
//从 connection 中拿到的所有表的列信息
private final RelationMetas relationMetas;
private final ProjectionEngine selectItemEngine = new ProjectionEngine();
public ProjectionsContext createProjectionsContext(final String sql, final SelectStatement selectStatement, final GroupByContext groupByContext, final OrderByContext orderByContext) {
SelectItemsSegment selectItemsSegment = selectStatement.getSelectItems();
//select 语句中查询的每一个字段创建映射
Collection<Projection> projections = getProjections(sql, selectItemsSegment);
ProjectionsContext result = new ProjectionsContext(
selectItemsSegment.getStartIndex(), selectItemsSegment.getStopIndex(), selectItemsSegment.isDistinctRow(), projections, getColumnLabels(selectStatement.getTables(), projections));
TablesContext tablesContext = new TablesContext(selectStatement);
//添加 Group 中的列,但是要和 selectItems 中的去重
result.getProjections().addAll(getDerivedGroupByColumns(tablesContext, projections, groupByContext));
//添加 OrderBy 中的列,但是要和 selectItems 中的去重
result.getProjections().addAll(getDerivedOrderByColumns(tablesContext, projections, orderByContext));
return result;
}
...略
}
select 语句中查询的每一个字段创建映射
1、循环每个 select 需要查询的字段:例如 select a.*,id 那么循环两次
2、根据 SelectItemSegment 类型创建 Projection ,Projection 的意义是描述 sql 语句中出现的数据库字段(或者表达式),许多种描述例如聚合函数,表达式,缩写 * ,普通字段。
//ProjectionsContextEngine.java
private Collection<Projection> getProjections(final String sql, final SelectItemsSegment selectItemsSegment) {
Collection<Projection> result = new LinkedList<>();
//循环 select 所查询所有段
for (SelectItemSegment each : selectItemsSegment.getSelectItems()) {
//为每个段,创建映射包括简写,函数 ,表达式....
Optional<Projection> selectItem = selectItemEngine.createProjection(sql, each);
if (selectItem.isPresent()) {
result.add(selectItem.get());
}
}
return result;
}
//ProjectionEngine.java
public Optional<Projection> createProjection(final String sql, final SelectItemSegment selectItemSegment) {
//select *
if (selectItemSegment instanceof ShorthandSelectItemSegment) {
return Optional.<Projection>of(createProjection((ShorthandSelectItemSegment) selectItemSegment));
}
//select id 或 select distinct 普通 select
if (selectItemSegment instanceof ColumnSelectItemSegment) {
return Optional.<Projection>of(createProjection((ColumnSelectItemSegment) selectItemSegment));
}
//select id+1 表达式类的 sql
if (selectItemSegment instanceof ExpressionSelectItemSegment) {
return Optional.<Projection>of(createProjection((ExpressionSelectItemSegment) selectItemSegment));
}
//select sum(distinct id) 聚合函数和 distinct 结合使用的 sql
if (selectItemSegment instanceof AggregationDistinctSelectItemSegment) {
return Optional.<Projection>of(createProjection(sql, (AggregationDistinctSelectItemSegment) selectItemSegment));
}
//select sum 函数的 sql
if (selectItemSegment instanceof AggregationSelectItemSegment) {
return Optional.<Projection>of(createProjection(sql, (AggregationSelectItemSegment) selectItemSegment));
}
// TODO subquery
return Optional.absent();
}
添加 Group 中的列到 ProjectionsContext . projections,但是要和 selectItems 中的去重
1、循环 groupByContext 中所有的列 OrderByItem
2、不包含呢么创建新 Column 映射
//添加 Group 中的列,但是要和 selectItems 中的去重
//result.getProjections().addAll(
//getDerivedGroupByColumns(tablesContext, projections, groupByContext));
private Collection<Projection> getDerivedGroupByColumns(final TablesContext tablesContext, final Collection<Projection> selectItems, final GroupByContext groupByContext) {
return getDerivedOrderColumns(tablesContext, selectItems, groupByContext.getItems(), DerivedColumn.GROUP_BY_ALIAS);
}
private Collection<Projection> getDerivedOrderColumns(final TablesContext tablesContext,
final Collection<Projection> selectItems, final Collection<OrderByItem> orderItems, final DerivedColumn derivedColumn) {
Collection<Projection> result = new LinkedList<>();
int derivedColumnOffset = 0;
for (OrderByItem each : orderItems) {
//判断 selectItems 中的列是否已经包含了 OrderByItem 中的列
if (!containsProjection(tablesContext, selectItems, each.getSegment())) {
//不包含呢么创建新 Column
result.add(new DerivedProjection(((TextOrderByItemSegment) each.getSegment()).getText(), derivedColumn.getDerivedColumnAlias(derivedColumnOffset++)));
}
}
return result;
}
insert 类型 sql 创建 InsertSQLStatementContext
创建 InsertSQLStatementContext 比 select 类型的简单不少首先看初始化函数
1、调用父类构造方法 CommonSQLStatementContext 和上文一样的都是根据 sqlStatement 抽取表名封装成 TablesContext
2、sqlStatement 是否用了默认的 Columns (columns.isEmpty() && null == setAssignment;) 如果用了,那么用表的元信息列(从数据库连接中取得)
3、创建 InsertValueContext
public final class InsertSQLStatementContext extends CommonSQLStatementContext {
//insert into() 列名称结合
private final List<String> columnNames;
//insert into values() 值的集合
private final List<InsertValueContext> insertValueContexts;
public InsertSQLStatementContext(final RelationMetas relationMetas, final List<Object> parameters, final InsertStatement sqlStatement) {
super(sqlStatement);
columnNames = sqlStatement.useDefaultColumns() ? relationMetas.getAllColumnNames(getTablesContext().getSingleTableName()) : sqlStatement.getColumnNames();
insertValueContexts = getInsertValueContexts(parameters);
}
...略
}
创建 InsertValueContext
1、因为有 insert into values(),(),() 的情况所以需要循环每一个值项
2、为 insert into values(),(),() 中的一个 value() 创建 InsertValueContext
3、记录value() 对应 Parameters 值得偏移位置
//InsertSQLStatementContext.java
private List<InsertValueContext> getInsertValueContexts(final List<Object> parameters) {
List<InsertValueContext> result = new LinkedList<>();
int parametersOffset = 0;
//因为有 insert into values(),(),() 的情况所以需要循环每一个值项
for (Collection<ExpressionSegment> each : ((InsertStatement) getSqlStatement()).getAllValueExpressions()) {
//insert into values(),(),() 中的一个 value()
InsertValueContext insertValueContext = new InsertValueContext(each, parameters, parametersOffset);
result.add(insertValueContext);
parametersOffset += insertValueContext.getParametersCount();
}
return result;
}
为 insert into values(),(),() 中的一个 value() 创建 InsertValueContext
1、获取 value() 中的参数数量,就是同统计 ? 出现的数量
2、获取 value() 中的字段 valueExpressions
3、从参数 parameters 中截取 InsertValueContext 的那部分值
public final class InsertValueContext {
private final int parametersCount;
//value('','','') ,ExpressionSegment 代表其中一个 ''
private final List<ExpressionSegment> valueExpressions;
//单个 value列表对应的值列表
private final List<Object> parameters;
public InsertValueContext(final Collection<ExpressionSegment> assignments, final List<Object> parameters, final int parametersOffset) {
parametersCount = calculateParametersCount(assignments);
valueExpressions = getValueExpressions(assignments);
this.parameters = getParameters(parameters, parametersOffset);
}
...略
}
至此,将 sqlStatement 进一步分解,根据 sql 类型创建 sql 的描述,已经结束,再看一下这部分代码的 uml 图
image.png