Golang自定义基于gin框架的Session中间件
2020-11-05 本文已影响0人
FredricZhu
工程结构如下
image.png
原理主要是利用了cookie来保存sessionID。使用sessionID来获取每个用户对应的Session。
main.go测试代码
package main
import (
"fmt"
"log"
"net/http"
"github.com/gin-gonic/gin"
"github.com/zhuge20100104/gin_session/gsession"
)
func main() {
r := gin.Default()
mgrObj, err := gsession.CreateSessionMgr(gsession.Redis, "localhost:6379")
if err != nil {
log.Fatalf("Create manager obj failed, err: %v\n", err)
return
}
sm := gsession.SessionMiddleware(mgrObj, gsession.Options{
Path: "/",
Domain: "127.0.0.1",
MaxAge: 120,
Secure: false,
HttpOnly: true,
})
r.Use(sm)
r.GET("/incr", func(c *gin.Context) {
session := c.MustGet("session").(gsession.Session)
fmt.Printf("%#v\n", session)
var count int
v, err := session.Get("count")
if err != nil {
log.Printf("get count from session failed, err: %v\n", err)
count = 0
} else {
count = v.(int)
count++
}
session.Set("count", count)
session.Save()
c.String(http.StatusOK, "count:%v", count)
})
r.Run()
}
session.go
package gsession
import (
"fmt"
"log"
"github.com/gin-gonic/gin"
)
type SessionMgrType string
const (
// SessionID在cookie里面的名字
SessionCookieName = "session_id"
// Session对象在Context里面的名字
SessionContextName = "session"
Memory SessionMgrType = "memory"
Redis SessionMgrType = "redis"
)
// Session 接口
type Session interface {
// 获取Session对象的ID
ID() string
// 加载redis数据到 session data
Load() error
// 获取key对应的value值
Get(string) (interface{}, error)
// 设置key对应的value值
Set(string, interface{})
// 删除key对应的value值
Del(string)
// 落盘数据到redis
Save()
// 设置Redis数据过期时间,内存版本无效
SetExpired(int)
}
// SessionMgr Session管理器对象
type SessionMgr interface {
// 初始化Redis数据库连接
Init(addr string, options ...string) error
// 通过SessionID获取已经初始化的Session对象
GetSession(string) (Session, error)
// 创建一个新的Session对象
CreateSession() Session
// 使用SessionID清空一个Session对象
Clear(string)
}
// Options Cookie对应的相关选项
type Options struct {
Path string
Domain string
// Cookie中的SessionID存活时间
// MaxAge=0 means no 'Max-Age' attribute specified.
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
}
func CreateSessionMgr(name SessionMgrType, addr string, options ...string) (sm SessionMgr, err error) {
switch name {
case Memory:
sm = NewMemSessionMgr()
case Redis:
sm = NewRedisSessionMgr()
default:
err = fmt.Errorf("unsupported %v\n", name)
return
}
err = sm.Init(addr, options...)
return
}
func SessionMiddleware(sm SessionMgr, options Options) gin.HandlerFunc {
return func(c *gin.Context) {
var session Session
// 尝试从cookie获取session ID
sessionID, err := c.Cookie(SessionCookieName)
if err != nil {
log.Printf("get session_id from cookie failed, err:%v\n", err)
session = sm.CreateSession()
sessionID = session.ID()
} else {
log.Printf("SessionId: %v\n", sessionID)
session, err = sm.GetSession(sessionID)
if err != nil {
log.Printf("Get session by %s failed, err: %v\n", sessionID, err)
session = sm.CreateSession()
sessionID = session.ID()
}
}
session.SetExpired(options.MaxAge)
c.Set(SessionContextName, session)
c.SetCookie(SessionCookieName, sessionID, options.MaxAge, options.Path, options.Domain, options.Secure, options.HttpOnly)
defer sm.Clear(sessionID)
c.Next()
}
}
memory.go
package gsession
import (
"fmt"
"sync"
uuid "github.com/satori/go.uuid"
)
// memSession 内存对应的Session对象
type memSession struct {
// 全局唯一标识的session id对象
id string
// session数据
data map[string]interface{}
// session过期时间
expired int
// 读写锁,支持多线程
rwLock sync.RWMutex
}
func NewMemSession(id string) *memSession {
return &memSession{
id: id,
data: make(map[string]interface{}, 8),
}
}
func (m *memSession) ID() string {
return m.id
}
func (m *memSession) Load() (err error) {
return
}
func (m *memSession) Get(key string) (value interface{}, err error) {
m.rwLock.RLock()
defer m.rwLock.RUnlock()
value, ok := m.data[key]
if !ok {
err = fmt.Errorf("Invalid key")
return
}
return
}
func (m *memSession) Set(key string, value interface{}) {
m.rwLock.Lock()
defer m.rwLock.Unlock()
m.data[key] = value
}
func (m *memSession) Del(key string) {
m.rwLock.Lock()
defer m.rwLock.Unlock()
delete(m.data, key)
}
func (m *memSession) Save() {
return
}
func (m *memSession) SetExpired(expired int) {
m.expired = expired
}
// MemSessionMgr 内存Session管理器
type MemSessionMgr struct {
session map[string]Session
rwLock sync.RWMutex
}
// NewMemSessionMgr MemSessionMgr类构造函数
func NewMemSessionMgr() *MemSessionMgr {
return &MemSessionMgr{
session: make(map[string]Session, 1024),
}
}
func (m *MemSessionMgr) Init(addr string, options ...string) (err error) {
return
}
// GetSession get the session by session id
func (m *MemSessionMgr) GetSession(sessionID string) (sd Session, err error) {
m.rwLock.RLock()
defer m.rwLock.RUnlock()
sd, ok := m.session[sessionID]
if !ok {
err = fmt.Errorf("Invalid session id")
return
}
return
}
func (m *MemSessionMgr) CreateSession() (sd Session) {
sessionID := uuid.NewV4().String()
sd = NewMemSession(sessionID)
m.session[sd.ID()] = sd
return
}
func (m *MemSessionMgr) Clear(sessionID string) {
m.rwLock.Lock()
defer m.rwLock.Unlock()
delete(m.session, sessionID)
}
redis.go
package gsession
import (
"bytes"
"encoding/gob"
"fmt"
"log"
"strconv"
"sync"
"time"
"github.com/go-redis/redis"
uuid "github.com/satori/go.uuid"
)
// redisSession redis session对象
type redisSession struct {
// redis session id 对象
id string
// session 数据对象
data map[string]interface{}
// session 数据是否有更新
modifyFlag bool
// 过期时间
expired int
rwLock sync.RWMutex
client *redis.Client
}
func NewRedisSession(id string, client *redis.Client) (session Session) {
session = &redisSession{
id: id,
data: make(map[string]interface{}, 8),
client: client,
}
return
}
func (r *redisSession) ID() string {
return r.id
}
func (r *redisSession) Load() (err error) {
data, err := r.client.Get(r.id).Bytes()
if err != nil {
log.Printf("get session data from redis by %s failed, err: %v\n", r.id, err)
return
}
dec := gob.NewDecoder(bytes.NewBuffer(data))
err = dec.Decode(&r.data)
if err != nil {
log.Printf("gob decode session data failed, err: %v\n", err)
return
}
return
}
func (r *redisSession) Get(key string) (value interface{}, err error) {
r.rwLock.RLock()
defer r.rwLock.RUnlock()
value, ok := r.data[key]
if !ok {
err = fmt.Errorf("invalid key")
return
}
return
}
func (r *redisSession) Set(key string, value interface{}) {
r.rwLock.Lock()
defer r.rwLock.Unlock()
r.data[key] = value
r.modifyFlag = true
}
func (r *redisSession) Del(key string) {
r.rwLock.Lock()
defer r.rwLock.Unlock()
delete(r.data, key)
r.modifyFlag = true
}
func (r *redisSession) SetExpired(expired int) {
r.expired = expired
}
func (r *redisSession) Save() {
r.rwLock.Lock()
defer r.rwLock.Unlock()
if !r.modifyFlag {
return
}
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(r.data)
if err != nil {
log.Fatalf("gob encode r.data failed, err: %v\n", err)
return
}
r.client.Set(r.id, buf.Bytes(), time.Second*time.Duration(r.expired))
log.Printf("set data %v to redis.\n", buf.Bytes())
r.modifyFlag = false
}
// redisSessionMgr redis Session管理器对象
type redisSessionMgr struct {
session map[string]Session
rwLock sync.RWMutex
client *redis.Client
}
// NewRedisSessionMgr Redis SessionMgr类构造函数
func NewRedisSessionMgr() *redisSessionMgr {
return &redisSessionMgr{
session: make(map[string]Session, 1024),
}
}
func (r *redisSessionMgr) Init(addr string, options ...string) (err error) {
var (
password string
db int
)
if len(options) == 1 {
password = options[0]
}
if len(options) == 2 {
password = options[0]
db, err = strconv.Atoi(options[1])
if err != nil {
log.Fatalln("invalid redis DB param")
}
}
r.client = redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
DB: db,
})
_, err = r.client.Ping().Result()
if err != nil {
return
}
return nil
}
func (r *redisSessionMgr) GetSession(sessionID string) (sd Session, err error) {
sd = NewRedisSession(sessionID, r.client)
err = sd.Load()
if err != nil {
return
}
r.rwLock.RLock()
r.session[sessionID] = sd
r.rwLock.RUnlock()
return
}
func (r *redisSessionMgr) CreateSession() (sd Session) {
sessionID := uuid.NewV4().String()
sd = NewRedisSession(sessionID, r.client)
r.session[sd.ID()] = sd
return
}
func (r *redisSessionMgr) Clear(sessionID string) {
r.rwLock.Lock()
defer r.rwLock.Unlock()
delete(r.session, sessionID)
}
程序输出如下,
image.png