golang xorm mysql代码生成器

2021-01-21  本文已影响0人  EasyNetCN

原来的代码生成器都是用java写的,心血来潮,把其中的生产golang数据层的代码,用golang实现了一下。

主体代码

package main

import (
    "flag"
    "fmt"
    "os"
    "strings"
    "text/template"
    "unsafe"

    _ "github.com/go-sql-driver/mysql"
    "xorm.io/xorm"
)

const (
    selectCurrentDbSql = "SELECT DATABASE()"
    allColumnInfoSql   = "SELECT * FROM information_schema.columns WHERE table_schema =? ORDER BY table_schema ASC,table_name ASC,ordinal_position ASC"
)

func main() {
    db := flag.String("db", "", "database connection string")
    tplName := flag.String("tplname", "repository.tpl", "code template name")
    tpl := flag.String("tpl", "tpl/repository.tpl", "code template file")
    readonly := flag.Bool("readonly", false, "readonly")
    output := flag.String("output", "./repository", "output file path")

    flag.Parse()

    if db == nil || *db == "" {
        fmt.Println("database connection string can not be empty")

        return
    }

    engine, err := xorm.NewEngine("mysql", *db)

    if err != nil {
        fmt.Println("can not create database engine,err:", err)

        return
    }

    currentDb := ""

    if _, err := engine.SQL(selectCurrentDbSql).Get(&currentDb); err != nil {
        fmt.Println("can not get current database,err:", err)

        return
    }

    columns := make([]DataColumn, 0)

    if err := engine.SQL(allColumnInfoSql, currentDb).Find(&columns); err != nil {
        fmt.Println("can not get column information,err:", err)

        return
    }

    tableMap := make(map[string][]DataColumn)

    for _, column := range columns {
        tableName := column.TableName

        if _, ok := tableMap[tableName]; !ok {
            tableMap[tableName] = make([]DataColumn, 0)
        }

        tableMap[tableName] = append(tableMap[tableName], column)
    }

    funcMap := template.FuncMap{"upperCamelCase": UpperCamelCase, "lowerCamelCase": LowerCamelCase}

    t, err := template.New(*tplName).Funcs(funcMap).ParseFiles(*tpl)

    if err != nil {
        fmt.Println("parse file err:", err)
        return
    }

    os.RemoveAll(*output)

    for table, columns := range tableMap {
        if _, err := os.Stat(*output); os.IsNotExist(err) {
            os.Mkdir(*output, 0777)
            os.Chmod(*output, 0777)
        }

        f, err := os.OpenFile(*output+"/"+table+"_repository.go", os.O_CREATE|os.O_WRONLY, 0666)

        defer f.Close()

        if err != nil {
            fmt.Println("can not create output file,err:", err)

            return
        }

        if err := t.Execute(f, Config{TableName: table, Readonly: *readonly, Columns: columns}); err != nil {
            fmt.Println("There was an error:", err.Error())
        }
    }

}

type Config struct {
    TableName string
    Readonly  bool
    Columns   []DataColumn
}

type DataColumn struct {
    TableSchema            string
    TableName              string
    ColumnName             string
    OrdinalPosition        int
    ColumnDefault          string
    IsNullable             string
    DataType               string
    CharacterMaximumLength string
    CharacterOctetLength   string
    NumericPrecision       string
    NumbericScale          string
    DatetimePrecision      string
    ColumnType             string
    ColumnKey              string
    Extra                  string
    ColumnComment          string
}

func (c *DataColumn) GoLangType() string {
    dataType := strings.ToLower(c.DataType)
    nullable := strings.ToLower(c.IsNullable) == "yes"

    if dataType == "int" {
        if nullable {
            return "*int"
        }

        return "int"
    }

    if dataType == "varchar" || dataType == "text" || dataType == "longtext" {
        if nullable {
            return "*string"
        }

        return "string"
    }

    if dataType == "long" || dataType == "bigint" {
        if nullable {
            return "*int64"
        }

        return "int64"
    }

    if dataType == "decimal" {
        if nullable {
            return "*float64"
        }

        return "ifloat64"
    }

    if dataType == "datetime" {
        if nullable {
            return "*time.Time"
        }

        return "time.Time"
    }

    return dataType
}

func (c *DataColumn) Tag() string {
    name := strings.ToLower(c.ColumnName)
    dataType := strings.ToLower(c.DataType)
    identity := strings.ToLower(c.Extra) == "auto_increment"
    primary := strings.ToLower(c.ColumnKey) == "PRI"
    nullable := strings.ToLower(c.IsNullable) == "yes"

    sb := new(strings.Builder)

    sb.WriteString("`xorm:\"")
    sb.WriteString(dataType)
    sb.WriteString(" '")
    sb.WriteString(name)
    sb.WriteString("'")

    if identity {
        sb.WriteString(" autoincr")
    }

    if primary {
        sb.WriteString(" pk")
    }

    if nullable {
        sb.WriteString(" null")
    } else {
        sb.WriteString(" notnull")
    }

    sb.WriteString(" default(")

    if dataType == "varchar" || dataType == "text" || dataType == "longtext" {
        sb.WriteString("'")
    }

    sb.WriteString(c.ColumnDefault)

    if dataType == "varchar" || dataType == "text" || dataType == "longtext" {
        sb.WriteString("'")
    }

    sb.WriteString(")")

    sb.WriteString(" comment('")
    sb.WriteString(c.ColumnComment)
    sb.WriteString("')")

    sb.WriteString("\" json:\"")

    if name == "del_status" {
        sb.WriteString("-")
    } else {
        sb.WriteString(LowerCamelCase(c.ColumnName))
    }

    sb.WriteString("\"`")

    return sb.String()
}

func UpperCamelCase(txt string) string {
    sb := new(strings.Builder)

    strs := strings.Split(txt, "_")

    for _, str := range strs {
        sb.WriteString(strings.ToUpper(string(str[0])))
        sb.WriteString(str[1:])
    }

    return sb.String()
}

func LowerCamelCase(txt string) string {
    sb := new(strings.Builder)

    strs := strings.Split(txt, "_")

    for _, str := range strs {
        sb.WriteString(strings.ToLower(string(str[0])))
        sb.WriteString(str[1:])
    }

    return sb.String()
}

func BytesToString(b []byte) string {
    return *(*string)(unsafe.Pointer(&b))
}

模板文件,tpl/repository.tpl

package repository

type {{upperCamelCase .TableName}} struct{
 {{ range $column := .Columns}}
 {{upperCamelCase .ColumnName }} {{.GoLangType}} {{.Tag}} 
 {{ end }}
}

{{if not .Readonly}}
var(
{{lowerCamelCase .TableName}}Columns=[]string{
         {{ range $column := .Columns}}
 "{{.ColumnName }}",
 {{ end }}
    }
    
{{lowerCamelCase .TableName}}ColumnMap=map[string]string{
         {{ range $column := .Columns}}
 "{{.ColumnName }}":"{{.ColumnName }}",
 {{ end }}
    }
)
{{end}}
type {{upperCamelCase .TableName}}Repository interface {
    {{if not .Readonly}}
    Create(entity {{upperCamelCase .TableName}}) ({{upperCamelCase .TableName}}, error)
    {{end}}
    {{if not .Readonly}}
    CreateBySession(session *xorm.Session, entity {{upperCamelCase .TableName}}) ({{upperCamelCase .TableName}}, error)
    {{end}}
    {{if not .Readonly}}
    DeleteById(id int64) (int64, error)
    {{end}}
    {{if not .Readonly}}
    DeleteBySessionAndId(session *xorm.Session, id int64) (int64, error)
    {{end}}
    {{if not .Readonly}}
    Update(entity {{upperCamelCase .TableName}}, columns []string) ({{upperCamelCase .TableName}}, error)
    {{end}}
    {{if not .Readonly}}
    UpdateBySession(session *xorm.Session, entity {{upperCamelCase .TableName}}, columns []string) ({{upperCamelCase .TableName}}, error)
    {{end}}
    FindById(id int64) ({{upperCamelCase .TableName}}, error)
    {{if not .Readonly}}
    FindBySessionAndId(session *xorm.Session, id int64) ({{upperCamelCase .TableName}}, error)
    {{end}}
}

type {{lowerCamelCase .TableName}}Repository  struct {
    engine *xorm.Engine
}

func New{{upperCamelCase .TableName}}Repository(engine *xorm.Engine) {{upperCamelCase .TableName}}Repository {
    return &{{lowerCamelCase .TableName}}Repository{
        engine: engine,
    }
}

{{if not .Readonly}}
func (r *{{lowerCamelCase .TableName}}Repository) Create(entity {{upperCamelCase .TableName}}) ({{upperCamelCase .TableName}}, error) {
    _, err := r.engine.Insert(&entity)

    return entity, err
}
{{end}}

{{if not .Readonly}}
func (r *{{lowerCamelCase .TableName}}Repository) CreateBySession(session *xorm.Session, entity {{upperCamelCase .TableName}}) ({{upperCamelCase .TableName}}, error) {
    _, err := session.Insert(&entity)

    return entity, err
}
{{end}}

{{if not .Readonly}}
func (r *{{lowerCamelCase .TableName}}Repository) DeleteById(id int64) (int64, error) {
    return r.engine.ID(id).Cols("del_status", "update_time").Update(&{{upperCamelCase .TableName}}{DelStatus: 1, UpdateTime: time.Now()})
}
{{end}}

{{if not .Readonly}}
func (r *{{lowerCamelCase .TableName}}Repository) DeleteBySessionAndId(session *xorm.Session, id int64) (int64, error) {
    return session.ID(id).Cols("del_status", "update_time").Update(&{{upperCamelCase .TableName}}{DelStatus: 1, UpdateTime: time.Now()})
}
{{end}}

{{if not .Readonly}}
func (r *{{lowerCamelCase .TableName}}Repository) Update(entity {{upperCamelCase .TableName}},columns []string) ({{upperCamelCase .TableName}}, error) {
    _, err := r.engine.ID(entity.Id).Cols(columns...).Update(&entity)

    return entity, err
}
{{end}}

{{if not .Readonly}}
func (r *{{lowerCamelCase .TableName}}Repository) UpdateBySession(session *xorm.Session, entity {{upperCamelCase .TableName}},columns []string) ({{upperCamelCase .TableName}}, error) {
    _, err := session.ID(entity.Id).Cols(columns...).Update(&entity)

    return entity, err
}
{{end}}

func (r *{{lowerCamelCase .TableName}}Repository) FindById(id int64) ({{upperCamelCase .TableName}}, error) {
    entity := new({{upperCamelCase .TableName}})

    _, err := r.engine.ID(id).Where("del_status=0").Get(entity)

    return *entity, err
}

{{if not .Readonly}}
func (r *{{lowerCamelCase .TableName}}Repository) FindBySessionAndId(session *xorm.Session, id int64) ({{upperCamelCase .TableName}}, error) {
    entity := new({{upperCamelCase .TableName}})

    _, err := session.ID(id).Where("del_status=0").Get(entity)

    return *entity, err
}
{{end}}
上一篇 下一篇

猜你喜欢

热点阅读