【解析SQL模板-1】Mybatis的SQL模板组合成可运行的S

2024-09-02  本文已影响0人  小胖学编程

背景

实现平台化的mybatis能力,即在页面上输入mybatis的SQL模板,并传入参数,最终解析成可运行的SQL。

实现原理

引入依赖:

<dependency>
    <groupId>org.mybatis</groupId>
    <artifactId>mybatis</artifactId>
    <version>3.5.7</version>
</dependency>

mybatis的SQL生成器:

  1. 解析mybatis模板,生成预编译的SQL;
  2. 解析预编译SQL,参数替换?;
@Slf4j
public class MybatisGenerator {

    private static final String HEAD = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
            + "<!DOCTYPE mapper PUBLIC \"-//mybatis.org//DTD Mapper 3.0//EN\" \"http://mybatis"
            + ".org/dtd/mybatis-3-mapper.dtd\">"
            + "<mapper namespace=\"customGenerator\">"
            + "<select id=\"selectData\" parameterType=\"map\" resultType=\"map\">\n";

    private static final String FOOT = "\n</select></mapper>";

    private static final LoadingCache<String, MappedStatement> mappedStatementCache = CacheBuilder.newBuilder()
            .refreshAfterWrite(1, TimeUnit.DAYS)
            .build(new CacheLoader<String, MappedStatement>() {
                @Override
                public MappedStatement load(@NotNull String key) {
                    Configuration configuration = new Configuration();
                    configuration.setShrinkWhitespacesInSql(true);
                    String sourceSQL = HEAD + key + FOOT;
                    XMLMapperBuilder xmlMapperBuilder =
                            new XMLMapperBuilder(IOUtils.toInputStream(sourceSQL, Charset.forName("UTF-8")),
                                    configuration, null,
                                    null);
                    xmlMapperBuilder.parse();
                    return xmlMapperBuilder.getConfiguration().getMappedStatement("selectData");
                }
            });

    //生成完整SQL
    public static String generateDsl(SQLConfig apiConfig, Map<String, Object> conditions) {
        String sql = apiConfig.getSqlTemplate();
        try {
            MappedStatement mappedStatement = mappedStatementCache.getUnchecked(sql);
            BoundSql boundSql = mappedStatement.getBoundSql(conditions);
            if (!boundSql.getParameterMappings().isEmpty()) {
                List<PreparedStatementParameter> parameters = boundSql.getParameterMappings()
                        .stream().map(ParameterMapping::getProperty)
                        .map(param -> Optional.ofNullable(boundSql.getAdditionalParameter(param))
                                .orElseGet(() -> conditions.get(param)))
                        .map(PreparedStatementParameter::fromObject)
                        .collect(Collectors.toList());
                //解析占位符,获取到完整SQL
                return PreparedStatementParser.parse(boundSql.getSql()).buildSql(parameters);
            } else {
                return boundSql.getSql();
            }
        } catch (UncheckedExecutionException e) {
            throw e;
        }
    }

    @Data
    public static class SQLConfig {
        //SQL模板
        private String sqlTemplate;
    }
}

因为需要处理?(占位符),所以需要判断是否进行转义处理。


public final class ValueFormatter {
    private static final Escaper ESCAPER = Escapers.builder()
            .addEscape('\\', "\\\\")
            .addEscape('\n', "\\n")
            .addEscape('\t', "\\t")
            .addEscape('\b', "\\b")
            .addEscape('\f', "\\f")
            .addEscape('\r', "\\r")
            .addEscape('\u0000', "\\0")
            .addEscape('\'', "\\'")
            .addEscape('`', "\\`")
            .build();

    public static final String NULL_MARKER = "\\N";
    private static final ThreadLocal<SimpleDateFormat> DATE_FORMAT =
            ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd"));
    private static final ThreadLocal<SimpleDateFormat> DATE_TIME_FORMAT =
            ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));

    public static String formatBytes(byte[] bytes) {
        if (bytes == null) {
            return null;
        } else {
            char[] hexArray =
                    new char[] {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
            char[] hexChars = new char[bytes.length * 4];

            for (int j = 0; j < bytes.length; ++j) {
                int v = bytes[j] & 255;
                hexChars[j * 4] = '\\';
                hexChars[j * 4 + 1] = 'x';
                hexChars[j * 4 + 2] = hexArray[v / 16];
                hexChars[j * 4 + 3] = hexArray[v % 16];
            }

            return new String(hexChars);
        }
    }

    public static String formatInt(int myInt) {
        return Integer.toString(myInt);
    }

    public static String formatDouble(double myDouble) {
        return Double.toString(myDouble);
    }

    public static String formatChar(char myChar) {
        return Character.toString(myChar);
    }

    public static String formatLong(long myLong) {
        return Long.toString(myLong);
    }

    public static String formatFloat(float myFloat) {
        return Float.toString(myFloat);
    }

    public static String formatBigDecimal(BigDecimal myBigDecimal) {
        return myBigDecimal != null ? myBigDecimal.toPlainString() : "\\N";
    }

    public static String formatShort(short myShort) {
        return Short.toString(myShort);
    }

    public static String formatString(String myString) {
        return escape(myString);
    }

    public static String formatNull() {
        return "\\N";
    }

    public static String formatByte(byte myByte) {
        return Byte.toString(myByte);
    }

    public static String formatBoolean(boolean myBoolean) {
        return myBoolean ? "1" : "0";
    }

    public static String formatUUID(UUID x) {
        return x.toString();
    }

    public static String formatBigInteger(BigInteger x) {
        return x.toString();
    }

    public static String formatObject(Object x) {
        if (x == null) {
            return null;
        } else if (x instanceof Byte) {
            return formatInt(((Byte) x).intValue());
        } else if (x instanceof String) {
            return formatString((String) x);
        } else if (x instanceof BigDecimal) {
            return formatBigDecimal((BigDecimal) x);
        } else if (x instanceof Short) {
            return formatShort((Short) x);
        } else if (x instanceof Integer) {
            return formatInt((Integer) x);
        } else if (x instanceof Long) {
            return formatLong((Long) x);
        } else if (x instanceof Float) {
            return formatFloat((Float) x);
        } else if (x instanceof Double) {
            return formatDouble((Double) x);
        } else if (x instanceof byte[]) {
            return formatBytes((byte[]) x);
        } else if (x instanceof Boolean) {
            return formatBoolean((Boolean) x);
        } else if (x instanceof UUID) {
            return formatUUID((UUID) x);
        } else if (x instanceof BigInteger) {
            return formatBigInteger((BigInteger) x);
        } else {
            return String.valueOf(x);
        }
    }

    public static boolean needsQuoting(Object o) {
        if (o == null) {
            return false;
        } else if (o instanceof Number) {
            return false;
        } else if (o instanceof Boolean) {
            return false;
        } else if (o.getClass().isArray()) {
            return false;
        } else {
            return !(o instanceof Collection);
        }
    }

    private static SimpleDateFormat getDateFormat() {
        return DATE_FORMAT.get();
    }

    private static SimpleDateFormat getDateTimeFormat() {
        return DATE_TIME_FORMAT.get();
    }

    public static String escape(String s) {
        return s == null ? "\\N" : ESCAPER.escape(s);
    }

    public static String quoteIdentifier(String s) {
        if (s == null) {
            throw new IllegalArgumentException("Can't quote null as identifier");
        } else {
            StringBuilder sb = new StringBuilder(s.length() + 2);
            sb.append('`');
            sb.append(ESCAPER.escape(s));
            sb.append('`');
            return sb.toString();
        }
    }
}

定义预编译的参数:

public final class PreparedStatementParameter {
    private static final PreparedStatementParameter
            NULL_PARAM = new PreparedStatementParameter((String) null, false);
    private static final PreparedStatementParameter
            TRUE_PARAM = new PreparedStatementParameter("1", false);
    private static final PreparedStatementParameter
            FALSE_PARAM = new PreparedStatementParameter("0", false);
    private final String stringValue;
    private final boolean quoteNeeded;

    //判断是否转义
    public static PreparedStatementParameter fromObject(Object x) {
        return x == null ? NULL_PARAM : new PreparedStatementParameter(
                ValueFormatter.formatObject(x),
                ValueFormatter.needsQuoting(x));
    }

    public static PreparedStatementParameter nullParameter() {
        return NULL_PARAM;
    }

    public static PreparedStatementParameter boolParameter(boolean value) {
        return value ? TRUE_PARAM : FALSE_PARAM;
    }

    public PreparedStatementParameter(String stringValue, boolean quoteNeeded) {
        this.stringValue = stringValue == null ? "\\N" : stringValue;
        this.quoteNeeded = quoteNeeded;
    }

    //判断是否需要转义
    String getRegularValue() {
        return !"\\N".equals(this.stringValue) ? (this.quoteNeeded ? "'" + this.stringValue + "'" : this.stringValue)
                                               : "null";
    }

    String getBatchValue() {
        return this.stringValue;
    }

    public String toString() {
        return this.stringValue;
    }
}

预编译解析器:将参数替换到占位符

public class PreparedStatementParser {

    static final String PARAM_MARKER = "?";
    static final String NULL_MARKER = "\\N";

    private static final Pattern VALUES = Pattern.compile(
            "(?i)INSERT\\s+INTO\\s+.+VALUES\\s*\\(",
            Pattern.MULTILINE | Pattern.DOTALL);

    private List<List<String>> parameters;
    private List<String> parts;
    private boolean valuesMode;

    private PreparedStatementParser() {
        parameters = new ArrayList<>();
        parts = new ArrayList<>();
        valuesMode = false;
    }

    public static PreparedStatementParser parse(String sql) {
        return parse(sql, -1);
    }

    public static PreparedStatementParser parse(String sql, int valuesEndPosition) {
        if (StringUtils.isBlank(sql)) {
            throw new IllegalArgumentException("SQL may not be blank");
        }
        PreparedStatementParser parser = new PreparedStatementParser();
        parser.parseSQL(sql, valuesEndPosition);
        return parser;
    }

    List<List<String>> getParameters() {
        return Collections.unmodifiableList(parameters);
    }

    List<String> getParts() {
        return Collections.unmodifiableList(parts);
    }

    boolean isValuesMode() {
        return valuesMode;
    }

    private void reset() {
        parameters.clear();
        parts.clear();
        valuesMode = false;
    }

    private void parseSQL(String sql, int valuesEndPosition) {
        reset();
        List<String> currentParamList = new ArrayList<String>();
        boolean afterBackSlash = false;
        boolean inQuotes = false;
        boolean inBackQuotes = false;
        boolean inSingleLineComment = false;
        boolean inMultiLineComment = false;
        boolean whiteSpace = false;
        int endPosition = 0;
        if (valuesEndPosition > 0) {
            valuesMode = true;
            endPosition = valuesEndPosition;
        } else {
            Matcher matcher = VALUES.matcher(sql);
            if (matcher.find()) {
                valuesMode = true;
                endPosition = matcher.end() - 1;
            }
        }

        int currentParensLevel = 0;
        int quotedStart = 0;
        int partStart = 0;
        int sqlLength = sql.length();
        for (int i = valuesMode ? endPosition : 0, idxStart = i, idxEnd = i; i < sqlLength; i++) {
            char c = sql.charAt(i);
            if (inSingleLineComment) {
                if (c == '\n') {
                    inSingleLineComment = false;
                }
            } else if (inMultiLineComment) {
                if (c == '*' && sqlLength > i + 1 && sql.charAt(i + 1) == '/') {
                    inMultiLineComment = false;
                    i++;
                }
            } else if (afterBackSlash) {
                afterBackSlash = false;
            } else if (c == '\\') {
                afterBackSlash = true;
            } else if (c == '\'' && !inBackQuotes) {
                inQuotes = !inQuotes;
                if (inQuotes) {
                    quotedStart = i;
                } else if (!afterBackSlash) {
                    idxStart = quotedStart;
                    idxEnd = i + 1;
                }
            } else if (c == '`' && !inQuotes) {
                inBackQuotes = !inBackQuotes;
            } else if (!inQuotes && !inBackQuotes) {
                if (c == '?') {
                    if (currentParensLevel > 0) {
                        idxStart = i;
                        idxEnd = i + 1;
                    }
                    if (!valuesMode) {
                        parts.add(sql.substring(partStart, i));
                        partStart = i + 1;
                        currentParamList.add(PARAM_MARKER);
                    }
                } else if (c == '-' && sqlLength > i + 1 && sql.charAt(i + 1) == '-') {
                    inSingleLineComment = true;
                    i++;
                } else if (c == '/' && sqlLength > i + 1 && sql.charAt(i + 1) == '*') {
                    inMultiLineComment = true;
                    i++;
                } else if (c == ',') {
                    if (valuesMode && idxEnd > idxStart) {
                        currentParamList.add(typeTransformParameterValue(sql.substring(idxStart, idxEnd)));
                        parts.add(sql.substring(partStart, idxStart));
                        partStart = idxEnd;
                        idxEnd = i;
                        idxStart = idxEnd;
                    }
                    idxStart++;
                    idxEnd++;
                } else if (c == '(') {
                    currentParensLevel++;
                    idxStart++;
                    idxEnd++;
                } else if (c == ')') {
                    currentParensLevel--;
                    if (valuesMode && currentParensLevel == 0) {
                        if (idxEnd > idxStart) {
                            currentParamList.add(typeTransformParameterValue(sql.substring(idxStart, idxEnd)));
                            parts.add(sql.substring(partStart, idxStart));
                            partStart = idxEnd;
                            idxEnd = i;
                            idxStart = idxEnd;
                        }
                        if (!currentParamList.isEmpty()) {
                            parameters.add(currentParamList);
                            currentParamList = new ArrayList<>(currentParamList.size());
                        }
                    }
                } else if (Character.isWhitespace(c)) {
                    whiteSpace = true;
                } else if (currentParensLevel > 0) {
                    if (whiteSpace) {
                        idxStart = i;
                        idxEnd = i + 1;
                    } else {
                        idxEnd++;
                    }
                    whiteSpace = false;
                }
            }
        }
        if (!valuesMode && !currentParamList.isEmpty()) {
            parameters.add(currentParamList);
        }
        String lastPart = sql.substring(partStart, sqlLength);
        parts.add(lastPart);
    }

    private static String typeTransformParameterValue(String paramValue) {
        if (paramValue == null) {
            return null;
        }
        if (Boolean.TRUE.toString().equalsIgnoreCase(paramValue)) {
            return "1";
        }
        if (Boolean.FALSE.toString().equalsIgnoreCase(paramValue)) {
            return "0";
        }
        if ("NULL".equalsIgnoreCase(paramValue)) {
            return NULL_MARKER;
        }
        return paramValue;
    }

    public String buildSql(List<PreparedStatementParameter> binds) {
        if (this.parts.size() == 1) {
            return this.parts.get(0);
        } else {
            StringBuilder sb = new StringBuilder(this.parts.get(0));
            int i = 1;

            for (int t = 0; i < this.parts.size(); ++i) {
                String pValue = this.getParameter(i - 1);
                //占位符-#{}会进行转义
                if ("?".equals(pValue)) {
                    sb.append(binds.get(t++).getRegularValue());
                } else {
                    sb.append(pValue);
                }
                sb.append(this.parts.get(i));
            }
            return sb.toString();
        }
    }

    private String getParameter(int paramIndex) {
        int i = 0;
        for (int count = paramIndex; i < this.parameters.size(); ++i) {
            List<String> pList = this.parameters.get(i);
            count = count - pList.size();
            if (count < 0) {
                return pList.get(pList.size() + count);
            }
        }
        return null;
    }
}

文章参考

Mybatis interceptor 获取clickhouse最终执行的sql

【Mybatis】单独使用mybatis的SQL模板解析

上一篇下一篇

猜你喜欢

热点阅读