go源码走读-gorm
目录
1.gorm使用实例
我们常使用懒加载和惰加载结合完成单例模式
import (
"sync"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
var (
db *gorm.DB
doOnce sync.Once
dsn string = "username:psd@(ip:port)/database?database?timeout=5000ms&readTimeout=5000ms&writeTimeout=5000ms&charset=utf8mb4&parseTime=true&loc=Local"
)
func GetDB() *gorm.DB {
var err error
doOnce.Do(func() {
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
panic(err)
}
// db.Table()
})
return db
}
本节接下来不真实连接数据库,而是用sql-mock来mock数据
2.核心DB类
即gorm.DB
首先需要理解会话这个概念,gorm.DB是该库定义的数据库类,所有执行的数据的操作都与这个类有关,以链式调用方式展开(如条件查询)
// Get first matched record
db.Where("name = ?", "jinzhu").First(&user)
每当链式调用后,就会新生成新的DB对象,该对象中存储了一些当前请求特有的状态信息,我们把这种对象叫做“会话”,即记录了执行过程上下文的对象
// gorm 中定义的数据库类
// 所有 orm 的思想
type DB struct {
// 配置
*Config
// 错误
Error error
// 影响的行数
RowsAffected int64
// 会话状态信息
Statement *Statement
// 克隆次数
clone int
}
- Statement:一次会话的状态信息,比如请求和响应信息
- clone: 会话被克隆的次数. 倘若 clone = 1,代表是始祖 DB 实例;倘若 clone > 1,代表是从始祖 DB 克隆出来的会话
- Error:一次会话执行过程中遇到的错误,一个信息里可能包含多个错误
func (db *DB) AddError(err error) error {
if err != nil {
// ...
if db.Error == nil {
db.Error = err
} else {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
}
return db.Error
}
我们看到使用fmt.Errorf和%w配合实现error wrapping(错误的拼接)
2.1 DB的克隆
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config, Error: db.Error}
// 倘若是首次对 db 进行 clone,则需要构造出一个新的 statement 实例
if db.clone == 1 {
// clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
}
// 倘若已经 db clone 过了,则还需要 clone 原先的 statement
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
return tx
}
return db
}
主要通过对clone字段来判断:
- clone=1,就克隆出一个新的会话
- clone>2,就从始祖DB上克隆
2.2Statement
// Statement statement
type Statement struct {
// 数据库实例
*DB
// ...
// 表名
Table string
// 操作的 po 模型
Model interface{}
// ...
// 处理结果反序列化到此处
Dest interface{}
// ...
// 各种条件语句
Clauses map[string]clause.Clause
// ...
// 是否启用 distinct 模式
Distinct bool
// select 语句
Selects []string // selected columns
// omit 语句
Omits []string // omit columns
// join
Joins []join
// ...
// 连接池,通常情况下是 database/sql 库下的 *DB 类型. 在 prepare 模式为 gorm.PreparedStmtDB
ConnPool ConnPool
// 操作表的概要信息
Schema *schema.Schema
// 上下文,请求生命周期控制管理
Context context.Context
// 在未查找到数据记录时,是否抛出 recordNotFound 错误
RaiseErrorOnNotFound bool
// ...
// 执行的 sql,调用 state.Build 方法后,会将 sql 各部分文本依次追加到其中. 具体可见 2.5 小节
SQL strings.Builder
// 存储的变量
Vars []interface{}
// ...
}
因为要记录上下文,所以字段成员比较多,慢慢来看
2.3 po 模型
orm的思想就是将数据库表映射成一个数据模型(类/结构体),我们将该模型称为po模型(persist object 持久化数据模型),下面就是一个数据表的po类
type Reward struct {
gorm.Model
Amount sql.NullInt64 `gorm:"column:amount"`
Type string `gorm:"not null"`
UserID int64 `gorm:"not null"`
}
func (r Reward) TableName() string {
return "reward"
}
如果po模型声明了TableName方法,则隐式实现了gorm.Tabler接口
type Tabler interface {
TableName() string
}
解析表时先尝试转为该接口,失败则直接用po模型的结构体名(会经过一定规则转化)当表名
那么问题来了,gorm.Statement.Model也是po模型,那这两个有啥区别?我们先暂时埋下一个疑问
2.4查询流程
如下测试代码
func TestQuery(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
gdb, err := gorm.Open(mysql.New(mysql.Config{
SkipInitializeWithVersion: true,
Conn: db,
}), &gorm.Config{})
require.NoError(t, err)
rows := sqlmock.NewRows([]string{"id"}).AddRow(2)
mock.ExpectQuery("SELECT *").WillReturnRows(rows)
type Name struct {
Id string
}
var name Name
res := gdb.First(&name)
t.Log(res.Error)
t.Log(res.RowsAffected)
t.Log(name.Id)
}
我们查看first源码
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB){
tx = db.Limit(1).Order(xxxx)
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
- 先是进行limit和order的链式调用(因为first只返回一条)
- 然后设置会话属性,最后调用callbacks执行查询方法
2.4.1 执行器processor
callbacks是db的内嵌字段config的成员
type Config struct{
callbacks *callbacks
}
type callbacks struct {
processors map[string]*processor
}
type processor struct {
db *DB
Clauses []string
fns []func(*DB)
callbacks []*callback
}
它的唯一成员就是gorm框架执行curd操作逻辑时用到的执行器processor,针对curd操作的处理函数会以list的形式聚合在对应类型的processor的fns字段中
// 对应存储了 crud 等各类操作对应的执行器 processor
// query -> query processor
// create -> create processor
// update -> update processor
// delete -> delete processor
也就是说调用callbacks.Query查询方法实际就是执行query processor的fns函数成员
func (cs *callbacks) Query() *processor {
return cs.processors["query"]
}
各类 processor 的初始化是通过 initializeCallbacks 方法完成,该方法是在gorm.Open中执行的
func initializeCallbacks(db *DB) *callbacks {
return &callbacks{
processors: map[string]*processor{
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
},
}
}
再来细看processor具体结构
type processor struct {
// 从属的 DB 实例
db *DB
// 拼接 sql 时的关键字顺序. 比如 query 类,固定为 SELECT,FROM,WHERE,GROUP BY, ORDER BY, LIMIT, FOR
Clauses []string
// 对应于 crud 类型的执行函数链
fns []func(*DB)
callbacks []*callback
}
对应的Execute方法
func (p *processor) Execute(db *DB) *DB {
// call scopes
var (
// ...
stmt = db.Statement
// ...
)
if len(stmt.BuildClauses) == 0 {
// 根据 crud 类型,对 buildClauses 进行复制,用于后续的 sql 拼接
stmt.BuildClauses = p.Clauses
// ...
}
// ...
// dest 和 model 相互赋值
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
// 解析 model,获取对应表的 schema 信息
if stmt.Model != nil {
// ...
}
// 处理 dest 信息,将其添加到 stmt 当中
if stmt.Dest != nil {
// ...
}
// 执行一系列的 callback 函数,其中最核心的 create/query/update/delete 操作都被包含在其中了. 还包括了一系列前、后处理函数,具体可见第 3 章
for _, f := range p.fns {
f(db)
}
//...
return db
}
- 获取会话
- 从会话中获取条件语句用于后续拼接sql语句
- 解析po模型到model中
- 处理dest信息,这里指我们传入fist的参数,最终会修改结果赋值
- 执行callback函数
我们注意这里的第三步解析po模型到model中
stmt.Parse(stmt.Model)
实际上就是再将po模型具体化,比如提取出表名,字段名,对应的数据库类型信息等保存到model中
2.4.2 条件 clause
一条执行 sql 中,各个部分都属于一个 clause,比如一条 SELECT * FROM reward WHERE id < 10 ORDER by id 的 SQL,其中就包含了 SELECT、FROM、WHERE 和 ORDER 四个 clause.
当使用方通过链式操作克隆 DB时,对应追加的状态信息就会生成一个新的 clause,追加到 statement 对应的 clauses 集合当中. 当请求实际执行时,会取出 clauses 集合,拼接生成完整的 sql 用于执行.
clause本身是一个抽象的interface
// Interface clause interface
type Interface interface {
// clause 名称
Name() string
// 生成对应的 sql 部分
Build(Builder)
// 和同类 clause 合并
MergeClause(*Clause)
}
不同的 clause 有不同的实现类,我们以 SELECT 为例进行展示:
type Select struct {
// 使用使用 distinct 模式
Distinct bool
// 是否 select 查询指定的列,如 select id,name
Columns []Column
Expression Expression
}
sql语句的拼接是通过调用statement.Build方法实现的,入参对应的是 crud 中某一类 processor 的 BuildClauses.
func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
if firstClauseWritten {
stmt.WriteByte(' ')
}
firstClauseWritten = true
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b(c, stmt)
} else {
c.Build(stmt)
}
}
}
}
2.4.3 小结
综上我们小结first查询方法:
- 通过limit和order追加clause,追加条件过程如下 :
- 调用getInstance方法克隆出会话
- 调用addClause方法将条件追加到statement的clauses map中
- 设置statement.dest
- 获取query类型的执行器processor,调用execute方法执行其中的fns函数链
2.5 Query方法
既然已经知道了fns方法是最终调用方法,那么它是由谁注册的?
自然是一开始的驱动
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
我这里就是mysql驱动注册的,看open方法,我们以查询为例搜索注册的查询函数
func Open(dsn string) gorm.Dialector {
dsnConf, _ := mysql.ParseDSN(dsn)
return &Dialector{Config: &Config{DSN: dsn, DSNConfig: dsnConf}}
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error){
// ...完成 crud 类操作 callback 函数的注册
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
}
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
queryCallback.Clauses = config.QueryClauses
}
func (p *processor) Register(name string, fn func(*DB)) error {
return (&callback{processor: p}).Register(name, fn)
}
func (c *callback) Register(name string, fn func(*DB)) error {
c.name = name
c.handler = fn
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile()
}
顺藤摸瓜,我们找到了query函数
func Query(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if !db.DryRun && db.Error == nil {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
}
}
}
- 先根据clauses组装sql
- 完成sql查询类的执行,返回查询到的行数据rows
- 将结果反序列化到dest中
2.5.1 连接池connPool
connPool 字段,其含义是连接池,和数据库的交互操作都需要依赖它才得以执行. connPool 本身是个 interface,定义如下:
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
connPool 根据是否启用了 prepare 预处理模式,存在不同的实现类版本:
- 在普通模式下,connPool 的实现类为 database/sql 库下的 DB 类
- 在 prepare 模式下,connPool 实现类型为 gorm 中定义的 PreparedStmtDB 类
prepare是什么?直观点将就是缓存,mysql5.8之后都直接抛弃了,了解就好
// prepare 模式下的 connPool 实现类.
type PreparedStmtDB struct {
// 各 stmt 实例. 其中 key 为 sql 模板,stmt 是对封 database/sql 中 *Stmt 的封装
Stmts map[string]*Stmt
// ...
Mux *sync.RWMutex
// 内置的 ConnPool 字段通常为 database/sql 中的 *DB
ConnPool
}
Stmt 类是 gorm 框架对 database/sql 标准库下 Stmt 类的简单封装,两者区别并不大:
type Stmt struct {
// database/sql 标准库下的 statement
*sql.Stmt
// 是否处于事务
Transaction bool
// 标识当前 stmt 是否已初始化完成
prepared chan struct{}
prepareErr error
}
举一反三,剩下的更新/删除/插入也是如此原理
2.6 事务
db.Transaction
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// ...
} else {
// 开启事务
tx := db.Begin(opts...)
if tx.Error != nil {
return tx.Error
}
defer func() {
// 倘若发生错误或者 panic,则进行 rollback 回滚
if panicked || err != nil {
tx.Rollback()
}
}()
// 执行事务内的逻辑
if err = fc(tx); err == nil {
panicked = false
// 指定成功会进行 commit 操作
return tx.Commit().Error
}
}
panicked = false
return
}
- 调用Begin方法启动事务,克隆出一个带事务属性的会话tx
- 以tx为参数传入闭包函数执行,根据成功与否执行提交或回滚
2.7 了解Prepare
在 PreparedStmtDB.prepare 方法中,会通过加锁 double check 的方式,创建或复用 sql 模板对应的 stmt. 创建 stmt 的操作通过调用 conn.PrepareContext 方法完成.(通常此处的 conn 为 database/sql 库下的 sql.DB)
PreparedStmtDB.prepare 方法核心流程梳理如下:
• 加读锁,然后以 sql 模板为 key,尝试从 db.Stmts map 中获取 stmt 复用
• 倘若 stmt 不存在,则加写锁 double check
• 调用 conn.PrepareContext(...) 方法,创建新的 stmt,并存放到 map 中供后续复用
完整的代码和对应的注释展示如下:
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock()
// 以 sql 模板为 key,优先复用已有的 stmt
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
// 并发场景下,只允许有一个 goroutine 完成 stmt 的初始化操作
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
db.Mux.RUnlock()
// 加锁 double check,确认未完成 stmt 初始化则执行初始化操作
db.Mux.Lock()
// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
// 创建 stmt 实例,并添加到 stmts map 中
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
// 此时可以提前解锁是因为还通过 channel 保证了其他使用者会阻塞等待初始化操作完成
db.Mux.Unlock()
// 所有工作执行完之后会关闭 channel,唤醒其他阻塞等待使用 stmt 的 goroutine
defer close(cacheStmt.prepared)
// 调用 *sql.DB 的 prepareContext 方法,创建真正的 stmt
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
}
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()
return cacheStmt,nil
}
在 prepare 模式下,查询操作通过 PreparedStmtDB.QueryContext(...) 方法实现. 首先通过 PreparedStmtDB.prepare(...) 方法尝试复用 stmt,然后调用 stmt.QueryContext(...) 执行查询操作.同理执行流程也大差不多