用好工作池 WaitGroup

2018-12-27  本文已影响19人  酷走天涯

本节学习

WaitGroup的用法

WaitGroup 用于实现工作池,因此要理解工作池,我们首先需要学习 WaitGroup。

WaitGroup 用于等待一批 Go 协程执行结束。程序控制会一直阻塞,直到这些协程全部执行完毕。假设我们有 3 个并发执行的 Go 协程(由 Go 主协程生成)。Go 主协程需要等待这 3 个协程执行结束后,才会终止。这就可以用 WaitGroup 来实现

package main

import (
    "fmt"
    "time"
    "sync"
)

func login(wg *sync.WaitGroup){
    time.Sleep(time.Second)
    fmt.Println("登录完成")
    wg.Done()
}
func getUserInfo(wg *sync.WaitGroup){
    time.Sleep(time.Second)
    fmt.Println("获取用户信息")
    wg.Done()  //4
}

func main() {
   var wg sync.WaitGroup // 1
   wg.Add(1) //2
   go login(&wg)
   wg.Add(1)
   go getUserInfo(&wg) //3

   wg.Wait()
   fmt.Println("执行完毕")
}

上面写了两次wg.Add(1),当然你也可以写一次wg.Add(2)

image.png

下面是waitGroup 的使用说明
1.WaitGroup 是一个等待协程完成的结构体
2.主协程通过调用Add 方法设置等待协程的数量
3.每个子协程完成的时候,需要调用Done 方法,那么等待的协程数量会自动减一
4.wait方法主要告诉协程,开启等待模式,知道所有的协程完成工作

注意事项
go login(&wg) 我们一定要传递指针类型的变量,因为sync.WaitGroup 是结构体,是值类型,在传递的过程中会赋值,如果不用指针,创建的时候,就不是原来的结构体了

工作池

工作池就是一组等待任务分配的协程。一旦完成了所分配的任务,这些线程可继续等待任务的分配。

我们会使用缓冲信道来实现工作池。我们工作池的任务是计算所输入数字的每一位的和。例如,如果输入 234,结果会是 9(即 2 + 3 + 4)。向工作池输入的是一列伪随机数。

我们工作池的核心功能如下:

创建一个 Go 协程池,监听一个等待作业分配的输入型缓冲信道。
将作业添加到该输入型缓冲信道中。
作业完成后,再将结果写入一个输出型缓冲信道。
从输出型缓冲信道读取并打印结果。

package main

import (
    "fmt"
    "math/rand"
    "sync"
    "time"
)

type Job struct {
    id       int
    randomno int
}
type Result struct {
    job         Job
    sumofdigits int
}

var jobs = make(chan Job, 10)
var results = make(chan Result, 10)


// 1.创建工作任务
func allocate(noOfJobs int) {
    for i := 0; i < noOfJobs; i++ {
        randomno := rand.Intn(999)
        job := Job{i, randomno}
        jobs <- job
    }
    // 关闭工作信道
    close(jobs)
}

// 2.计算数的和
func digits(number int) int {
    sum := 0
    no := number
    for no != 0 {
        digit := no % 10
        sum += digit
        no /= 10
    }
    time.Sleep(2 * time.Second)
    return sum
}

// 执行一项工作 一项工作启用 一个协程 工作完毕后,等待组减一 多个协程同时调用 这个方法 会对 同一个信道 jobs 产生竞争,谁先拿到值,谁先执行
func 3.worker(wg *sync.WaitGroup) {
    for job := range jobs {
        output := Result{job, digits(job.randomno)}
        results <- output
    }
    wg.Done()
}

// 4.创建执行数量的工作组
func createWorkerPool(noOfWorkers int) {
    var wg sync.WaitGroup
    for i := 0; i < noOfWorkers; i++ {
        wg.Add(1)
        go worker(&wg) 
    }
    wg.Wait()
    // 当所有任务执行完毕后,关闭通道
    close(results)
}



// 5.对结果进行输出
func result(done chan bool) {
    for result := range results {
        fmt.Printf("Job id %d, input random no %d , sum of digits %d\n", result.job.id, result.job.randomno, result.sumofdigits)
    }
    done <- true
}

func main() {
    startTime := time.Now()

    noOfJobs := 100

    go allocate(noOfJobs)

    done := make(chan bool)
    // 完成一项任务执行一次输出
    go result(done)
    
    // 创建工作池开始做任务
    noOfWorkers := 50

    createWorkerPool(noOfWorkers)
    
    // 等待所有任务完成
    <-done
    endTime := time.Now()
    diff := endTime.Sub(startTime)
    fmt.Println("total time taken ", diff.Seconds(), "seconds")
}

案例二

我们要下载100 张图片,模拟每张图片下载需要500ms,请使用工作池实现这个下载图片的任务

download.go

package util

import (
    "sync"
    "fmt"
)

var results = make(chan string)
var tasks = make(chan string)

type Download struct {
    urls []string
    results []string
}

func (d *Download)CreateTasks(urls []string){
    d.urls = urls
    for _,v := range d.urls{
        tasks <- v
    }
    close(tasks)
}

func (d *Download)DownloadBy(url string){
    d.results = append(d.results,"下载完成了")

}

func (d *Download)startDownload(ws * sync.WaitGroup){
    for task :=  range tasks{
        d.DownloadBy(task)
    }
    ws.Done()
}

func (d *Download)CreateWorkerPool(num int, finish chan bool){
    var ws sync.WaitGroup
    for i := 0;i <num;i++{
        ws.Add(1)
        go d.startDownload(&ws)
    }
    ws.Wait()
    close(results)
    finish <- true
}



func (d *Download) Download(urls []string,finish func([]string)){
    result := make(chan bool)
    go d.CreateWorkerPool(10,result)
    go d.CreateTasks(urls)
    <- result
    fmt.Println("最后执行的")
    if(finish != nil){
        finish(d.results)
    }
}

main.go 文件

package main

import (
        "fmt"
    "strconv"
     "./util"

)



func main() {
    download := util.Download{}
    var urls = []string{}
    for i:=0;i < 100;i++{
        urls = append(urls, strconv.Itoa(i))
    }
    download.Download(urls, func(result []string) {
       fmt.Println(result)
    }) 
}
image.png
上一篇下一篇

猜你喜欢

热点阅读