SpringBoot 自动建表
2019-01-07 本文已影响46人
东方不喵
使用MyBatis或者JDBCTemplate的时候,并不能自动创建数据库表,这样需要多花点时间进行数据表的构建。为了减省这一步骤,可以编写一个简易的 自动建表模板。
GitHub https://github.com/oldguys/MultipleDataSourceDemo
Step1: Maven引用依赖。使用jpa规范进行建表。
<!-- 使用java规范进行设置 -->
<dependency>
<groupId>javax.persistence</groupId>
<artifactId>persistence-api</artifactId>
<version>1.0</version>
</dependency>
Step2: 编写抽象模板
package com.oldguy.example.modules.common.services;
import com.oldguy.example.modules.common.dao.entities.SqlTableObject;
import java.util.List;
import java.util.Map;
/**
* @Description: 用于实现不同方言的数据库
* @Author: ren
* @CreateTime: 2018-10-2018/10/23 0023 13:53
*/
public interface TableFactory {
String showTableSQL();
String getDialect();
Map<Class, String> getColumnType();
Map<String, String> trainToDBSchema(List<SqlTableObject> sqlTableObjects);
}
Step3: 编写实现类,来生成不同方言的Scheme
package com.oldguy.example.modules.common.services.impls;
import com.oldguy.example.modules.common.dao.entities.SqlTableObject;
import com.oldguy.example.modules.common.services.TableFactory;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* @Description: MySQL 方言实现类,用于生产Scheme
* @author ren
* @date 2018/12/20
*/
public class MySQLTableFactory implements TableFactory {
private static final Map<Class, String> columnType;
static {
columnType = new HashMap<>();
columnType.put(Integer.class, "INT");
columnType.put(Long.class, "BIGINT");
columnType.put(String.class, "VARCHAR");
columnType.put(Date.class, "DATETIME");
columnType.put(Boolean.class, "TINYINT");
columnType.put(Double.class, "DOUBLE");
}
@Override
public String showTableSQL() {
return "show tables";
}
@Override
public String getDialect() {
return "MySQL";
}
@Override
public Map<Class, String> getColumnType() {
return columnType;
}
@Override
public Map<String, String> trainToDBSchema(List<SqlTableObject> sqlTableObjects) {
Map<String, String> tableMap = new HashMap<>(sqlTableObjects.size());
for (SqlTableObject obj : sqlTableObjects) {
StringBuilder builder = new StringBuilder();
builder.append("CREATE TABLE IF NOT EXISTS `" + obj.getTableName() + "` (").append("\n");
for (int i = 0; i < obj.getColumns().size(); i++) {
SqlTableObject.Column column = obj.getColumns().get(i);
builder.append("`").append(column.getName()).append("` ");
if (column.getType().equals("VARCHAR")) {
if (column.getLength() == null) {
builder.append(column.getType()).append("(").append(255).append(")");
} else {
builder.append(column.getType()).append("(").append(column.getLength()).append(")");
}
} else {
builder.append(column.getType().toUpperCase());
}
if (column.isPrimaryKey()) {
builder.append(" PRIMARY KEY");
if (column.isAutoIncrement()) {
builder.append(" AUTO_INCREMENT");
}
}
if (!column.isNullable()) {
builder.append(" NOT NULL");
}
if(column.isUnique()){
builder.append(" UNIQUE");
}
if (i < obj.getColumns().size() - 1) {
builder.append(",");
}
builder.append("\n");
}
builder.append(") ENGINE=InnoDB DEFAULT CHARSET=utf8 ;").append("\n\n");
if (tableMap.containsKey(obj.getTableName())) {
throw new RuntimeException(obj.getTableName() + " 表名重复。");
} else {
tableMap.put(obj.getTableName(), builder.toString());
}
}
return tableMap;
}
}
Step4: 编写扫描注册器
package com.oldguy.example.modules.common.services;
import com.oldguy.example.modules.common.annotation.AssociateEntity;
import com.oldguy.example.modules.common.annotation.Entity;
import com.oldguy.example.modules.common.dao.entities.SqlTableObject;
import com.oldguy.example.modules.common.services.impls.MySQLTableFactory;
import com.oldguy.example.modules.common.utils.ClassUtils;
import org.springframework.util.StringUtils;
import javax.persistence.Column;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
import java.lang.reflect.Field;
import java.util.*;
/**
* @Description: 数据表注册器
* @Author: ren
* @CreateTime: 2018-10-2018/10/16 0016 10:33
*/
public class DbRegister {
private TableFactory tableFactory;
public DbRegister() {
this.tableFactory = new MySQLTableFactory();
}
public DbRegister(TableFactory tableFactory) {
this.tableFactory = tableFactory;
}
/**
* 编写数据库配置文件
*
* @param packageNames
*/
public Map<String, String> registerClassToDB(String... packageNames) {
if (packageNames.length == 0) {
return Collections.emptyMap();
}
List<Class> classList = new ArrayList<>();
for (String packageName : packageNames) {
classList.addAll(ClassUtils.getClasses(packageName));
}
List<SqlTableObject> sqlTableObjects = new ArrayList<>();
for (Class clazz : classList) {
if (clazz.isAnnotationPresent(Entity.class) || clazz.isAnnotationPresent(AssociateEntity.class)) {
SqlTableObject obj = new SqlTableObject();
String tableName = "";
String preIndex = "";
if (clazz.isAnnotationPresent(Entity.class)) {
Entity annotation = (Entity) clazz.getAnnotation(Entity.class);
tableName = annotation.name();
preIndex = annotation.pre();
} else if(clazz.isAnnotationPresent(AssociateEntity.class)){
AssociateEntity annotation = (AssociateEntity) clazz.getAnnotation(AssociateEntity.class);
tableName = annotation.name();
preIndex = annotation.pre();
}
tableName = StringUtils.isEmpty(tableName) ? preIndex + formatTableName(clazz.getSimpleName()) : tableName;
obj.setTableName(tableName);
// 配置字段
List<Field> fields = new ArrayList<>();
getAllField(clazz, fields);
setTableColumns(obj, fields);
sqlTableObjects.add(obj);
}
}
//转换成为SQL Schema
return trainToDBSchema(sqlTableObjects);
}
/**
* 转换成为SQLSchema
*
* @param sqlTableObjects
*/
private Map<String, String> trainToDBSchema(List<SqlTableObject> sqlTableObjects) {
if (null == tableFactory) {
throw new RuntimeException("TableFactory 不能为空!");
}
return tableFactory.trainToDBSchema(sqlTableObjects);
}
/**
* 设置表格字段
*
* @param obj
* @param fields
*/
private void setTableColumns(SqlTableObject obj, List<Field> fields) {
List<SqlTableObject.Column> columnList = new ArrayList<>();
for (Field field : fields) {
if (tableFactory.getColumnType().containsKey(field.getType())) {
SqlTableObject.Column column = new SqlTableObject.Column();
if (field.isAnnotationPresent(Id.class)) {
column.setPrimaryKey(true);
if (field.isAnnotationPresent(GeneratedValue.class)) {
GeneratedValue annotation = field.getAnnotation(GeneratedValue.class);
if (annotation.strategy().equals(GenerationType.AUTO)) {
column.setAutoIncrement(true);
}
}
}
if (field.isAnnotationPresent(Column.class)) {
Column annotation = field.getAnnotation(Column.class);
if (!StringUtils.isEmpty(annotation.name())) {
column.setName(annotation.name());
} else {
column.setName(formatTableName(field.getName()));
}
if (!StringUtils.isEmpty(annotation.columnDefinition())) {
column.setType(annotation.columnDefinition());
} else {
column.setType(tableFactory.getColumnType().get(field.getType()));
}
column.setLength(annotation.length());
column.setUnique(annotation.unique());
column.setNullable(annotation.nullable());
} else {
column.setName(formatTableName(field.getName()));
column.setType(tableFactory.getColumnType().get(field.getType()));
}
columnList.add(column);
}
}
obj.setColumns(columnList);
}
/**
* 获取所有的 字段
*
* @param clazz
* @param fields
*/
private static void getAllField(Class clazz, List<Field> fields) {
fields.addAll(Arrays.asList(clazz.getDeclaredFields()));
if (!clazz.getSuperclass().equals(Object.class)) {
getAllField(clazz.getSuperclass(), fields);
}
}
/**
* 驼峰转双峰
*
* @param name
* @return
*/
public static String formatTableName(String name) {
StringBuilder formatResult = new StringBuilder();
char[] upperCaseArrays = name.toUpperCase().toCharArray();
char[] defaultArrays = name.toCharArray();
for (int i = 0; i < upperCaseArrays.length; i++) {
if (i == 0) {
formatResult.append(String.valueOf(defaultArrays[0]).toLowerCase());
continue;
}
if (defaultArrays[i] == upperCaseArrays[i]) {
formatResult.append("_" + String.valueOf(defaultArrays[i]).toLowerCase());
} else {
formatResult.append(defaultArrays[i]);
}
}
return formatResult.toString();
}
public void setTableFactory(TableFactory tableFactory) {
this.tableFactory = tableFactory;
}
public TableFactory getTableFactory() {
return tableFactory;
}
}
Step5: 编写Configuration注册类
package com.oldguy.example.configs;
import com.oldguy.example.modules.common.services.DbRegister;
import com.oldguy.example.modules.common.utils.Log4jUtils;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.util.StringUtils;
import javax.annotation.Resource;
import javax.sql.DataSource;
import java.util.*;
/**
* @Description: 数据库注册类
* @author ren
* @date 2018/12/20
*/
public class DbRegisterConfiguration {
/**
* 初始化数据库
*/
public void initDB(JdbcTemplate jdbcTemplate,String typeAliasesPackage) {
DbRegister dbRegister = new DbRegister();
Map<String, String> tableMap = new HashMap<>();
List<String> typeAliasesPackages = splitPackagesPath(typeAliasesPackage);
for (String path : typeAliasesPackages) {
tableMap.putAll(dbRegister.registerClassToDB(path));
}
if (!tableMap.keySet().isEmpty()) {
List<Map<String, Object>> mapList = jdbcTemplate.queryForList(dbRegister.getTableFactory().showTableSQL());
Set<String> tableNameSet = new HashSet<>();
for (Map<String, Object> item : mapList) {
for (String key : item.keySet()) {
tableNameSet.add((String) item.get(key));
}
}
for (String key : tableMap.keySet()) {
if (!tableNameSet.contains(key)) {
Log4jUtils.getInstance(getClass()).info("未找到表[" + key + "],进行创建.");
String sql = tableMap.get(key);
if (sql.trim().length() > 0) {
jdbcTemplate.execute(sql);
Log4jUtils.getInstance(getClass()).info("\n\n" + sql);
}
} else {
Log4jUtils.getInstance(getClass()).info("表[" + key + "] 已存在");
}
}
}
}
private List<String> splitPackagesPath(String typeAliasesPackage) {
List<String> paths = new ArrayList<>();
String[] packagePaths = typeAliasesPackage.split(";");
for (String path : packagePaths) {
if (!StringUtils.isEmpty(path)) {
paths.add(path);
}
}
return paths;
}
}
以上模板基本构建完成。
Step6: 开始配置;进行数据库注册,可以注册多个数据源的数据库
- 注册类
package com.oldguy.example.configs;
import com.oldguy.example.modules.common.utils.Log4jUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.sql.DataSource;
import java.io.IOException;
/**
* @author ren
* @date 2018/12/20
*/
@Configuration
public class DemoConfiguration {
@Resource(name = "test1DataSource")
private DataSource test1dataSource;
@Value("${test1.mybatis.type-aliases-package}")
private String test1TypeAliasesPackage;
@Resource(name = "test2DataSource")
private DataSource test2dataSource;
@Value("${test2.mybatis.type-aliases-package}")
private String test2TypeAliasesPackage;
@PostConstruct
public void initData() throws IOException {
Log4jUtils.getInstance(getClass()).info("初始化 数据库 -------------------------------");
DbRegisterConfiguration dbConfiguration = new DbRegisterConfiguration();
// 自动生成表结构 test1
JdbcTemplate jdbcTemplate = new JdbcTemplate(test1dataSource);
dbConfiguration.initDB(jdbcTemplate,test1TypeAliasesPackage);
// 自动生成表结构 test2
jdbcTemplate = new JdbcTemplate(test2dataSource);
dbConfiguration.initDB(jdbcTemplate,test2TypeAliasesPackage);
}
}
2.注册数据源
package com.oldguy.example.configs;
import com.alibaba.druid.pool.DruidDataSource;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ResourceLoader;
/**
* @author ren
* @date 2018/12/20
*/
@Configuration
public class Test1DataSourceConfiguration extends AbstractMybatisConfiguration {
@Bean(name = "test1DataSource")
@ConfigurationProperties(prefix = "test1.datasource")
public DruidDataSource test1DataSource() {
return new DruidDataSource();
}
@Bean(name = "test2DataSource")
@ConfigurationProperties(prefix = "test2.datasource")
public DruidDataSource test2DataSource() {
return new DruidDataSource();
}
}
3.yaml配置文件
test1:
datasource:
username: root
password: root
url: jdbc:mysql://127.0.0.1:3306/multiple_datasource1?useUnicode=true&characterEncoding=utf8&useSSL=false&allowMultiQueries=true
driver-class-name: com.mysql.jdbc.Driver
type: com.alibaba.druid.pool.DruidDataSource
mybatis:
mapper-locations: classpath:mappers/test1/*.xml
type-aliases-package: com.oldguy.example.modules.test1.dao.entities;
config-location: classpath:configs/myBatis-config.xml
test2:
datasource:
username: root
password: root
url: jdbc:mysql://127.0.0.1:3306/multiple_datasource2?useUnicode=true&characterEncoding=utf8&useSSL=false&allowMultiQueries=true
driver-class-name: com.mysql.jdbc.Driver
type: com.alibaba.druid.pool.DruidDataSource
mybatis:
mapper-locations: classpath:mappers/test2/*.xml
type-aliases-package: com.oldguy.example.modules.test2.dao.entities;
config-location: classpath:configs/myBatis-config.xml
这样就可以完整 自动建表配置类。
引用: 包扫描类
package com.oldguy.example.modules.common.utils;/**
* Created by Administrator on 2018/10/16 0016.
*/
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
/**
* @author ren
* @date 2018/12/20
*/
public class ClassUtils {
/**
* 通过包名获取包内所有类
*
* @param pkg
* @return
*/
public static List<Class<?>> getAllClassByPackageName(Package pkg) {
String packageName = pkg.getName();
// 获取当前包下以及子包下所以的类
List<Class<?>> returnClassList = getClasses(packageName);
return returnClassList;
}
/**
* 通过接口名取得某个接口下所有实现这个接口的类
*/
public static List<Class<?>> getAllClassByInterface(Class<?> c) {
List<Class<?>> returnClassList = null;
if (c.isInterface()) {
// 获取当前的包名
String packageName = c.getPackage().getName();
// 获取当前包下以及子包下所以的类
List<Class<?>> allClass = getClasses(packageName);
if (allClass != null) {
returnClassList = new ArrayList<Class<?>>();
for (Class<?> cls : allClass) {
// 判断是否是同一个接口
if (c.isAssignableFrom(cls)) {
// 本身不加入进去
if (!c.equals(cls)) {
returnClassList.add(cls);
}
}
}
}
}
return returnClassList;
}
/**
* 取得某一类所在包的所有类名 不含迭代
*/
public static String[] getPackageAllClassName(String classLocation, String packageName) {
// 将packageName分解
String[] packagePathSplit = packageName.split("[.]");
String realClassLocation = classLocation;
int packageLength = packagePathSplit.length;
for (int i = 0; i < packageLength; i++) {
realClassLocation = realClassLocation + File.separator + packagePathSplit[i];
}
File packeageDir = new File(realClassLocation);
if (packeageDir.isDirectory()) {
String[] allClassName = packeageDir.list();
return allClassName;
}
return null;
}
/**
* 从包package中获取所有的Class
*
* @param packageName
* @return
*/
public static List<Class<?>> getClasses(String packageName) {
// 第一个class类的集合
List<Class<?>> classes = new ArrayList<>();
// 是否循环迭代
boolean recursive = true;
// 获取包的名字 并进行替换
String packageDirName = packageName.replace('.', '/');
// 定义一个枚举的集合 并进行循环来处理这个目录下的things
Enumeration<URL> dirs;
try {
dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
// 循环迭代下去
while (dirs.hasMoreElements()) {
// 获取下一个元素
URL url = dirs.nextElement();
// 得到协议的名称
String protocol = url.getProtocol();
// 如果是以文件的形式保存在服务器上
if ("file".equals(protocol)) {
// 获取包的物理路径
String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
// 以文件的方式扫描整个包下的文件 并添加到集合中
findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
} else if ("jar".equals(protocol)) {
// 如果是jar包文件
// 定义一个JarFile
JarFile jar;
try {
// 获取jar
jar = ((JarURLConnection) url.openConnection()).getJarFile();
// 从此jar包 得到一个枚举类
Enumeration<JarEntry> entries = jar.entries();
// 同样的进行循环迭代
while (entries.hasMoreElements()) {
// 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件
JarEntry entry = entries.nextElement();
String name = entry.getName();
// 如果是以/开头的
if (name.charAt(0) == '/') {
// 获取后面的字符串
name = name.substring(1);
}
// 如果前半部分和定义的包名相同
if (name.startsWith(packageDirName)) {
int idx = name.lastIndexOf('/');
// 如果以"/"结尾 是一个包
if (idx != -1) {
// 获取包名 把"/"替换成"."
packageName = name.substring(0, idx).replace('/', '.');
}
// 如果可以迭代下去 并且是一个包
if ((idx != -1) || recursive) {
// 如果是一个.class文件 而且不是目录
if (name.endsWith(".class") && !entry.isDirectory()) {
// 去掉后面的".class" 获取真正的类名
String className = name.substring(packageName.length() + 1, name.length() - 6);
try {
// 添加到classes
classes.add(Class.forName(packageName + '.' + className));
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
} catch (IOException e) {
e.printStackTrace();
}
return classes;
}
/**
* 以文件的形式来获取包下的所有Class
*
* @param packageName
* @param packagePath
* @param recursive
* @param classes
*/
private static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive, List<Class<?>> classes) {
// 获取此包的目录 建立一个File
File dir = new File(packagePath);
// 如果不存在或者 也不是目录就直接返回
if (!dir.exists() || !dir.isDirectory()) {
return;
}
// 如果存在 就获取包下的所有文件 包括目录
File[] dirfiles = dir.listFiles(new FileFilter() {
// 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
@Override
public boolean accept(File file) {
return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
}
});
// 循环所有文件
for (File file : dirfiles) {
// 如果是目录 则继续扫描
if (file.isDirectory()) {
findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
} else {
// 如果是java类文件 去掉后面的.class 只留下类名
String className = file.getName().substring(0, file.getName().length() - 6);
try {
// 添加到集合中去
classes.add(Class.forName(packageName + '.' + className));
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
}
}
到此完成了 多数据源 自动建表模板搭建。
代码可以参考 GitHub https://github.com/oldguys/MultipleDataSourceDemo