take4s5i DEV

25 Jul 2022

Goのchannel

Goのchannelについて勉強したのでまとめます。 Go 1.18 で調べています。

読んだ資料はこの当たり

channel の作り方

go の channel は goroutine 間でデータのやり取りを行うためのプリミティブ。 マップやスライスと同様に make で値を生成する。

<- ch で値を読み出し、 ch <- で値を書き込むことができる

channel は使い終わったら close する。 closeした channel に読み書きすると panic する

main.go [src]

package main

import (
	"fmt"
	"time"
)

func main() {
	ch := make(chan int)

	go func() {
		ch <- 1
	}()

	time.Sleep(500 * time.Millisecond)
	fmt.Println(<-ch)
	close(ch)
}

channel の読み書き

channel は goroutine safe なので複数の goroutine から読み書きできる。 Rust の mpsc と異なり、同時に複数書き込んだり、読み込んだりしても良い。 Mutex 使って同期する必要もない。

<-chan T のように <- を型名につけることで、receive only channel にできる。 同様に chan<- T で send only channel になる。

main.go [src]

package main

import (
	"fmt"
	"sync"
)

func startSend() <-chan string {
	var wg sync.WaitGroup
	ch := make(chan string)
	nSender := 2
	nSend := 3

	wg.Add(nSender)

	// nSender 個の goroutine を起動
	for n := 0; n < nSender; n++ {
		n := n
		go func() {
			defer wg.Done()

			// nSend 回のメッセージを ch に送信
			for v := 0; v < nSend; v++ {
				ch <- fmt.Sprintf("%d from sender %d", v, n)
			}
		}()
	}

	// すべての sender goroutine が終了したら close(ch) する
	go func() {
		wg.Wait()
		close(ch)
	}()

	return ch
}

func startReceive(ch <-chan string) {
	var wg sync.WaitGroup
	nReceiver := 2

	wg.Add(nReceiver)

	// nReceiver 個の goroutine を起動
	for n := 0; n < nReceiver; n++ {
		n := n
		go func() {
			defer wg.Done()

			for v := range ch {
				fmt.Printf("receive '%s' at receiver %d\n", v, n)
			}
		}()
	}

	wg.Wait()
}

func main() {
	ch := startSend()
	startReceive(ch)
}

出力:

receive '1 from sender 0' at receiver 0
receive '2 from sender 0' at receiver 0
receive '0 from sender 1' at receiver 0
receive '1 from sender 1' at receiver 0
receive '2 from sender 1' at receiver 0
receive '0 from sender 0' at receiver 1

channel を使ったパターン

channel を扱うときはどこで close するのか責任を明確にしたほうが良い気がする。 基本的に sender 側が close したほうがきれいに書けると思う。

Generator

他言語のジェネレーターに相当するようなもの。 go にはジェネレータ構文はないが、 goroutine と channel でジェネレータを作れる。 生成する値がなくなったら close することで、後続の処理にデータがもう来ないことを伝えられる。

main.go [src]

package main

import "fmt"

func fib(n int) <-chan int {
	ch := make(chan int)
	a, b := 0, 1

	go func() {
		defer close(ch)
		for {
			v := a
			a = b
			b = v + b
			n--

			if n < 0 {
				return
			}
			ch <- v
		}
	}()

	return ch
}

func main() {
	f := fib(10)

	for v := range f {
		fmt.Println(v)
	}
}

quit/done

goroutine を外から止めたいときに使う quitdone という名前でよく使われている気がする。

main.go [src]

package main

import (
	"fmt"
	"time"
)

func fib(quit <-chan struct{}) <-chan int {
	ch := make(chan int)
	a, b := 0, 1

	go func() {
		defer close(ch)
		for {
			v := a
			a = b
			b = v + b

			select {
			case ch <- v:
			case <-quit:
				// case <-quit: は quit が close された場合も実行される。
				// こう書いておくことで quit を close すると一括で goroutine を終了されられる
				return
			}
		}
	}()

	return ch
}

func main() {
	quit := make(chan struct{})
	f := fib(quit)

	go func() {
		time.Sleep(100 * time.Millisecond)
		close(quit)
	}()

	for v := range f {
		fmt.Println(v)
	}
}

context でほぼ同様の機能 ctx.Done()context.WithCancel(ctx) が提供されているので、 こちらを使ったほうがいいかもしれない。

main.go [src]

package main

import (
	"context"
	"fmt"
	"time"
)

func fib(ctx context.Context) <-chan int {
	ch := make(chan int)
	a, b := 0, 1

	go func() {
		defer close(ch)
		for {
			v := a
			a = b
			b = v + b

			select {
			case ch <- v:
			case <-ctx.Done():
				return
			}
		}
	}()

	return ch
}

func main() {
	ctx, cancel := context.WithCancel(context.Background())
	f := fib(ctx)

	go func() {
		time.Sleep(100 * time.Millisecond)
		cancel()
	}()

	for v := range f {
		fmt.Println(v)
	}
}

FanIn / FanOut

複数の channel を1つに束ねるのを FanIn、 1つのchannel を複数に分岐するのを FanOut という。

channel は Multi-Producer / Multi-Consumer なので特に気にせず 複数の goroutine から読み書きすればよい。

main.go [src]

package main

import (
	"fmt"
	"sync"
)

type MyInt int

func (x MyInt) String() string {
	return fmt.Sprint(int(x))
}

func collatz(n int) <-chan MyInt {
	ch := make(chan MyInt)

	go func() {
		defer close(ch)

		for {
			ch <- MyInt(n)
			switch {
			case n == 1:
				return
			case n%2 == 0:
				n = n / 2
			default:
				n = n*3 + 1
			}
		}
	}()

	return ch
}

func fanIn[T any](cs ...<-chan T) <-chan T {
	var wg sync.WaitGroup
	out := make(chan T)

	wg.Add(len(cs))

	for _, ch := range cs {
		ch := ch

		// 元となる channel から値を受け取り out にわたすだけの goroutine
		// 元channelが close されると wg.Done()する
		go func() {
			defer wg.Done()
			for v := range ch {
				out <- v
			}
		}()
	}

	// すべての元 channel が close されると close(out) する
	go func() {
		wg.Wait()
		close(out)
	}()

	return out
}

func fanOut[T fmt.Stringer](ch <-chan T, n int) {
	var wg sync.WaitGroup
	wg.Add(n)

	for i := 0; i < n; i++ {
		i := i

		// ch から読み取って出力するだけの goroutine
		// ch は goroutine safe なので 複数の goroutine から読み書きしても問題ない
		go func() {
			defer wg.Done()
			for v := range ch {
				fmt.Printf("'%s' from %d\n", v.String(), i)
			}
		}()
	}

	wg.Wait()
}

func main() {
	c := fanIn(collatz(10), collatz(20), collatz(30))
	fanOut(c, 3)
}

とりあえず今回はここまで