MyBatis 分表插件之并发安全(三)
前言
这是Mybatis缓存插件系列的第三篇文章,不熟悉的同学可以看下上两篇。缓存篇(二), 原理入门(一)
上一篇中介绍了给分表插件添加缓存,主要是用来处理解析SQL获取原始表名称。文章提到,在单应用debug跟踪的时候,是没有问题的。但是实际环境中,由于没有协调对缓存的访问和修改,会造成取到错误的结果,导致请求失败。虽然当时使用了ConrrentHashMap
,但是解析的结果是保存在一个对象中的,因此这里操作map是线程安全的,但是保存的对象字段仍然可以修改,并发下会有问题,这就是最终的原因。所以在测试环境部署后,会看到表名称没有替换,导致SQL执行失败。今天这篇文章就是介绍如何保证,SQL解析缓存的线程安全。
思路
首先,可以知道线程安全的问题主要是并发修改和访问导致的,我们可以在修改时锁定对象,这里对应过来就是:在map的更改操作加锁;除了这个还有一种思路。上面提到,虽然使用了ConcurrentHashMap
,保证了操作解析对象的线程安全,但是对象的字段修改却不是线程安全的,因此可以在操作解析对象指定字段的时候加锁,这样来协调多线程并发的修改和访问。 最终,这边采用了给整个修改操作加锁的方式,可以看下面的代码和解析。
代码
解析结果对象
@Data
private static final class ShardEntity {
//执行SQL,开始考虑缓存,Mapper中一个方法对应一条方法,但是对于in等可变参数的SQL无法支持。
private String statement;
//原始表名称
private String originTableName;
//方法拆分键
private String shardKey;
}
分表插件类
@Intercepts(@Signature(
type = StatementHandler.class,
method = "prepare",
args = {Connection.class, Integer.class}
))
@Component
public class ShardInterceptor implements Interceptor, Ordered {
private static final ReflectorFactory defaultReflectorFactory = new DefaultReflectorFactory();
private static final HashMap<String, ShardEntity> MAPPER_SHARD_CACHE = new HashMap<>();
@Resource
private Properties shardConfigProperties;
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = MetaObject.forObject(statementHandler,
SystemMetaObject.DEFAULT_OBJECT_FACTORY,
SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY,
defaultReflectorFactory
);
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
String id = mappedStatement.getId();
BoundSql boundSql = statementHandler.getBoundSql();
HashMap<String, Object> parameterObject = (HashMap<String, Object>) boundSql.getParameterObject();
//对整个修改map的操作加同步,但这样锁的粒度较大。
synchronized (MAPPER_SHARD_CACHE) {
//这里直接使用的普通HashMap
ShardEntity shardEntity = MAPPER_SHARD_CACHE.get(id);
String sql = boundSql.getSql();
if (null != shardEntity) {
if (null == shardEntity.getShardKey() || null == shardEntity.getOriginTableName()) {
return invocation.proceed();
}
Long value = (Long) parameterObject.get(shardEntity.getShardKey());
String originTable = shardEntity.getOriginTableName();
renameTable(boundSql, sql, value, originTable);
return invocation.proceed();
} else {
shardEntity = new ShardEntity();
shardEntity.setStatement(sql);
MAPPER_SHARD_CACHE.put(id, shardEntity);
}
String dao = id.substring(0, id.lastIndexOf("."));
String methodName = id.substring(id.lastIndexOf(".") + 1);
Class clazz = Class.forName(dao);
for (Method method : clazz.getMethods()) {
if (method.getName().equals(methodName)) {
Annotation[][] parameterAnnotations = method.getParameterAnnotations();
int idx = 0;
for (Annotation[] pa : parameterAnnotations) {
for (Annotation a : pa) {
if (a instanceof ShardBy) {
String shardKey = method.getParameters()[idx].getName();
Long value = (Long) parameterObject.get(shardKey);
String originTable = getTableName(sql);
renameTable(boundSql, sql, value, originTable);
ShardEntity entity = MAPPER_SHARD_CACHE.get(id);
if (null == entity) {
shardEntity.setOriginTableName(originTable);
shardEntity.setShardKey(shardKey);
} else {
entity.setOriginTableName(originTable);
entity.setShardKey(shardKey);
}
return invocation.proceed();
}
}
idx++;
}
}
}
}
return invocation.proceed();
}
private void renameTable(BoundSql boundSql, String sql, Long value, String originTable) throws NoSuchFieldException, IllegalAccessException {
String forwardTable = shard(originTable, value);
Field field = boundSql.getClass().getDeclaredField("sql");
field.setAccessible(true);
field.set(boundSql, sql.replace(originTable, forwardTable));
}
private String shard(String tableName, Long value) {
return tableName + "_" + value % Integer.parseInt(shardConfigProperties.getProperty("mod"));
}
private String getTableName(String sql) throws Throwable {
SQLParseInfo parseInfo = SQLParseInfo.getParseInfo(sql);
if (parseInfo.getTables() == null || parseInfo.getTables().length != 1) {
throw new Throwable("表数目不为1");
}
return parseInfo.getTables()[0].getName();
}
}
说明
以上就是修改后的代码,具体更多的细节可以看之前的两篇文章。这里是对整个缓存MAP的所有操作加了同步Sychronize
,可见锁定的范围是比较大的。这里至少要保存拆分键和原始表名,因此解析对象建议还是保留,如果要降低锁的粒度,可以尝试在ShardEntity
操作时加同步策略。
总结
分表插件起源于老项目迁移SpringBoot+MyBatis
,目前完成了上线,单机QPS在50
左右,高峰超过100
,在这个过程中,自己开发了项目迁移脚本(转换生成Mapper.xml),自动批量验证DAO层接口(可以看之前的文章)。总的来说,循序渐进完成了项目的迁移工作,当然其中也遇到了很多问题并随之解决。
说点题外的吧,可能觉得有点重复造轮子,但是由于各种原因,很多现成的东西并不能直接拿过来用,而且有时候执行迁移的成本也很高。如果有遇到类似能做深入开发的机会,最好在能有把握完成的前提下自己开动,也是一种锻炼和提升的机会。一来业务开发中,相关的机会现实中并不多,再说多数情况下,想法容易,但要实际落地可用是比较难的。
感谢阅读~
更新 (2020-04-20)
更新了最新版本,去掉了synchronized
.
@Intercepts(@Signature(
type = StatementHandler.class,
method = "prepare",
args = {Connection.class, Integer.class}
))
@Component
public class ShardInterceptor implements Interceptor, Ordered {
private static final ReflectorFactory defaultReflectorFactory = new DefaultReflectorFactory();
private static final ConcurrentHashMap<String, ShardEntity> MAPPER_SHARD_CACHE = new ConcurrentHashMap<>();
@Resource
private Properties shardConfigProperties;
private int mod;
@PostConstruct
public void init() {
mod = Integer.parseInt(shardConfigProperties.getProperty("mod"));
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
long time = System.currentTimeMillis();
try {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = MetaObject.forObject(statementHandler,
SystemMetaObject.DEFAULT_OBJECT_FACTORY,
SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY,
defaultReflectorFactory
);
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
String id = mappedStatement.getId();
BoundSql boundSql = statementHandler.getBoundSql();
HashMap<String, Object> parameterObject = (HashMap<String, Object>) boundSql.getParameterObject();
ShardEntity shardEntity = MAPPER_SHARD_CACHE.get(id);
String sql = boundSql.getSql();
if (null != shardEntity) {
if (null == shardEntity.getShardKey()) {
return invocation.proceed();
}
Long value = (Long) parameterObject.get(shardEntity.getShardKey());
String originTable = shardEntity.getOriginTableName();
renameTable(boundSql, sql, value, originTable);
return invocation.proceed();
}
String dao = id.substring(0, id.lastIndexOf("."));
String methodName = id.substring(id.lastIndexOf(".") + 1);
Class clazz = Class.forName(dao);
for (Method method : clazz.getMethods()) {
if (method.getName().equals(methodName)) {
Annotation[][] parameterAnnotations = method.getParameterAnnotations();
int idx = 0;
for (Annotation[] pa : parameterAnnotations) {
for (Annotation a : pa) {
if (a instanceof ShardBy) {
String shardKey = method.getParameters()[idx].getName();
Long value = (Long) parameterObject.get(shardKey);
String originTable = getTableName(sql);
renameTable(boundSql, sql, value, originTable);
shardEntity = new ShardEntity();
shardEntity.setStatement(sql);
shardEntity.setOriginTableName(originTable);
shardEntity.setShardKey(shardKey);
MAPPER_SHARD_CACHE.put(id, shardEntity);
return invocation.proceed();
}
}
idx++;
}
}
}
//非拆分表
shardEntity = new ShardEntity();
MAPPER_SHARD_CACHE.put(id, shardEntity);
return invocation.proceed();
} finally {
Counter.count("shard_mapper_millis", System.currentTimeMillis() - time);
}
}