Mybatis模糊查询限定词注入

2022-06-06  本文已影响0人  越狱的灵感

问题

前端大佬测试反馈,使用MyBatis中的模糊查询时,当查询关键字中包括有_、%时,查询关键字失效,会返回所有结果,如下图


1.png 2.png

原因

1、当like中包含时,查询仍为全部,即 like '%%'查询出来的结果与like '%%'一致,并不能查询出实际字段中包含有_特殊字符的结果条目
2、like中包括%时,与%中相同

处理

mybatis 新增拦截器

@Component
@Intercepts(
        {@Signature(type = Executor.class, method = "query", args =
                {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
                @Signature(type = Executor.class, method = "query", args =
                        {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})})
public class MyBatisQueryEscapeInterceptor implements Interceptor {
    private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();
    private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();
    private static final String ROOT_SQL_NODE = "sqlSource.rootSqlNode";
 
    protected static final Logger LOGGER = LoggerFactory.getLogger(MyBatisQueryEscapeInterceptor.class);
 
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        long startTime = System.currentTimeMillis();
        Object parameter = invocation.getArgs()[1];
        MappedStatement statement = (MappedStatement) invocation.getArgs()[0];
        MetaObject metaMappedStatement = MetaObject.forObject(statement, DEFAULT_OBJECT_FACTORY,
                DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
        BoundSql boundSql = statement.getBoundSql(parameter);
        if (metaMappedStatement.hasGetter(ROOT_SQL_NODE)) {
            SqlNode sqlNode = (SqlNode) metaMappedStatement.getValue(ROOT_SQL_NODE);
            getBoundSql(statement.getConfiguration(), boundSql.getParameterObject(), sqlNode);
        }
        try {
            return invocation.proceed();
        } catch (Exception ex) {
            throw ex;
        } finally {
            long endTime = System.currentTimeMillis();
            long sqlCost = endTime - startTime;
            LOGGER.info("MYSQL_SQL: [ " + beautifySql(boundSql.getSql()) + " ] 执行耗时:[" + sqlCost + "ms]");
        }
    }
 
 
    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }
 
    @Override
    public void setProperties(Properties properties) {
    }
 
    public static BoundSql getBoundSql(Configuration configuration, Object parameterObject, SqlNode sqlNode) {
        DynamicContext context = new DynamicContext(configuration, parameterObject);
        sqlNode.apply(context);
        String contextSql = context.getSql();
        SqlSourceBuilder sqlSourceParser = new SqlSourceBuilder(configuration);
        Class<?> parameterType = parameterObject == null ? Object.class : parameterObject.getClass();
        String sql = modifyLikeSql(contextSql, parameterObject);
        SqlSource sqlSource = sqlSourceParser.parse(sql, parameterType, context.getBindings());
        BoundSql boundSql = sqlSource.getBoundSql(parameterObject);
        for (Map.Entry<String, Object> entry : context.getBindings().entrySet()) {
            boundSql.setAdditionalParameter(entry.getKey(), entry.getValue());
        }
        return boundSql;
    }
 
 
    /**
     * 优化like中的限定词注入
     *
     * @param sql
     * @param parameterObject
     * @return
     */
    public static String modifyLikeSql(String sql, Object parameterObject) {
        //--只处理like query
        if (!sql.toLowerCase().contains("like")) {
            return sql;
        }
        String reg = "LIKE\\s#\\{[^\\}]+\\}";
        Pattern pattern = Pattern.compile(reg, Pattern.CASE_INSENSITIVE);
        Matcher matcher = pattern.matcher(sql);
        List<String> replaceFiled = new ArrayList<>();
        //--读取参数
        while (matcher.find()) {
            int n = matcher.groupCount();
            for (int i = 0; i <= n; i++) {
                String output = matcher.group(i);
                if (output != null) {
                    String key = getParameterKey(output);
                    if (replaceFiled.indexOf(key) < 0) {
                        replaceFiled.add(key);
                    }
                }
            }
        }
        // 转义参数
        MetaObject metaObject = MetaObject.forObject(parameterObject, DEFAULT_OBJECT_FACTORY,
                DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
        for (String key : replaceFiled) {
            Object val = metaObject.getValue(key);
            //--如果字符串是完全只由%或者_组成,则直接转义
            if (null != val && specialChar(val.toString())) {
                val = escapeChar(val.toString());
                metaObject.setValue(key, val);
            }
        }
        return sql;
    }
 
    /**
     * 判断请求参数是否全部为%或者_组成
     *
     * @param content
     * @return
     */
    private static boolean specialChar(String content) {
        if (StringUtils.isEmpty(content)) {
            return false;
        }
        String regex = "^[%_]+$";
        Pattern pattern = Pattern.compile(regex);
        Matcher match = pattern.matcher(content);
        return match.matches();
    }
 
    /**
     * 特殊字符专业
     *
     * @param val
     * @return
     */
    private static String escapeChar(String val) {
        if (StringUtils.isNotBlank(val)) {
            val = val.replaceAll("\\\\", "\\\\\\\\");
            val = val.replaceAll("_", "\\\\_");
            val = val.replaceAll("%", "\\\\%");
        }
        return val;
    }
 
    /**
     * 读取出sql中的参数
     *
     * @param input
     * @return
     */
    private static String getParameterKey(String input) {
        String key = "";
        String[] temp = input.split("#");
        if (temp.length > 1) {
            key = temp[1];
            key = key.replace("{", "").replace("}", "").split(",")[0];
        }
        return key.trim();
    }
 
    /**
     * 美化SQL
     *
     * @param sql
     * @return
     */
    private static String beautifySql(String sql) {
        // sql = sql.replace("\n", "").replace("\t", "").replace(" ", "
        // ").replace("( ", "(").replace(" )", ")").replace(" ,", ",");
        sql = sql.trim().replaceAll("[\\s\n ]+", " ");
        return sql;
    }
}

结果

3.png 4.png
上一篇下一篇

猜你喜欢

热点阅读