Gorm的使用心得和一些常用扩展(二)

2019-07-19  本文已影响0人  海之方

上一篇文章,我分享了自己在新增和更新的场景下,自己使用gorm的一些心得和扩展。本文,我将分享一些在查询的方面的心得。

首先,我把查询按照涉及到的表的数量分为:

按照查询范围又可以分为:

在日常使用中,单表查询占据了多半的场景,把这部分的代码按照查询范围做一些封装,可以大大减少冗余的代码。

单表查询

于是,我仿照gorm API的风格,做了如下的封装:

ps:以下例子均以假定已定义user对象

查询一个

func (dw *DBExtension) GetOne(result interface{}, query interface{}, args ...interface{}) (found bool, err error) {
    var (
        tableNameAble TableNameAble
        ok            bool
    )

    if tableNameAble, ok = query.(TableNameAble); !ok {
        if tableNameAble, ok = result.(TableNameAble); !ok {
            return false, errors.New("neither the query nor result implement TableNameAble")
        }
    }

    err = dw.Table(tableNameAble.TableName()).Where(query, args...).First(result).Error

    if err == gorm.ErrRecordNotFound {
        dw.logger.LogInfoc("mysql", fmt.Sprintf("record not found for query %s, the query is %+v, args are %+v", tableNameAble.TableName(), query, args))
        return false, nil
    }

    if err != nil {
        dw.logger.LogErrorc("mysql", err, fmt.Sprintf("failed to query %s, the query is %+v, args are %+v", tableNameAble.TableName(), query, args))
        return false, err
    }

    return true, nil
}

这段值得说明的就是对查询不到数据时的处理,gorm是报了gorm.ErrRecordNotFound的error, 我是对这个错误做了特殊处理,用found这个boolean值表述这个特殊状态。

调用代码如下:

condition := User{Id:1}
result := User{}

if  found, err := dw.GetOne(&result, condition); !found {
    //not found
    if err != nil {
        // has error
        return err
    }
    
}

也可以这样写,更加灵活的指定的查询条件:

result := User{}

if  found, err := dw.GetOne(&result, "id = ?", 1); !found {
    //not found
    if err != nil {
        // has error
        return err
    }
    
}

两种写法执行的语句都是:

select * from test.user where id = 1

范围查询

针对四种范国查询,我做了如下封装:


func (dw *DBExtension) GetList(result interface{}, query interface{}, args ...interface{}) error {
    return dw.getListCore(result, "", 0, 0, query, args)
}

func (dw *DBExtension) GetOrderedList(result interface{}, order string, query interface{}, args ...interface{}) error {
    return dw.getListCore(result, order, 0, 0, query, args)
}

func (dw *DBExtension) GetFirstNRecords(result interface{}, order string, limit int, query interface{}, args ...interface{}) error {
    return dw.getListCore(result, order, limit, 0, query, args)
}

func (dw *DBExtension) GetPageRangeList(result interface{}, order string, limit, offset int, query interface{}, args ...interface{}) error {
    return dw.getListCore(result, order, limit, offset, query, args)
}

func (dw *DBExtension) getListCore(result interface{}, order string, limit, offset int, query interface{}, args []interface{}) error {
    var (
        tableNameAble TableNameAble
        ok            bool
    )

    if tableNameAble, ok = query.(TableNameAble); !ok {
        // type Result []*Item{}
        // result := &Result{}
        resultType := reflect.TypeOf(result)
        if resultType.Kind() != reflect.Ptr {
            return errors.New("result is not a pointer")
        }

        sliceType := resultType.Elem()
        if sliceType.Kind() != reflect.Slice {
            return errors.New("result doesn't point to a slice")
        }
        // *Item
        itemPtrType := sliceType.Elem()
        // Item
        itemType := itemPtrType.Elem()

        elemValue := reflect.New(itemType)
        elemValueType := reflect.TypeOf(elemValue)
        tableNameAbleType := reflect.TypeOf((*TableNameAble)(nil)).Elem()

        if elemValueType.Implements(tableNameAbleType) {
            return errors.New("neither the query nor result implement TableNameAble")
        }

        tableNameAble = elemValue.Interface().(TableNameAble)
    }

    db := dw.Table(tableNameAble.TableName()).Where(query, args...)
    if len(order) != 0 {
        db = db.Order(order)
    }

    if offset > 0 {
        db = db.Offset(offset)
    }

    if limit > 0 {
        db = db.Limit(limit)
    }

    if err := db.Find(result).Error; err != nil {
        dw.logger.LogErrorc("mysql", err, fmt.Sprintf("failed to query %s, query is %+v, args are %+v, order is %s, limit is %d", tableNameAble.TableName(), query, args, order, limit))
        return err
    }

    return nil
}

为了减少冗余的代码,通用的逻辑写在getListCore函数里,里面用到了一些golang反射的知识。

但只要记得golang的反射和其它语言的反射最大的不同,是golang的反射是基本值而不是类型的,一切就好理解了。

其中的一个小技巧是如何判断一个类型是否实现了某个接口,用到了指向nil的指针。

    elemValue := reflect.New(itemType)
    elemValueType := reflect.TypeOf(elemValue)
    tableNameAbleType := reflect.TypeOf((*TableNameAble)(nil)).Elem()

    if elemValueType.Implements(tableNameAbleType) {
        return errors.New("neither the query nor result implement TableNameAble")
    }

关于具体的使用,就不再一一举例子了,熟悉gorm api的同学可以一眼看出。

多表查询

关于多表查询,因为不同场景很难抽取出不同,也就没有再做封装,但是我的经验是优先多使用gorm的方法,而不是自己拼sql。你想要做的gorm都可以实现。

这里,我偷个懒,贴出自己在项目中写的最复杂的一段代码,供各位看官娱乐。

一个复杂的例子

这段代码是从埋点数据的中间表,为了用通用的代码实现不同展示场景下的查询,代码设计的比较灵活,其中涉及了关联多表的查询,按查询条件动态过滤和聚合,还有分页查询的逻辑。

func buildCommonStatisticQuery(tableName, startDate, endDate string) *gorm.DB {
    query := models.DB().Table(tableName)

    if startDate == endDate || endDate == "" {
        query = query.Where("date = ?", startDate)
    } else {
        query = query.Where("date >= ? and date <= ?", startDate, endDate)
    }

    return query
}

func buildElementsStatisticQuery(startDate, endDate,  elemId string,  elemType int32) *gorm.DB {
    query := buildCommonStatisticQuery("spotanalysis.element_statistics", startDate, endDate)

    if elemId != "" && elemType != 0 {
        query = query.Where("element_id = ? and element_type = ?", elemId, elemType)
    }

    return query
}

func CountElementsStatistics(count *int32, startDate, endDate, instId, appId, elemId string, elemType int32, groupFields []string ) error {
    query := buildElementsStatisticQuery(startDate, endDate,  elemId, elemType)

    query = whereInstAndApp(query, instId, appId)

    if len(groupFields) != 0 {
        query = query.Select(fmt.Sprintf("count(distinct(concat(%s)))", strings.Join(groupFields, ",")))
    } else {
        query = query.Select("count(id)")
    }

    query = query.Count(count)
    return query.Error
}


func GetElementsStatistics(result interface{}, startDate, endDate, instId, appId, elemId string, elemType int32, groupFields []string, orderBy string, ascOrder bool, limit, offset int32) error {
    query := buildElementsStatisticQuery(startDate, endDate, elemId, elemType)
    if len(groupFields) != 0 {
        groupBy := strings.Join(groupFields, "`,`")
        groupBy = "`" + groupBy + "`"
        query = query.Group(groupBy)
        query = havingInstAndApp(query, instId, appId)

        sumFields := strings.Join([]string{
            "SUM(`element_statistics`.`mp_count`) AS `mp_count`",
            "SUM(`element_statistics`.`h5_count`) AS `h5_count`",
            "SUM(`element_statistics`.`total_count`) AS `total_count`",
            "SUM(`element_statistics`.`collection_count`) AS `collection_count`",
            "SUM(`element_statistics`.`mp_share_count`) AS `mp_share_count`",
            "SUM(`element_statistics`.`h5_share_count`) AS `h5_share_count`",
            "SUM(`element_statistics`.`poster_share_count`) AS `poster_share_count`",
            "SUM(`element_statistics`.`total_share_count`) AS `total_share_count`",
        }, ",")

        query = query.Select(groupBy + "," + sumFields)
    } else {
        query = whereInstAndApp(query, instId, appId)
    }

    query = getPagedList(query, orderBy, ascOrder, limit, offset)

    return query.Find(result).Error
}

func getPagedList(query *gorm.DB, orderBy string, ascOrder bool, limit , offset int32) *gorm.DB {
    if orderBy != "" {
        if ascOrder {
            orderBy += " asc"
        } else {
            orderBy += " desc"
        }
        query = query.Order(orderBy)
    }

    if offset != 0 {
        query = query.Offset(offset)
    }
    if limit != 0 {
        query = query.Limit(limit)
    }
    return query
}

func whereInstAndApp(query *gorm.DB, instId string, appId string) *gorm.DB {
    query = query.Where("inst_id = ?", instId)
    if appId != "" {
        query = query.Where("app_id = ?", appId)
    }
    return query
}

func havingInstAndApp(query *gorm.DB, instId string, appId string) *gorm.DB {
    query = query.Having("inst_id = ?", instId)
    if appId != "" {
        query = query.Having("app_id = ?", appId)
    }
    return query
}

感谢各位看官耐心看完,如果本文对你有用,请点个赞~~~

如果能到代码仓库:Github:Ksloveyuan/gorm-ex 给个✩star✩, 楼主就更加感谢了!

上一篇下一篇

猜你喜欢

热点阅读