03GORM源码解读
简介
GORM 源码解读, 基于 v1.9.11 版本.
模型交互
前面已经研究过模型是如何定义并被解析的了, 这次看一下模型是如何和数据库交互的.
package main
import (
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
type Product struct {
gorm.Model
Code string
Price uint
}
func main() {
db, err := gorm.Open("sqlite3", "test.db")
if err != nil {
panic("failed to connect database")
}
defer db.Close()
// Migrate the schema
db.AutoMigrate(&Product{})
// 创建
db.Create(&Product{Code: "L1212", Price: 1000})
// 读取
var product Product
db.First(&product, 1) // 查询id为1的product
db.First(&product, "code = ?", "L1212") // 查询code为l1212的product
// 更新 - 更新product的price为2000
db.Model(&product).Update("Price", 2000)
// 删除 - 删除product
db.Delete(&product)
}
AutoMigrate
当定义好模型之后, 第一步是使用 AutoMigrate
合并模型:
db.AutoMigrate(&Product{})
看一下它的源码:
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
func (s *DB) AutoMigrate(values ...interface{}) *DB {
db := s.Unscoped()
for _, value := range values {
db = db.NewScope(value).autoMigrate().db
}
return db
}
内部是对每个传递的参数调用了 db.NewScope(value).autoMigrate()
.
那具体是如何合并的呢?
func (scope *Scope) autoMigrate() *Scope {
tableName := scope.TableName()
quotedTableName := scope.QuotedTableName()
if !scope.Dialect().HasTable(tableName) {
scope.createTable()
} else {
for _, field := range scope.GetModelStruct().StructFields {
if !scope.Dialect().HasColumn(tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.Dialect().DataTypeOf(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
}
}
scope.createJoinTable(field)
}
scope.autoIndex()
}
return scope
}
中间的 if 部分的代码展示了两条路径. 如果表还没有创建, 直接创建就行了.
否则就需要对模型中的每个字段进行操作, 如果列名不存在, 就需要变更表新增字段了.
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
SQL 语句是如何执行的, 先暂时不理会, 但从代码的形式上看算是挺简洁的, 直接使用 Raw 构造语句, Exec 执行.
同时, 对于模型中的每个字段, 还要更新一遍连接表, scope.createJoinTable(field)
.
在 for 循环处理完模型中的所有字段后, 再更新一遍索引, scope.autoIndex()
.
总结起来, 自动合并主要做了这么几件事: 创建表, 添加新增的字段, 更新表的关系, 更新索引.
createTable
前面省略了创建表的具体过程, 来仔细看看表是如何创建的.
func (scope *Scope) createTable() *Scope {
var tags []string
var primaryKeys []string
var primaryKeyInColumnType = false
for _, field := range scope.GetModelStruct().StructFields {
if field.IsNormal {
sqlTag := scope.Dialect().DataTypeOf(field)
// Check if the primary key constraint was specified as
// part of the column type. If so, we can only support
// one column as the primary key.
if strings.Contains(strings.ToLower(sqlTag), "primary key") {
primaryKeyInColumnType = true
}
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
}
if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
}
scope.createJoinTable(field)
}
var primaryKeyStr string
if len(primaryKeys) > 0 && !primaryKeyInColumnType {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
scope.autoIndex()
return scope
}
这就是构建 SQL 创建表的过程, 主要的过程是这行代码:
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
前面的过程主要是遍历模型的字段, 获取每个字段的 sqlTag
, 并加入 tags 中:
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
带有双引号的列名加上空格加上 sqlTag
.
这个过程中还涉及到了主键的判断, 不过感觉这部分有点坑, 因为
sqlTag := scope.Dialect().DataTypeOf(field)
的实现取决于每种数据库对 DataTypeOf
的具体实现.
issues 2270 显示出现多个 primary key
,
使用的是如下的模型定义, 数据库使用了 sqlite3:
type Permission struct {
ID int64 `gorm:"AUTO_INCREMENT;column:id;primary_key"`
Name string `gorm:"column:name;type:varchar;unique;not null"`
Idx int64 `gorm:"AUTO_INCREMENT"`
}
虽然这个模型定义中只指定了一个 primary_key
, 但结果 Idx
也变成了 primary_key
:
[2019-01-19 19:40:30] table "permission" has more than one primary key
[2019-01-19 19:40:30] [0.14ms] CREATE TABLE "permission" ("id" integer primary key autoincrement,"name" varchar NOT NULL UNIQUE,"idx" integer primary key autoincrement )
[0 rows affected or returned ]
原因只有一个, 它使用了 AUTO_INCREMENT
选项, 而在 sqlite3 的 DataTypeOf
实现中:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
}
AUTO_INCREMENT
选项导致了返回的结果中存在 primary key
.
我怀疑这是个 bug. 因为在后续有对是否是主键的判断, 并添加 primaryKeyStr
.
if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
}
var primaryKeyStr string
if len(primaryKeys) > 0 && !primaryKeyInColumnType {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
}
我觉得 sqlType
不应该返回关于 primary key
的信息.
要设置主键, 可以在后面的 primaryKeyStr
中进行.
好了, 对于主键的讨论就此告一段落了.
合并表和创建表的过程中都有 createJoinTable
, 但因为关系实现还没有深入研究, 先忽略吧.
callbacks
增删改查都和 DB 结构体中的 callbacks
有关:
// DB contains information for current db connection
type DB struct {
...
// global db
parent *DB
callbacks *Callback
dialect Dialect
singularTable bool
...
}
看一下 Create 方法的代码:
// Create insert the value into database
func (s *DB) Create(value interface{}) *DB {
scope := s.NewScope(value)
return scope.callCallbacks(s.parent.callbacks.creates).db
}
在新的 scope 中调用了 callCallbacks
方法, 里面的参数是 s.parent.callbacks.creates
.
parent
的类型也是 *DB
, 算是继承.
继续挖掘 callCallbacks
:
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
defer func() {
if err := recover(); err != nil {
if db, ok := scope.db.db.(sqlTx); ok {
db.Rollback()
}
panic(err)
}
}()
for _, f := range funcs {
(*f)(scope)
if scope.skipLeft {
break
}
}
return scope
}
使用了 defer 下的 recover 模式, 以前介绍过这个模式, 不再深入.
callCallbacks
的参数其实是个函数的切片, 然后依次调用所有的函数, 除非 scope.skipLeft
为 true.
看过了调用的方式, 让我们来看看 Callback
到底是什么.
// Callback is a struct that contains all CRUD callbacks
// Field `creates` contains callbacks will be call when creating object
// Field `updates` contains callbacks will be call when updating object
// Field `deletes` contains callbacks will be call when deleting object
// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
type Callback struct {
logger logger
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
processors []*CallbackProcessor
}
Callback
里包含了很多的函数切片, 用于增删改查. 注释已经解释的很清楚了.
关注一下 CallbackProcessor
, 这是用于按序生成所有 callbacks 的.
// CallbackProcessor contains callback informations
type CallbackProcessor struct {
logger logger
name string // current callback's name
before string // register current callback before a callback
after string // register current callback after a callback
replace bool // replace callbacks with same name
remove bool // delete callbacks with same name
kind string // callback type: create, update, delete, query, row_query
processor *func(scope *Scope) // callback handler
parent *Callback
}
// Create could be used to register callbacks for creating object
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
// // business logic
// ...
//
// // set error if some thing wrong happened, will rollback the creating
// scope.Err(errors.New("error"))
// })
func (c *Callback) Create() *CallbackProcessor {
return &CallbackProcessor{logger: c.logger, kind: "create", parent: c}
}
// Update could be used to register callbacks for updating object, refer `Create` for usage
func (c *Callback) Update() *CallbackProcessor {
return &CallbackProcessor{logger: c.logger, kind: "update", parent: c}
}
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
func (c *Callback) Delete() *CallbackProcessor {
return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c}
}
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
// Refer `Create` for usage
func (c *Callback) Query() *CallbackProcessor {
return &CallbackProcessor{logger: c.logger, kind: "query", parent: c}
}
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
func (c *Callback) RowQuery() *CallbackProcessor {
return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c}
}
Callback
有各种方法来创建不同类型的 CallbackProcessor
.
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
cp.after = callbackName
return cp
}
// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
cp.before = callbackName
return cp
}
After
和 Before
更新了 CallbackProcessor
上特定的属性, 用于后续计算 callback 调用顺序.
db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
// business logic
...
// set error if some thing wrong happened, will rollback the creating
scope.Err(errors.New("error"))
})
注释上的例子是这样的, 继续看 Register
方法.
// Register a new callback, refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName))
cp.before = "gorm:row_query"
}
}
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
主要是设置了 cp 的 processor
属性, 并将该 cp 添加到了 cp.parent.processors
中.
然后调用 cp.parent.reorder()
进行了重新排序.
有注册方法, 当然也有对应的删除方法:
// Remove a registered callback
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
func (cp *CallbackProcessor) Remove(callbackName string) {
cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.remove = true
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
设置 remove
属性为 true, 然后重新排序.
替换的方法也是类似:
// Replace a registered callback with new callback
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
// scope.SetColumn("Created", now)
// scope.SetColumn("Updated", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
cp.replace = true
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
还是看一下重新排序是如何进行的吧:
// reorder all registered processors, and reset CRUD callbacks
func (c *Callback) reorder() {
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
for _, processor := range c.processors {
if processor.name != "" {
switch processor.kind {
case "create":
creates = append(creates, processor)
case "update":
updates = append(updates, processor)
case "delete":
deletes = append(deletes, processor)
case "query":
queries = append(queries, processor)
case "row_query":
rowQueries = append(rowQueries, processor)
}
}
}
c.creates = sortProcessors(creates)
c.updates = sortProcessors(updates)
c.deletes = sortProcessors(deletes)
c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries)
}
上半部分只是分别归类, 具体还是要看 sortProcessors
:
// sortProcessors sort callback processors based on its before, after, remove, replace
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
var (
allNames, sortedNames []string
sortCallbackProcessor func(c *CallbackProcessor)
)
for _, cp := range cps {
// show warning message the callback name already exists
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()))
}
allNames = append(allNames, cp.name)
}
sortCallbackProcessor = func(c *CallbackProcessor) {
if getRIndex(sortedNames, c.name) == -1 { // if not sorted
if c.before != "" { // if defined before callback
if index := getRIndex(sortedNames, c.before); index != -1 {
// if before callback already sorted, append current callback just after it
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
} else if index := getRIndex(allNames, c.before); index != -1 {
// if before callback exists but haven't sorted, append current callback to last
sortedNames = append(sortedNames, c.name)
sortCallbackProcessor(cps[index])
}
}
if c.after != "" { // if defined after callback
if index := getRIndex(sortedNames, c.after); index != -1 {
// if after callback already sorted, append current callback just before it
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
} else if index := getRIndex(allNames, c.after); index != -1 {
// if after callback exists but haven't sorted
cp := cps[index]
// set after callback's before callback to current callback
if cp.before == "" {
cp.before = c.name
}
sortCallbackProcessor(cp)
}
}
// if current callback haven't been sorted, append it to last
if getRIndex(sortedNames, c.name) == -1 {
sortedNames = append(sortedNames, c.name)
}
}
}
for _, cp := range cps {
sortCallbackProcessor(cp)
}
var sortedFuncs []*func(scope *Scope)
for _, name := range sortedNames {
if index := getRIndex(allNames, name); !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor)
}
}
return sortedFuncs
}
首先获取了所有 cp 的名字, 同时提示是否发现了重复. sortedNames
里保存排序好的名字.
// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
getRIndex
获取最右边的索引.
看一下 sortCallbackProcessor
函数到底在做什么.
里面有两个判断部分, 先看第一个部分:
if c.before != "" { // if defined before callback
if index := getRIndex(sortedNames, c.before); index != -1 {
// if before callback already sorted, append current callback just after it
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
} else if index := getRIndex(allNames, c.before); index != -1 {
// if before callback exists but haven't sorted, append current callback to last
sortedNames = append(sortedNames, c.name)
sortCallbackProcessor(cps[index])
}
}
分为两种情况, 如果 before callback 已经排序好了, 直接插在它的后面就行.
如果 before callback 确实存在, 但还没有被排序, 就将当前名字直接放在 sortedNames
的最后.
然后递归调用 sortCallbackProcessor(cps[index])
, 这就是直接进入到 before callback 的排序中了.
再看第二个部分:
if c.after != "" { // if defined after callback
if index := getRIndex(sortedNames, c.after); index != -1 {
// if after callback already sorted, append current callback just before it
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
} else if index := getRIndex(allNames, c.after); index != -1 {
// if after callback exists but haven't sorted
cp := cps[index]
// set after callback's before callback to current callback
if cp.before == "" {
cp.before = c.name
}
sortCallbackProcessor(cp)
}
}
其实和前面的逻辑差不多, 如果 after callback 已经排序好了, 直接插在它的前面就行.
如果 after callback 确实存在, 会修改 after callback 的 before 属性, 设置为当前 callback.
然后递归调用 sortCallbackProcessor(cp)
, 进入到 after callback 的排序中.
// if current callback haven't been sorted, append it to last
if getRIndex(sortedNames, c.name) == -1 {
sortedNames = append(sortedNames, c.name)
}
还没保存就直接放到最后. sortCallbackProcessor
的内容就是这样.
for _, cp := range cps {
sortCallbackProcessor(cp)
}
开始排序. 等排序完了之后, sortedNames
就完成了:
var sortedFuncs []*func(scope *Scope)
for _, name := range sortedNames {
if index := getRIndex(allNames, name); !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor)
}
}
return sortedFuncs
将那些不是 remove
状态的 callback, 依次添加到 sortedFuncs
中.
最后还有一个 Get 方法用于获取注册的回调:
// Get registered callback
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind {
if p.remove {
callback = nil
} else {
callback = *p.processor
}
}
}
return
}
现在, 我们应该已经清楚了回调函数是如何注册并排序的了, 以及如何按名称获取单个回调函数.
实际注册流程
前面只是讲解了理论上的定义, 看一下实际上是在哪里注册的.
DB 在初始化的时候, 即 Open
方法调用了如下的语句:
db = &DB{
db: dbSQL,
logger: defaultLogger,
callbacks: DefaultCallback,
dialect: newDialect(dialect, dbSQL),
}
这个 DefaultCallback
的定义如下:
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
一开始我也是有点慌, 这只是个空定义, 肯定有地方初始化的. 扫了一眼目录就明白了.
在 callback_create.go
文件下定义了 create 方面的注册流程.
// Define callbacks for creating
func init() {
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
DefaultCallback.Create().Register("gorm:create", createCallback)
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}
结合文档,
看一下 BeforeSave
和 BeforeCreate
是如何实现的.
当你定义一个模型时, 可以在这个模型上实现 BeforeSave
和 BeforeCreate
之类的方法,
这些方法会在恰当的时候被调用.
func (u *User) BeforeSave() (err error) {
if !u.IsValid() {
err = errors.New("can't save invalid data")
}
return
}
func (u *User) AfterCreate(scope *gorm.Scope) (err error) {
if u.ID == 1 {
scope.DB().Model(u).Update("role", "admin")
}
return
}
上面是官方文档上的例子. 在前面我们在注释中看到了如何手动注册一个回调函数,
类似于 DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
,
但如何实现调用模型上定义的方法呢?
看一下 beforeCreateCallback
函数:
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
func beforeCreateCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("BeforeSave")
}
if !scope.HasError() {
scope.CallMethod("BeforeCreate")
}
}
原来是通过 scope.CallMethod
方法实现的, 传递特定的方法名称就能调用该方法了.
// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
func (scope *Scope) CallMethod(methodName string) {
if scope.Value == nil {
return
}
if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice {
for i := 0; i < indirectScopeValue.Len(); i++ {
scope.callMethod(methodName, indirectScopeValue.Index(i))
}
} else {
scope.callMethod(methodName, indirectScopeValue)
}
}
绕了一圈, 继续看 callMethod
的代码:
func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
// Only get address from non-pointer
if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr {
reflectValue = reflectValue.Addr()
}
if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
switch method := methodValue.Interface().(type) {
case func():
method()
case func(*Scope):
method(scope)
case func(*DB):
newDB := scope.NewDB()
method(newDB)
scope.Err(newDB.Error)
case func() error:
scope.Err(method())
case func(*Scope) error:
scope.Err(method(scope))
case func(*DB) error:
newDB := scope.NewDB()
scope.Err(method(newDB))
scope.Err(newDB.Error)
default:
scope.Err(fmt.Errorf("unsupported function %v", methodName))
}
}
}
这些灵活的方式都是靠反射实现的, 关键代码是 methodValue := reflectValue.MethodByName(methodName)
.
从 switch
可以看到, 方法可以有不同的签名:
switch method := methodValue.Interface().(type) {
case func():
method()
case func(*Scope):
method(scope)
case func(*DB):
newDB := scope.NewDB()
method(newDB)
scope.Err(newDB.Error)
case func() error:
scope.Err(method())
case func(*Scope) error:
scope.Err(method(scope))
case func(*DB) error:
newDB := scope.NewDB()
scope.Err(method(newDB))
scope.Err(newDB.Error)
default:
scope.Err(fmt.Errorf("unsupported function %v", methodName))
}
所以, 实际上这都可以看作是 reflect
的大型示范使用例子.
createCallback
其他的钩子函数不看了, 具体看一下当插入单条数据时都在干什么:
// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
if !scope.HasError() {
defer scope.trace(scope.db.nowFunc())
var (
columns, placeholders []string
blankColumnsWithDefaultValue []string
)
for _, field := range scope.Fields() {
if scope.changeableField(field) {
if field.IsNormal && !field.IsIgnored {
if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
} else if !field.IsPrimaryKey || !field.IsBlank {
columns = append(columns, scope.Quote(field.DBName))
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
}
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
for _, foreignKey := range field.Relationship.ForeignDBNames {
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
columns = append(columns, scope.Quote(foreignField.DBName))
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
}
}
}
}
}
var (
returningColumn = "*"
quotedTableName = scope.QuotedTableName()
primaryField = scope.PrimaryField()
extraOption string
insertModifier string
)
if str, ok := scope.Get("gorm:insert_option"); ok {
extraOption = fmt.Sprint(str)
}
if str, ok := scope.Get("gorm:insert_modifier"); ok {
insertModifier = strings.ToUpper(fmt.Sprint(str))
if insertModifier == "INTO" {
insertModifier = ""
}
}
if primaryField != nil {
returningColumn = scope.Quote(primaryField.DBName)
}
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
if len(columns) == 0 {
scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v %v%v%v",
addExtraSpaceIfExist(insertModifier),
quotedTableName,
scope.Dialect().DefaultValueStr(),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
} else {
scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v (%v) VALUES (%v)%v%v",
addExtraSpaceIfExist(insertModifier),
scope.QuotedTableName(),
strings.Join(columns, ","),
strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
}
// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
}
} else {
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
} else {
scope.Err(ErrUnaddressable)
}
}
}
}
首先, 内部的第一个 for 循环遍历了所有的字段, 并更新了开头定义的三个切片.
for _, field := range scope.Fields() {
if scope.changeableField(field) {
if field.IsNormal && !field.IsIgnored {
if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
} else if !field.IsPrimaryKey || !field.IsBlank {
columns = append(columns, scope.Quote(field.DBName))
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
}
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
for _, foreignKey := range field.Relationship.ForeignDBNames {
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
columns = append(columns, scope.Quote(foreignField.DBName))
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
}
}
}
}
}
然后就是获取并设置一些信息:
var (
returningColumn = "*"
quotedTableName = scope.QuotedTableName()
primaryField = scope.PrimaryField()
extraOption string
insertModifier string
)
等信息都获取完了, 就开始构造插入语句了:
if len(columns) == 0 {
scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v %v%v%v",
addExtraSpaceIfExist(insertModifier),
quotedTableName,
scope.Dialect().DefaultValueStr(),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
} else {
scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v (%v) VALUES (%v)%v%v",
addExtraSpaceIfExist(insertModifier),
scope.QuotedTableName(),
strings.Join(columns, ","),
strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
}
最后执行 sql 语句:
// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
}
} else {
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
} else {
scope.Err(ErrUnaddressable)
}
}
这里的第一个判断条件是和 lastInsertIDReturningSuffix
有关的, 只有 PostgreSQL 会返回非空的字符串.
var userid int
err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age)
VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid)
PostgreSQL 中不支持 LastInsertId()
方法, 要获取 ID 需要像上面这样调用.
参考 PostgreSQL Queries.
所以执行方式有所不同.
这样, createCallback
回调就看完了, 插入数据的过程也知道了.
总结
在这一部分里, 主要看了数据表是如何创建和合并的, 以及钩子函数是如何注册并排序的, 以及何时被调用的.