Mybatis 拦截器自定义使用添加通用字段

2023-04-20  本文已影响0人  同吢欢楽

** 最近,在开发项目中,由于数据结构设计的需要和业务需要,数据表中,会存在一些公共字段的插入,每个insert和update语句的实体,都会塞入公共字段,比如:tenantId、appId,于是,便想着通过自定义mybatis拦截器来实现mapper层的实体字段注入。 **

一、Mybatis拦截器是什么?
MyBatis 拦截器(Interceptor)是 MyBatis 框架提供的一个功能,用于在 SQL 执行过程中拦截某些方法的调用,从而实现自定义的功能扩展。拦截器可以拦截 MyBatis 的四大对象:Executor、StatementHandler、ParameterHandler 和 ResultSetHandler。
MyBatis 拦截器的主要用途是为了实现一些自定义的功能,例如:

1.SQL 统计:可以通过拦截器在查询 SQL 执行前后,统计 SQL 执行时间、查询结果数量等信息,以便对 SQL 查询性能进行优化。

2.分页:可以通过拦截器实现对查询结果进行分页处理,从而实现分页查询功能。

3.数据库路由:可以通过拦截器在执行 SQL 时,根据不同的条件动态选择不同的数据源,实现数据库的读写分离、分库分表等功能。

4.SQL 注入检测:可以通过拦截器在执行 SQL 前,对 SQL 参数进行检查,防止 SQL 注入攻击。

5.数据加密脱敏:可以通过拦截器对敏感信息如:密码、地址、电话等进行脱敏操作。

拦截器可以通过实现 MyBatis 的 Interceptor 接口来实现。在实现拦截器时,需要重写 intercept 方法,该方法会在 MyBatis 对象的方法调用前后进行拦截,并在需要的时候调用 Invocation.proceed() 方法来继续执行原方法。通过在 intercept 方法中添加自定义的逻辑,就可以实现拦截器的功能扩展。

MyBatis 拦截器是 MyBatis 框架中非常重要和实用的一个功能,可以通过拦截器扩展 MyBatis 的功能,提高应用程序的性能和可维护性。

二、如何实现?

新建一个类,实现mybatis的Interceptor接口,并实现它的intercept。

package com.streamax.human.interceptor;

import com.streamax.platform.common.util.BaseApplicationContext;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.jetbrains.annotations.NotNull;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.util.*;


@Intercepts({@Signature(
        type = Executor.class,
        method = "update",
        args = {MappedStatement.class, Object.class}
)})
@Component
public class MapperUpdateInterceptor implements Interceptor {

    private final String createUser = "createUser";
    private final String updateUser = "updateUser";
    private final String updateTime = "updateTime";
    private final String createTime = "createTime";
    private final String appId = "appId";
    private final String tenantId = "tenantId";

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        //代理类方法参数,该拦截器拦截的update方法有两个参数args = {MappedStatement.class, Object.class}
        Object[] args = invocation.getArgs();
        //获取方法参数
        MappedStatement mappedStatement = (MappedStatement) args[0];
        Object parameter = args[1];
        if (parameter instanceof Map) {
            //mapper参数通过@param("“) 走这个
            Map<String, Map<String, Object>> paramMap = (Map) parameter;
            final Set<Map.Entry<String, Map<String, Object>>> entries = paramMap.entrySet();
            final Iterator<Map.Entry<String, Map<String, Object>>> iterator = entries.iterator();
            if (iterator.hasNext()) {
                final Object value = iterator.next().getValue();
                setValues(value);
            }
        } else {
            //参数直接用实体,走这里
            setValues(parameter);

        }
        //获取操作类型,crud
        return invocation.proceed();
    }

    private void setValues(Object parameter) throws IllegalAccessException {
        if (parameter != null) {
            if (parameter instanceof List) {
                List<Object> list = (List<Object>) parameter;
                for (Object obj : list) {
                    setValuesCore(obj);
                }
            } else {
                setValuesCore(parameter);
            }

        }
    }

    private void setValuesCore(Object parameter) throws IllegalAccessException {
        Field[] fields = getAllFields(parameter);
        Long userId = BaseApplicationContext.getUserId() == null ? 0 : BaseApplicationContext.getUserId();
        Integer tenantId1 = BaseApplicationContext.getTenantId() == null ? 0 : BaseApplicationContext.getTenantId();
        Integer appId1 = BaseApplicationContext.getAppId() == null ? 0 : BaseApplicationContext.getAppId();
        long time = System.currentTimeMillis();
        for (Field field : fields) {
            String name = field.getName();
            field.setAccessible(true);
            final Object o = field.get(parameter);
            if (o != null) {
                continue;
            }
            switch (name) {
                case createUser:
                case updateUser:
                    field.set(parameter, userId);
                    break;
                case createTime:
                case updateTime:
                    field.set(parameter, time);
                    break;
                case tenantId:
                    field.set(parameter, tenantId1);
                    break;
                case appId:
                    field.set(parameter, appId1);
                    break;
                default:
                    break;
            }
        }

    }

    @NotNull
    private static Field[] getAllFields(Object parameter) {
        Class<?> clazz = parameter.getClass();
        List<Field> fieldList = new ArrayList<>();
        while (clazz != null) {
            fieldList.addAll(new ArrayList<>(Arrays.asList(clazz.getDeclaredFields())));
            clazz = clazz.getSuperclass();
        }
        Field[] fields = new Field[fieldList.size()];
        fieldList.toArray(fields);
        return fields;
    }
}

三、代码解析

首先,这个是要交给spring容器来管理的,所以必须要加上@Component注解,这个就不多说了。
然后是@Intercepts注解,这个注解主要配置的是@Signature里面的参数,这就是 你要想拦截的地方入口。

package org.apache.ibatis.plugin;

import java.lang.annotation.Documented;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * The annotation that indicate the method signature.
 *
 * @see Intercepts
 * @author Clinton Begin
 */
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({})
public @interface Signature {
  /**
   * Returns the java type.
   * 要拦截的类
   * @return the java type
   */
  Class<?> type();

  /**
   * Returns the method name.
   * 要拦截的方法名
   * @return the method name
   */
  String method();

  /**
   * Returns java types for method argument.
   * 拦截方法的参数
   * @return java types for method argument
   */
  Class<?>[] args();
}

我们这里拦截的是Executor接口,方法是update,如下源码:

/*
 *    Copyright 2009-2021 the original author or authors.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */
package org.apache.ibatis.executor;

import java.sql.SQLException;
import java.util.List;

import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.cursor.Cursor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.transaction.Transaction;

/**
 * @author Clinton Begin
 */
public interface Executor {

  ResultHandler NO_RESULT_HANDLER = null;
 /**
*
*我们拦截的是这个入口。
*/
  int update(MappedStatement ms, Object parameter) throws SQLException;

  <E> List<E> query(MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, CacheKey cacheKey, BoundSql boundSql) throws SQLException;

  <E> List<E> query(MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler) throws SQLException;

  <E> Cursor<E> queryCursor(MappedStatement ms, Object parameter, RowBounds rowBounds) throws SQLException;

  List<BatchResult> flushStatements() throws SQLException;

  void commit(boolean required) throws SQLException;

  void rollback(boolean required) throws SQLException;

  CacheKey createCacheKey(MappedStatement ms, Object parameterObject, RowBounds rowBounds, BoundSql boundSql);

  boolean isCached(MappedStatement ms, CacheKey key);

  void clearLocalCache();

  void deferLoad(MappedStatement ms, MetaObject resultObject, String property, CacheKey key, Class<?> targetType);

  Transaction getTransaction();

  void close(boolean forceRollback);

  boolean isClosed();

  void setExecutorWrapper(Executor executor);

}

代码走查时,有同事问,为什么method = update,为什么没有insert。这里给你答案了。method参数是Executor的方法名,我们在接口里面看到,根本没有insert方法的。其实看了源码,就可以知道,insert 语句执行的,也是update方法。

image.png
上一篇 下一篇

猜你喜欢

热点阅读