mybatis-sql拦截

2021-11-17  本文已影响0人  幻如常
package com.zvos.bmp.risk.data.infrastructure.config;

import com.zvos.bmp.risk.data.application.service.SchedulerService;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;

import javax.persistence.Column;
import javax.persistence.Entity;
import java.lang.reflect.Field;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;


@Intercepts({@Signature(type = Executor.class, method = "update",
        args = {MappedStatement.class, Object.class})})
public class BaseUpdateInterceptor implements Interceptor {
    /**
     * 缓存字段
     */
    private static final Map<Class<?>, Map<SqlCommandType, List<SQLField>>> sqlMap = new ConcurrentHashMap<>();


    /**
     * 遍历获取容器里面要注入的对象
     *
     * @param target
     * @param list
     * @return
     */
    private List<Object> getTarget(Object target, List<Object> list) {
        if (target instanceof List) {
            ((List<?>) target).forEach((Consumer<Object>) o -> getTarget(o, list));
        } else if (target instanceof Map) {
            for (Map.Entry<?, ?> entry : ((Map<?, ?>) target).entrySet()) {
                getTarget(entry.getValue(), list);
            }
        } else {
            list.add(target);
        }
        return list;
    }
    
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];

        SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();

        Object target = invocation.getArgs()[1];

        if (sqlCommandType == SqlCommandType.UPDATE || sqlCommandType == SqlCommandType.INSERT) {
            List<Object> targetList = getTarget(target, new LinkedList<>());
            for (Object targetObject : targetList) {
                List<SQLField> fields = sqlMap.computeIfAbsent(targetObject.getClass(),
                        e -> getSQLField(targetObject.getClass())).get(sqlCommandType);
                if (fields.size() > 0) {
                    for (SQLField field : fields) {
                        try {
                            if (Objects.isNull(field.field.get(targetObject))) {
                                field.field.set(targetObject, field.value.apply(targetObject));
                            }
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
        }
        return invocation.proceed();
    }
    
    /**
     * 获取所有Field
     *
     * @param aClass
     * @param list
     * @return
     */
    private List<Field> getAllFields(Class<?> aClass, List<Field> list) {
        list.addAll(Arrays.asList(aClass.getDeclaredFields()));
        Class<?> superClass = aClass.getSuperclass();
        return (superClass == null || superClass.equals(Object.class) ||
                !superClass.isAnnotationPresent(Entity.class) && (Map.class.isAssignableFrom(superClass) ||
                        Collection.class.isAssignableFrom(superClass)) ? list : getAllFields(superClass, list));
    }
    
    private static final Function<Object, Object> BATCH_NO = o -> SchedulerService.batchNo;
    
    private Map<SqlCommandType, List<SQLField>> getSQLField(Class<?> aClass) {
        //List<SQLField> updateFields = new ArrayList<>();
        List<SQLField> insertFields = new ArrayList<>();

        List<Field> fields = getAllFields(aClass, new LinkedList<>());
        for (Field field : fields) {
            Column column = field.getAnnotation(Column.class);
            if (column == null) {
                continue;
            }
            switch (column.name()) {
                case "batch_no":
                    field.setAccessible(true);
                    //updateFields.add(new SQLField(field, BATCH_NO));
                    insertFields.add(new SQLField(field, BATCH_NO));
                    break;
                default:
                    break;
            }
        }
        Map<SqlCommandType, List<SQLField>> map = new HashMap<>(12);
        //map.put(SqlCommandType.UPDATE, updateFields);
        map.put(SqlCommandType.INSERT, insertFields);
        return map;
    }
    
    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }
    
    private static class SQLField {
        /**
         * 字段
         */
        private final Field field;
        /**
         * 给这个字段设置值的函数
         */
        private final Function<Object, Object> value;

        public SQLField(Field field, Function<Object, Object> value) {
            this.field = field;
            this.value = value;
        }
    }
}
上一篇下一篇

猜你喜欢

热点阅读