SpringBoot 自动建表

2019-01-07  本文已影响46人  东方不喵

使用MyBatis或者JDBCTemplate的时候,并不能自动创建数据库表,这样需要多花点时间进行数据表的构建。为了减省这一步骤,可以编写一个简易的 自动建表模板。
GitHub https://github.com/oldguys/MultipleDataSourceDemo

Step1: Maven引用依赖。使用jpa规范进行建表。

<!-- 使用java规范进行设置  -->
<dependency>
    <groupId>javax.persistence</groupId>
    <artifactId>persistence-api</artifactId>
    <version>1.0</version>
</dependency>

Step2: 编写抽象模板

package com.oldguy.example.modules.common.services;

import com.oldguy.example.modules.common.dao.entities.SqlTableObject;

import java.util.List;
import java.util.Map;

/**
 * @Description: 用于实现不同方言的数据库
 * @Author: ren
 * @CreateTime: 2018-10-2018/10/23 0023 13:53
 */
public interface TableFactory {

    String showTableSQL();

    String getDialect();

    Map<Class, String> getColumnType();

    Map<String, String> trainToDBSchema(List<SqlTableObject> sqlTableObjects);
}

Step3: 编写实现类,来生成不同方言的Scheme

package com.oldguy.example.modules.common.services.impls;

import com.oldguy.example.modules.common.dao.entities.SqlTableObject;
import com.oldguy.example.modules.common.services.TableFactory;

import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @Description: MySQL 方言实现类,用于生产Scheme
 * @author ren
 * @date 2018/12/20
 */
public class MySQLTableFactory implements TableFactory {

    private static final Map<Class, String> columnType;

    static {
        columnType = new HashMap<>();
        columnType.put(Integer.class, "INT");
        columnType.put(Long.class, "BIGINT");
        columnType.put(String.class, "VARCHAR");
        columnType.put(Date.class, "DATETIME");
        columnType.put(Boolean.class, "TINYINT");
        columnType.put(Double.class, "DOUBLE");
    }


    @Override
    public String showTableSQL() {
        return "show tables";
    }

    @Override
    public String getDialect() {
        return "MySQL";
    }

    @Override
    public Map<Class, String> getColumnType() {
        return columnType;
    }

    @Override
    public Map<String, String> trainToDBSchema(List<SqlTableObject> sqlTableObjects) {

        Map<String, String> tableMap = new HashMap<>(sqlTableObjects.size());

        for (SqlTableObject obj : sqlTableObjects) {
            StringBuilder builder = new StringBuilder();
            builder.append("CREATE TABLE IF NOT EXISTS `" + obj.getTableName() + "` (").append("\n");

            for (int i = 0; i < obj.getColumns().size(); i++) {

                SqlTableObject.Column column = obj.getColumns().get(i);
                builder.append("`").append(column.getName()).append("` ");

                if (column.getType().equals("VARCHAR")) {
                    if (column.getLength() == null) {
                        builder.append(column.getType()).append("(").append(255).append(")");
                    } else {
                        builder.append(column.getType()).append("(").append(column.getLength()).append(")");
                    }
                } else {
                    builder.append(column.getType().toUpperCase());
                }

                if (column.isPrimaryKey()) {
                    builder.append(" PRIMARY KEY");
                    if (column.isAutoIncrement()) {
                        builder.append(" AUTO_INCREMENT");
                    }
                }

                if (!column.isNullable()) {
                    builder.append(" NOT NULL");
                }

                if(column.isUnique()){
                    builder.append(" UNIQUE");
                }

                if (i < obj.getColumns().size() - 1) {
                    builder.append(",");
                }

                builder.append("\n");
            }

            builder.append(") ENGINE=InnoDB DEFAULT CHARSET=utf8 ;").append("\n\n");

            if (tableMap.containsKey(obj.getTableName())) {
                throw new RuntimeException(obj.getTableName() + " 表名重复。");
            } else {
                tableMap.put(obj.getTableName(), builder.toString());
            }
        }

        return tableMap;
    }
}

Step4: 编写扫描注册器

package com.oldguy.example.modules.common.services;



import com.oldguy.example.modules.common.annotation.AssociateEntity;
import com.oldguy.example.modules.common.annotation.Entity;
import com.oldguy.example.modules.common.dao.entities.SqlTableObject;
import com.oldguy.example.modules.common.services.impls.MySQLTableFactory;
import com.oldguy.example.modules.common.utils.ClassUtils;
import org.springframework.util.StringUtils;

import javax.persistence.Column;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
import java.lang.reflect.Field;
import java.util.*;

/**
 * @Description: 数据表注册器
 * @Author: ren
 * @CreateTime: 2018-10-2018/10/16 0016 10:33
 */
public class DbRegister {

    private TableFactory tableFactory;

    public DbRegister() {
        this.tableFactory = new MySQLTableFactory();
    }

    public DbRegister(TableFactory tableFactory) {
        this.tableFactory = tableFactory;
    }


    /**
     * 编写数据库配置文件
     *
     * @param packageNames
     */
    public Map<String, String> registerClassToDB(String... packageNames) {

        if (packageNames.length == 0) {
            return Collections.emptyMap();
        }

        List<Class> classList = new ArrayList<>();
        for (String packageName : packageNames) {
            classList.addAll(ClassUtils.getClasses(packageName));
        }

        List<SqlTableObject> sqlTableObjects = new ArrayList<>();
        for (Class clazz : classList) {
            if (clazz.isAnnotationPresent(Entity.class) || clazz.isAnnotationPresent(AssociateEntity.class)) {

                SqlTableObject obj = new SqlTableObject();
                String tableName = "";
                String preIndex = "";

                if (clazz.isAnnotationPresent(Entity.class)) {
                    Entity annotation = (Entity) clazz.getAnnotation(Entity.class);
                    tableName = annotation.name();
                    preIndex = annotation.pre();
                } else if(clazz.isAnnotationPresent(AssociateEntity.class)){
                    AssociateEntity annotation = (AssociateEntity) clazz.getAnnotation(AssociateEntity.class);
                    tableName = annotation.name();
                    preIndex = annotation.pre();
                }

                tableName = StringUtils.isEmpty(tableName) ? preIndex + formatTableName(clazz.getSimpleName()) : tableName;
                obj.setTableName(tableName);

                // 配置字段
                List<Field> fields = new ArrayList<>();
                getAllField(clazz, fields);
                setTableColumns(obj, fields);
                sqlTableObjects.add(obj);
            }
        }

        //转换成为SQL Schema
        return trainToDBSchema(sqlTableObjects);
    }

    /**
     * 转换成为SQLSchema
     *
     * @param sqlTableObjects
     */
    private Map<String, String> trainToDBSchema(List<SqlTableObject> sqlTableObjects) {

        if (null == tableFactory) {
            throw new RuntimeException("TableFactory 不能为空!");
        }

        return tableFactory.trainToDBSchema(sqlTableObjects);
    }

    /**
     * 设置表格字段
     *
     * @param obj
     * @param fields
     */
    private void setTableColumns(SqlTableObject obj, List<Field> fields) {

        List<SqlTableObject.Column> columnList = new ArrayList<>();
        for (Field field : fields) {
            if (tableFactory.getColumnType().containsKey(field.getType())) {

                SqlTableObject.Column column = new SqlTableObject.Column();

                if (field.isAnnotationPresent(Id.class)) {
                    column.setPrimaryKey(true);
                    if (field.isAnnotationPresent(GeneratedValue.class)) {
                        GeneratedValue annotation = field.getAnnotation(GeneratedValue.class);
                        if (annotation.strategy().equals(GenerationType.AUTO)) {
                            column.setAutoIncrement(true);
                        }
                    }
                }

                if (field.isAnnotationPresent(Column.class)) {
                    Column annotation = field.getAnnotation(Column.class);

                    if (!StringUtils.isEmpty(annotation.name())) {
                        column.setName(annotation.name());
                    } else {
                        column.setName(formatTableName(field.getName()));
                    }

                    if (!StringUtils.isEmpty(annotation.columnDefinition())) {
                        column.setType(annotation.columnDefinition());
                    } else {
                        column.setType(tableFactory.getColumnType().get(field.getType()));
                    }

                    column.setLength(annotation.length());
                    column.setUnique(annotation.unique());
                    column.setNullable(annotation.nullable());
                } else {
                    column.setName(formatTableName(field.getName()));
                    column.setType(tableFactory.getColumnType().get(field.getType()));
                }
                columnList.add(column);
            }
        }
        obj.setColumns(columnList);
    }


    /**
     * 获取所有的 字段
     *
     * @param clazz
     * @param fields
     */
    private static void getAllField(Class clazz, List<Field> fields) {
        fields.addAll(Arrays.asList(clazz.getDeclaredFields()));
        if (!clazz.getSuperclass().equals(Object.class)) {
            getAllField(clazz.getSuperclass(), fields);
        }
    }

    /**
     * 驼峰转双峰
     *
     * @param name
     * @return
     */
    public static String formatTableName(String name) {
        StringBuilder formatResult = new StringBuilder();
        char[] upperCaseArrays = name.toUpperCase().toCharArray();
        char[] defaultArrays = name.toCharArray();

        for (int i = 0; i < upperCaseArrays.length; i++) {
            if (i == 0) {
                formatResult.append(String.valueOf(defaultArrays[0]).toLowerCase());
                continue;
            }
            if (defaultArrays[i] == upperCaseArrays[i]) {
                formatResult.append("_" + String.valueOf(defaultArrays[i]).toLowerCase());
            } else {
                formatResult.append(defaultArrays[i]);
            }
        }

        return formatResult.toString();
    }


    public void setTableFactory(TableFactory tableFactory) {
        this.tableFactory = tableFactory;
    }

    public TableFactory getTableFactory() {
        return tableFactory;
    }

}

Step5: 编写Configuration注册类

package com.oldguy.example.configs;


import com.oldguy.example.modules.common.services.DbRegister;
import com.oldguy.example.modules.common.utils.Log4jUtils;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.util.StringUtils;

import javax.annotation.Resource;
import javax.sql.DataSource;
import java.util.*;

/**
 * @Description: 数据库注册类
 * @author ren
 * @date 2018/12/20
 */
public class DbRegisterConfiguration {

    /**
     * 初始化数据库
     */
    public void initDB(JdbcTemplate jdbcTemplate,String typeAliasesPackage) {

        DbRegister dbRegister = new DbRegister();
        Map<String, String> tableMap = new HashMap<>();
        List<String> typeAliasesPackages = splitPackagesPath(typeAliasesPackage);

        for (String path : typeAliasesPackages) {
            tableMap.putAll(dbRegister.registerClassToDB(path));
        }

        if (!tableMap.keySet().isEmpty()) {

            List<Map<String, Object>> mapList = jdbcTemplate.queryForList(dbRegister.getTableFactory().showTableSQL());
            Set<String> tableNameSet = new HashSet<>();
            for (Map<String, Object> item : mapList) {
                for (String key : item.keySet()) {
                    tableNameSet.add((String) item.get(key));
                }
            }

            for (String key : tableMap.keySet()) {
                if (!tableNameSet.contains(key)) {
                    Log4jUtils.getInstance(getClass()).info("未找到表[" + key + "],进行创建.");
                    String sql = tableMap.get(key);
                    if (sql.trim().length() > 0) {
                        jdbcTemplate.execute(sql);
                        Log4jUtils.getInstance(getClass()).info("\n\n" + sql);
                    }
                } else {
                    Log4jUtils.getInstance(getClass()).info("表[" + key + "] 已存在");
                }
            }
        }
    }

    private List<String> splitPackagesPath(String typeAliasesPackage) {
        List<String> paths = new ArrayList<>();
        String[] packagePaths = typeAliasesPackage.split(";");
        for (String path : packagePaths) {
            if (!StringUtils.isEmpty(path)) {
                paths.add(path);
            }
        }
        return paths;
    }


}

以上模板基本构建完成。

Step6: 开始配置;进行数据库注册,可以注册多个数据源的数据库

  1. 注册类
package com.oldguy.example.configs;

import com.oldguy.example.modules.common.utils.Log4jUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;

import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.sql.DataSource;
import java.io.IOException;

/**
 * @author ren
 * @date 2018/12/20
 */
@Configuration
public class DemoConfiguration {

    @Resource(name = "test1DataSource")
    private DataSource test1dataSource;

    @Value("${test1.mybatis.type-aliases-package}")
    private String test1TypeAliasesPackage;

    @Resource(name = "test2DataSource")
    private DataSource test2dataSource;
    @Value("${test2.mybatis.type-aliases-package}")
    private String test2TypeAliasesPackage;



    @PostConstruct
    public void initData() throws IOException {

        Log4jUtils.getInstance(getClass()).info("初始化 数据库 -------------------------------");
        DbRegisterConfiguration dbConfiguration = new DbRegisterConfiguration();

        // 自动生成表结构 test1
        JdbcTemplate jdbcTemplate = new JdbcTemplate(test1dataSource);
        dbConfiguration.initDB(jdbcTemplate,test1TypeAliasesPackage);

        // 自动生成表结构 test2
        jdbcTemplate = new JdbcTemplate(test2dataSource);
        dbConfiguration.initDB(jdbcTemplate,test2TypeAliasesPackage);

    }
}

2.注册数据源

package com.oldguy.example.configs;

import com.alibaba.druid.pool.DruidDataSource;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ResourceLoader;


/**
 * @author ren
 * @date 2018/12/20
 */
@Configuration
public class Test1DataSourceConfiguration extends AbstractMybatisConfiguration {

    @Bean(name = "test1DataSource")
    @ConfigurationProperties(prefix = "test1.datasource")
    public DruidDataSource test1DataSource() {
        return new DruidDataSource();
    }

    @Bean(name = "test2DataSource")
    @ConfigurationProperties(prefix = "test2.datasource")
    public DruidDataSource test2DataSource() {
        return new DruidDataSource();
    }

}

3.yaml配置文件

test1:
  datasource:
    username: root
    password: root
    url: jdbc:mysql://127.0.0.1:3306/multiple_datasource1?useUnicode=true&characterEncoding=utf8&useSSL=false&allowMultiQueries=true
    driver-class-name: com.mysql.jdbc.Driver
    type: com.alibaba.druid.pool.DruidDataSource
  mybatis:
    mapper-locations: classpath:mappers/test1/*.xml
    type-aliases-package: com.oldguy.example.modules.test1.dao.entities;
    config-location: classpath:configs/myBatis-config.xml

test2:
  datasource:
    username: root
    password: root
    url: jdbc:mysql://127.0.0.1:3306/multiple_datasource2?useUnicode=true&characterEncoding=utf8&useSSL=false&allowMultiQueries=true
    driver-class-name: com.mysql.jdbc.Driver
    type: com.alibaba.druid.pool.DruidDataSource
  mybatis:
    mapper-locations: classpath:mappers/test2/*.xml
    type-aliases-package: com.oldguy.example.modules.test2.dao.entities;
    config-location: classpath:configs/myBatis-config.xml

这样就可以完整 自动建表配置类。

引用: 包扫描类

package com.oldguy.example.modules.common.utils;/**
 * Created by Administrator on 2018/10/16 0016.
 */

import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

/**
 * @author ren
 * @date 2018/12/20
 */
public class ClassUtils {

    /**
     * 通过包名获取包内所有类
     *
     * @param pkg
     * @return
     */
    public static List<Class<?>> getAllClassByPackageName(Package pkg) {
        String packageName = pkg.getName();
        // 获取当前包下以及子包下所以的类
        List<Class<?>> returnClassList = getClasses(packageName);
        return returnClassList;
    }

    /**
     * 通过接口名取得某个接口下所有实现这个接口的类
     */
    public static List<Class<?>> getAllClassByInterface(Class<?> c) {
        List<Class<?>> returnClassList = null;

        if (c.isInterface()) {
            // 获取当前的包名
            String packageName = c.getPackage().getName();
            // 获取当前包下以及子包下所以的类
            List<Class<?>> allClass = getClasses(packageName);
            if (allClass != null) {
                returnClassList = new ArrayList<Class<?>>();
                for (Class<?> cls : allClass) {
                    // 判断是否是同一个接口
                    if (c.isAssignableFrom(cls)) {
                        // 本身不加入进去
                        if (!c.equals(cls)) {
                            returnClassList.add(cls);
                        }
                    }
                }
            }
        }

        return returnClassList;
    }

    /**
     * 取得某一类所在包的所有类名 不含迭代
     */
    public static String[] getPackageAllClassName(String classLocation, String packageName) {
        // 将packageName分解
        String[] packagePathSplit = packageName.split("[.]");
        String realClassLocation = classLocation;
        int packageLength = packagePathSplit.length;
        for (int i = 0; i < packageLength; i++) {
            realClassLocation = realClassLocation + File.separator + packagePathSplit[i];
        }
        File packeageDir = new File(realClassLocation);
        if (packeageDir.isDirectory()) {
            String[] allClassName = packeageDir.list();
            return allClassName;
        }
        return null;
    }

    /**
     * 从包package中获取所有的Class
     *
     * @param packageName
     * @return
     */
    public static List<Class<?>> getClasses(String packageName) {

        // 第一个class类的集合
        List<Class<?>> classes = new ArrayList<>();
        // 是否循环迭代
        boolean recursive = true;
        // 获取包的名字 并进行替换
        String packageDirName = packageName.replace('.', '/');
        // 定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs;
        try {
            dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
            // 循环迭代下去
            while (dirs.hasMoreElements()) {
                // 获取下一个元素
                URL url = dirs.nextElement();
                // 得到协议的名称
                String protocol = url.getProtocol();
                // 如果是以文件的形式保存在服务器上
                if ("file".equals(protocol)) {
                    // 获取包的物理路径
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    // 以文件的方式扫描整个包下的文件 并添加到集合中
                    findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                } else if ("jar".equals(protocol)) {
                    // 如果是jar包文件
                    // 定义一个JarFile
                    JarFile jar;
                    try {
                        // 获取jar
                        jar = ((JarURLConnection) url.openConnection()).getJarFile();
                        // 从此jar包 得到一个枚举类
                        Enumeration<JarEntry> entries = jar.entries();
                        // 同样的进行循环迭代
                        while (entries.hasMoreElements()) {
                            // 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件
                            JarEntry entry = entries.nextElement();
                            String name = entry.getName();
                            // 如果是以/开头的
                            if (name.charAt(0) == '/') {
                                // 获取后面的字符串
                                name = name.substring(1);
                            }
                            // 如果前半部分和定义的包名相同
                            if (name.startsWith(packageDirName)) {
                                int idx = name.lastIndexOf('/');
                                // 如果以"/"结尾 是一个包
                                if (idx != -1) {
                                    // 获取包名 把"/"替换成"."
                                    packageName = name.substring(0, idx).replace('/', '.');
                                }
                                // 如果可以迭代下去 并且是一个包
                                if ((idx != -1) || recursive) {
                                    // 如果是一个.class文件 而且不是目录
                                    if (name.endsWith(".class") && !entry.isDirectory()) {
                                        // 去掉后面的".class" 获取真正的类名
                                        String className = name.substring(packageName.length() + 1, name.length() - 6);
                                        try {
                                            // 添加到classes
                                            classes.add(Class.forName(packageName + '.' + className));
                                        } catch (ClassNotFoundException e) {
                                            e.printStackTrace();
                                        }
                                    }
                                }
                            }
                        }
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }

        return classes;
    }

    /**
     * 以文件的形式来获取包下的所有Class
     *
     * @param packageName
     * @param packagePath
     * @param recursive
     * @param classes
     */
    private static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive, List<Class<?>> classes) {
        // 获取此包的目录 建立一个File
        File dir = new File(packagePath);
        // 如果不存在或者 也不是目录就直接返回
        if (!dir.exists() || !dir.isDirectory()) {
            return;
        }
        // 如果存在 就获取包下的所有文件 包括目录
        File[] dirfiles = dir.listFiles(new FileFilter() {
            // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
            @Override
            public boolean accept(File file) {
                return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
            }
        });
        // 循环所有文件
        for (File file : dirfiles) {
            // 如果是目录 则继续扫描
            if (file.isDirectory()) {
                findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
            } else {
                // 如果是java类文件 去掉后面的.class 只留下类名
                String className = file.getName().substring(0, file.getName().length() - 6);
                try {
                    // 添加到集合中去
                    classes.add(Class.forName(packageName + '.' + className));
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
        }
    }

}

到此完成了 多数据源 自动建表模板搭建。
代码可以参考 GitHub https://github.com/oldguys/MultipleDataSourceDemo

上一篇 下一篇

猜你喜欢

热点阅读