Spring Boot

Spring根据实体生成插入和更新SQL

2020-04-08  本文已影响0人  EasyNetCN

这种方法有以下限制:

1.只适合单表

2.主键为id

3.数据库名称,数据表字段名称使用LOWER_UNDERSCORE规则命名

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.stream.Collectors;

import org.springframework.beans.PropertyAccessorFactory;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

import com.google.common.base.CaseFormat;

public class JdbcUtility {
    private JdbcUtility() {

    }

    public static SqlParameterSource getSqlParameterSource(Object object) {
        return new BeanPropertyExtSqlParameterSource(object);
    }

    public static SqlParameterSource getSqlParameterSource(Object object, @Nullable String... ignoreProperties) {
        return new BeanPropertyExtSqlParameterSource(object, ignoreProperties);
    }

    public static String getInsertSql(Object object, @Nullable String... ignoreProperties) {
        var sb = new StringBuilder("INSERT INTO ").append(getTableName(object.getClass())).append(" (");

        var beanWrapper = PropertyAccessorFactory.forBeanPropertyAccess(object);
        var props = beanWrapper.getPropertyDescriptors();
        var ignoreSet = null != ignoreProperties ? Arrays.asList(ignoreProperties).stream()
                .filter(str -> !StringUtils.isEmpty(str)).collect(Collectors.toSet()) : new HashSet<String>(0);
        var names = new ArrayList<String>();

        for (var pd : props) {
            if (!pd.getName().equalsIgnoreCase("class") && beanWrapper.isReadableProperty(pd.getName())
                    && !ignoreSet.contains(pd.getName())) {
                names.add(pd.getName());
            }
        }

        for (var i = 0; i < names.size(); i++) {
            sb.append(CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, names.get(i)));

            if (i < names.size() - 1) {
                sb.append(",");
            }
        }

        sb.append(") VALUES (");

        for (var i = 0; i < names.size(); i++) {
            sb.append(":").append(names.get(i));

            if (i < names.size() - 1) {
                sb.append(",");
            }
        }

        sb.append(")");

        return sb.toString();
    }

    public static String getUpdateSql(Object object, @Nullable String... ignoreProperties) {
        var sb = new StringBuilder("UPDATE ").append(getTableName(object.getClass())).append(" SET ");

        var beanWrapper = PropertyAccessorFactory.forBeanPropertyAccess(object);
        var props = beanWrapper.getPropertyDescriptors();
        var ignoreSet = null != ignoreProperties ? Arrays.asList(ignoreProperties).stream()
                .filter(str -> !StringUtils.isEmpty(str)).collect(Collectors.toSet()) : new HashSet<String>(0);
        var names = new ArrayList<String>();

        for (var pd : props) {
            if (!pd.getName().equalsIgnoreCase("class") && !pd.getName().equalsIgnoreCase("id")
                    && beanWrapper.isReadableProperty(pd.getName()) && !ignoreSet.contains(pd.getName())) {
                names.add(pd.getName());
            }
        }

        for (var i = 0; i < names.size(); i++) {
            var name = names.get(i);

            sb.append(CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, name)).append("=").append(":")
                    .append(name);

            if (i < names.size() - 1) {
                sb.append(",");
            }
        }

        sb.append(" WHERE id=:id");

        return sb.toString();
    }

    public static String getTableName(Class<?> cls) {
        var className = cls.getSimpleName();

        return CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE,
                className.endsWith("Entity") ? className.substring(0, className.length() - "Entity".length())
                        : className);
    }
}
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

import org.springframework.beans.BeanWrapper;
import org.springframework.beans.NotReadablePropertyException;
import org.springframework.beans.PropertyAccessorFactory;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.jdbc.core.namedparam.AbstractSqlParameterSource;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

public class BeanPropertyExtSqlParameterSource extends AbstractSqlParameterSource {
    private final Set<String> ignoreProperties = new HashSet<>();

    private final BeanWrapper beanWrapper;

    @Nullable
    private String[] propertyNames;

    public BeanPropertyExtSqlParameterSource(Object object) {
        this.beanWrapper = PropertyAccessorFactory.forBeanPropertyAccess(object);
    }

    public BeanPropertyExtSqlParameterSource(Object object, @Nullable String... ignoreProperties) {
        this.beanWrapper = PropertyAccessorFactory.forBeanPropertyAccess(object);

        var ignoreSet = Arrays.asList(ignoreProperties).stream().filter(str -> !StringUtils.isEmpty(str))
                .collect(Collectors.toSet());

        if (!ignoreSet.isEmpty()) {
            this.ignoreProperties.addAll(ignoreSet);
        }
    }

    @Override
    public boolean hasValue(String paramName) {
        return this.beanWrapper.isReadableProperty(paramName);
    }

    @Override
    @Nullable
    public Object getValue(String paramName) {
        try {
            return this.beanWrapper.getPropertyValue(paramName);
        } catch (NotReadablePropertyException ex) {
            throw new IllegalArgumentException(ex.getMessage());
        }
    }

    @Override
    public int getSqlType(String paramName) {
        var sqlType = super.getSqlType(paramName);

        if (sqlType != TYPE_UNKNOWN) {
            return sqlType;
        }

        var propType = this.beanWrapper.getPropertyType(paramName);

        return StatementCreatorUtils.javaTypeToSqlParameterType(propType);
    }

    @Override
    @NonNull
    public String[] getParameterNames() {
        return getReadablePropertyNames();
    }

    public String[] getReadablePropertyNames() {
        if (this.propertyNames == null) {
            var names = new ArrayList<String>();
            var props = this.beanWrapper.getPropertyDescriptors();

            for (var pd : props) {
                if (this.beanWrapper.isReadableProperty(pd.getName()) && !ignoreProperties.contains(pd.getName())) {
                    names.add(pd.getName());
                }
            }

            this.propertyNames = StringUtils.toStringArray(names);
        }

        return this.propertyNames;
    }

}
上一篇 下一篇

猜你喜欢

热点阅读