MyBatis注解支持SQL片断/宏
2019-10-08 本文已影响0人
无醉_1866
对于喜欢使用注解而不是xml的同学来说,通过<sql>这种xml的方式将常用的sql片断抽取成公用的sql会非常痛苦,我们想使用无xml配置的方法,有没有办法通过注解实现呢?办法就是使用mybatis插件
希望达到如下使用方式的目的(省略无关代码):
@Macro(name = "columns", content = "id, name, cel_phone")
public interface UserMapper {
@Select("select @columns from user where id = #{id}")
User queryById(@Param("id")int id);
}
其中,注解@Macro用于定义宏,即用于对SQL做文本替换,上面的代码中,@Macro(name = "columns", content = "id, name, cel_phone")表示的意思是定义了一个名为columns的宏,当SQL中出现@columns时,将其替换成id, name, cel_phone,下面是@Macro的定义:
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 用于在mybatis的映射接口上定义宏,适用于纯注解的使用方式。
* <p>
* 一个映射接口上可以标记多个@Macro,jdk8支持多个注解实例
*
* @author gaohang
*/
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Macro {
/**
* 宏的名称
*/
String name();
/**
* 用于做替换的文本内容,使用数组可以避免因内容太长导致字符串非常难以阅读,拆分成多个string即可
*/
String[] content();
}
最后聊聊此插件如何实现,基本原理:
- 识别SQL中的@符号,取出@符号后面的第一个标识符,比如@columns,则取出columns,columns是正在使用的宏的名称
- 将通过宏名称找到映射接口上的@Macro声明,取出其中的content
- 将取出的content替换原始sql中的@以及宏名称
结合mybatis的实现,我们可以看到MappedStatement对象有一个getSqlSource方法,返回SqlSource对象,而SqlSource对象的getBoundSql方法返回BoundSql,我们需要处理的正是BoundSql的getSql方法返回的sql字符串,这个时候,最容易想到的方式是使用代理模式,将SqlSource和BoundSql对象做代理,当调用到boundSql.getSql()时,处理sql中所使用到的宏,这里需要注意以下几点:
- SqlSource是一个接口,且其中只有一个方法getBoundSql,使用静态代理或者匿名内部类即可
- BoundSql是一个具体类,想要代理此类的对象,需要使用cglib
- MappedStatement对象是全局唯一的,不要直接修改此类的对象的属性,当需要修改时,选择复制一个并为需要修改的字段赋上新的值
废话不多说了,代码实现如下:
import com.google.common.collect.Sets;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cglib.proxy.Enhancer;
import org.springframework.cglib.proxy.MethodInterceptor;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
/**
* 处理{@link Macro}注解,对SQL做字符串替换,将SQL语句中的@xxx替换成{@link Macro#name()}为xxx的注解中的{@link Macro#content()},例如:
*
* <pre>
* {@code
* @Macro(name = "columns", content = "id, name, cel_phone")
* public interface UserMapper {
* @Select("select @columns from user where id = #{id})
* User queryById(@Param("id")int id);
* }
*
* }
* </pre>
* <p>
* 注意,宏的使用必须要跟着逗号或空白字符,或者是SQL的结尾,否则识别不了
*
* @author gaohang
*/
@Intercepts({
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})
})
public class MacroInterceptor implements Interceptor {
private static final Logger logger = LoggerFactory.getLogger(MacroInterceptor.class);
/**
* 记录不需要处理{@link Macro}的映射接口
*/
private final Set<String> namespaceNotContainsMacro = Sets.newConcurrentHashSet();
@Override
public Object intercept(Invocation invocation) throws Throwable {
final Object[] args = invocation.getArgs();
final MappedStatement ms = (MappedStatement) args[0];
//这里的sqlSource是一个代理对象
final SqlSource sqlSource = new SqlSourceDelegate(ms);
//复制MapedStatement对象
final MappedStatement mappedStatement = MappedStatementUtils.copyMappedStatement(ms, sqlSource);
args[0] = mappedStatement;
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
/**
* 对{@link SqlSource}的代理,{@link #getBoundSql(Object)}方法会返回一个{@link BoundSql}
* 对象的cglib代理,在代理对象中处理宏的替换
*/
private final class SqlSourceDelegate implements SqlSource {
private final MappedStatement ms;
private SqlSourceDelegate(MappedStatement ms) {
this.ms = ms;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
//获取原始的BoundSql
final BoundSql boundSql = ms.getSqlSource().getBoundSql(parameterObject);
final String namespace = getNamespace(ms);
if (namespace == null) {
return boundSql;
}
//拦截一次
if (namespaceNotContainsMacro.contains(namespace)) {
return boundSql;
}
try {
final Class<?> mapperInterface = Class.forName(namespace);
final Macro[] macros = mapperInterface.getAnnotationsByType(Macro.class);
//没有定义Macro,则返回原始的boundSql
if (macros == null || macros.length == 0) {
namespaceNotContainsMacro.add(namespace);
return boundSql;
}
//对BoundSql做代理,处理getSql()返回的sql语句
return proxyBoundSql(ms.getConfiguration(), boundSql, macros);
} catch (ClassNotFoundException e) {
logger.debug("load namespace class failed, maybe namespace {} is not a class", namespace, e);
namespaceNotContainsMacro.add(namespace);
}
return boundSql;
}
}
/**
* 创建{@link BoundSql}的代理对象
*/
private BoundSql proxyBoundSql(Configuration configuration, BoundSql boundSql, Macro[] macros) {
return getProxy(configuration, boundSql, (proxy, method, args, methodProxy) -> {
if (!Objects.equals("getSql", method.getName())) {
return methodProxy.invoke(boundSql, args);
}
//处理宏
return replaceMacroContent(macros, boundSql.getSql());
});
}
private BoundSql getProxy(Configuration configuration, BoundSql target, MethodInterceptor methodInterceptor) {
Enhancer enhancer = new Enhancer();
enhancer.setSuperclass(target.getClass());
enhancer.setCallback(methodInterceptor);
final Class<?>[] parameterTypes = {Configuration.class, String.class, List.class, Object.class};
final Object[] params = new Object[]{
configuration, target.getSql(), target.getParameterMappings(), target.getParameterObject()
};
return (BoundSql) enhancer.create(parameterTypes, params);
}
/**
* 对sql做宏的替换
*/
private String replaceMacroContent(Macro[] macros, String sql) {
//识别sql字符串中的@符号,并取出@符号之后的宏名
int lastAppended = 0;
final StringBuilder sqlAppender = new StringBuilder();
for (int i = sql.length(); lastAppended < i; ) {
final int macroPlaceStart = sql.indexOf('@', lastAppended);
if (macroPlaceStart < 0) {
//没有了,结束
break;
}
final int macroNameStart = macroPlaceStart + 1;
if (macroNameStart == i) {
break;
}
int k = macroNameStart;
for (; k < i; k++) {
final char ch = sql.charAt(k);
//@macroName后面必须要跟着逗号或者空白字符,或者是SQL的结尾
if (Character.isWhitespace(ch) || ch == ',') {
break;
}
}
sqlAppender.append(sql, lastAppended, macroPlaceStart);
//(macroPlaceStart, k)是宏结束
final String macroName = sql.substring(macroNameStart, k);
for (Macro macro : macros) {
if (Objects.equals(macro.name(), macroName)) {
final String[] contents = macro.content();
for (String content : contents) {
sqlAppender.append(content).append(' ');
}
}
}
lastAppended = k;
}
if (sqlAppender.length() == 0) {
return sql;
}
if (lastAppended < sql.length()) {
sqlAppender.append(sql.substring(lastAppended));
}
return sqlAppender.toString();
}
private String getNamespace(MappedStatement ms) {
final String statementId = ms.getId();
final int i = statementId.lastIndexOf('.');
if (i <= 0) {
return null;
}
return statementId.substring(0, i);
}
}