Mybatis 程序员

Mybatis拦截器如何进行分表

2019-04-15  本文已影响266人  杭宇_8ba6

这点跟spring的拦截器是基本一致的。它的设计初衷就是为了供用户在某些时候可以实现自己的逻辑而不必去动Mybatis固有的逻辑。拦截器的使用中,分页插件应该是使用得最多的了。分表的实现也差不多类似。

   <dependency>
            <groupId>org.mybatis</groupId>
            <artifactId>mybatis-spring</artifactId>
            <version>1.2.3</version>
        </dependency>
        <dependency>
            <groupId>org.mybatis</groupId>
            <artifactId>mybatis</artifactId>
            <version>3.3.0</version>
        </dependency>
        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>druid</artifactId>
            <version>1.0.27</version>
        </dependency>
/**
 * @Title:
 * @Auther: hangyu
 * @Date: 2019/4/15
 * @Description
 * @Version:1.0
 */
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})})
public class TableSegInterceptor implements Interceptor {

    private Log log = LogFactory.getLog(getClass());

    private final static String BOUNDSQL_SQL_NAME = "delegate.boundSql.sql";

    private final static String BOUNDSQL_NAME = "delegate.boundSql";

    private final static String MAPPEDSTATEMENT_NAME = "delegate.mappedStatement";

    private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();

    private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();

    private final static ReflectorFactory DEFAULT_REFLECTOR_FACTORY = new DefaultReflectorFactory();


    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();

        //全局操作对象
        MetaObject metaObject = MetaObject.forObject(statementHandler, DEFAULT_OBJECT_FACTORY,
                DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);

        //获取原始sql
        String originalSql = (String) metaObject.getValue(BOUNDSQL_SQL_NAME);

        //这两个对象都是获取mapper的参数的
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue(MAPPEDSTATEMENT_NAME);
        BoundSql boundSql = (BoundSql) metaObject.getValue(BOUNDSQL_NAME);

        if (StringUtils.isNotEmpty(originalSql)) {
            String id = mappedStatement.getId();
            String className = id.substring(0, id.lastIndexOf("."));
            Class<?> classObj = Class.forName(className);
            TableSeg tableSeg = classObj.getAnnotation(TableSeg.class);

            if (tableSeg != null) {
                Map<String, Object> parameter = getParameterFromMappedStatement(mappedStatement, boundSql);
                shardTable(metaObject, parameter, tableSeg, originalSql);
            }
        }

        return invocation.proceed();
    }

    /**
     * 获取参数
     *
     * @param ms
     * @param boundSql
     * @return
     */
    private Map<String, Object> getParameterFromMappedStatement(MappedStatement ms, BoundSql boundSql) {

        Map<String, Object> paramMap;

        Object parameterObject = boundSql.getParameterObject();

        if (parameterObject == null) {
            paramMap = new HashMap<String, Object>();
        } else if (parameterObject instanceof Map) {
            paramMap = new HashMap<String, Object>();
            paramMap.putAll((Map) parameterObject);
        } else {
            paramMap = new HashMap<String, Object>();
            boolean hasTypeHandler = ms.getConfiguration().getTypeHandlerRegistry()
                                       .hasTypeHandler(parameterObject.getClass());

            MetaObject metaObject = SystemMetaObject.forObject(parameterObject);
            if (!hasTypeHandler) {
                for (String name : metaObject.getGetterNames()) {
                    paramMap.put(name, metaObject.getValue(name));
                }
            }
            //下面这段方法,主要解决一个常见类型的参数时的问题
            if (boundSql.getParameterMappings() != null && boundSql.getParameterMappings().size() > 0) {
                for (ParameterMapping parameterMapping : boundSql.getParameterMappings()) {
                    String name = parameterMapping.getProperty();
                    if (paramMap.get(name) == null) {
                        if (hasTypeHandler || parameterMapping.getJavaType().equals(parameterObject.getClass())) {
                            paramMap.put(name, parameterObject);
                            break;
                        }
                    }
                }
            }
        }

        return paramMap;
    }

    /**
     * 分表操作(不可用于批量语句)
     *
     * @param metaObject
     * @param tableSeg
     * @param originalSql
     * @throws Exception
     */
    private void shardTable(MetaObject metaObject, Map<String, Object> parameter,
                            TableSeg tableSeg, String originalSql){
        MySqlStatementParser parser = new MySqlStatementParser(originalSql);
        SQLStatement statement = parser.parseStatement();

        StringBuilder newSql = new StringBuilder();
        SQLASTOutputVisitor visitor = SQLUtils.createOutputVisitor(newSql, JdbcConstants.MYSQL);

        Map<String, String> oldTableNewTableNameMap = getShardTableName(tableSeg, parameter);

        if(!oldTableNewTableNameMap.isEmpty()) {
            for (Map.Entry<String, String> entry : oldTableNewTableNameMap.entrySet()) {
                // 增加旧标明和新表名映射关系
                visitor.addTableMapping(entry.getKey(), entry.getValue());
            }
        }

        statement.accept(visitor);
        //重新赋值新sql生效
        metaObject.setValue(BOUNDSQL_SQL_NAME, newSql.toString());
    }

    /**
     * 构造分表表名映射
     * @param seg
     * @param parameter
     * @return
     */
    private Map<String, String> getShardTableName(TableSeg seg, Map<String, Object> parameter) {

        TableShardStrategy tableShardStrategy = seg.shardBy();
        // 分表code
        String memberIdStr = parameter.get(tableShardStrategy.getShardCode()).toString();
        Long memberId = Long.valueOf(memberIdStr);
        // 分表表名,可以针对多种类型做分表
        String[] toShardTableList = tableShardStrategy.getShardTableList();

        // 新老表名map,key:老表名  value:新表名
        Map<String, String> oldTableNewTableNameMap = new HashMap<>();
        String suffix;
        for (String toShardTable : toShardTableList) {
            //取模
            suffix = String.valueOf(memberId % seg.shardNum());
            StringBuilder shardTableName = new StringBuilder();
            //添加后缀
            oldTableNewTableNameMap.put(toShardTable, shardTableName.append(toShardTable).append("_").append(suffix).toString());
        }

        return oldTableNewTableNameMap;
    }

    @Override
    public Object plugin(Object target) {
        // 当目标类是StatementHandler类型时,才包装目标类,否者直接返回目标本身,减少目标被代理的次数
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }

    @Override
    public void setProperties(Properties properties) {

    }
}
/**
 * @Title:
 * @Auther: hangyu
 * @Date: 2019/4/15
 * @Description
 * @Version:1.0
 */
@Target({ElementType.TYPE })
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface TableSeg {

    /**
     * 分表方式,取模,如%4:表示取4余数,
     * 如果不设置,直接根据shardNum值分表
     * @return
     */
    int shardNum();

    /**
     * 根据什么字段分表
     * @return
     */
    TableShardStrategy shardBy();
}

/**
 * @Title:
 * @Auther: hangyu
 * @Date: 2019/4/15
 * @Description
 * @Version:1.0
 */
public enum TableShardStrategy {

    OPEN_ID("openId", new String[]{"member"});

    // 分表code
    private String shardCode;

    // 分表表名
    private String[] shardTableList;

    TableShardStrategy(String shardCode, String[] shardTableList) {
        this.shardCode = shardCode;
        this.shardTableList = shardTableList;
    }

    public String getShardCode() {
        return shardCode;
    }

    public void setShardCode(String shardCode) {
        this.shardCode = shardCode;
    }

    public String[] getShardTableList() {
        return shardTableList;
    }

    public void setShardTableList(String[] shardTableList) {
        this.shardTableList = shardTableList;
    }
}
public class Member implements Serializable {

    private Long memberId;

    private String openId;

    public Long getMemberId() {
        return memberId;
    }

    public void setMemberId(Long memberId) {
        this.memberId = memberId;
    }

    public String getOpenId() {
        return openId;
    }

    public void setOpenId(String openId) {
        this.openId = openId;
    }
}


/**
 * @Title:
 * @Auther: hangyu
 * @Date: 2019/4/15
 * @Description
 * @Version:1.0
 */
@Repository
@TableSeg(shardNum = 100, shardBy = TableShardStrategy.OPEN_ID)
public interface MemberDao {

    Member getMember(String openId);

    int insert(Member member);
}
/**
 * @Title:
 * @Auther: hangyu
 * @Date: 2019/4/15
 * @Description
 * @Version:1.0
 */
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})})
public class TableShareInterceptor implements Interceptor {

    private Log log = LogFactory.getLog(getClass());


    private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();

    private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();

    private final static ReflectorFactory DEFAULT_REFLECTOR_FACTORY = new DefaultReflectorFactory();

    private static final String DATE_PATTERN = "yyyyMMdd";

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        //全局操作对象
        MetaObject metaObject = MetaObject.forObject(statementHandler, DEFAULT_OBJECT_FACTORY,
                DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);

        MappedStatement mappedStatement = (MappedStatement)
                metaObject.getValue("delegate.mappedStatement");

        String id = mappedStatement.getId();
        id = id.substring(0, id.lastIndexOf('.'));
        Class clazz = Class.forName(id);

        // 获取TableShard注解
        TableSeg tableShard = (TableSeg)clazz.getAnnotation(TableSeg.class);
        if ( tableShard != null ) {
            TableShardStrategy tableShardStrategy = tableShard.shardBy();
            String tableName = tableShardStrategy.getShardTableList()[0];
            String newTableName = tableShard(tableName);
            // 获取源sql
            String sql = (String)metaObject.getValue("delegate.boundSql.sql");
            // 用新sql代替旧sql, 完成所谓的sql rewrite
            metaObject.setValue("delegate.boundSql.sql", sql.replaceAll(tableName, newTableName));
        }

        // 传递给下一个拦截器处理
        return invocation.proceed();
    }


    public String tableShard(String tableName) {
        SimpleDateFormat sdf = new SimpleDateFormat(DATE_PATTERN);
        return tableName + "_" + sdf.format(new Date());
    }

    @Override
    public Object plugin(Object o) {
        return null;
    }

    @Override
    public void setProperties(Properties properties) {

    }
}
上一篇下一篇

猜你喜欢

热点阅读