wal-g
106 строк · 2.4 Кб
1package splitmerge
2
3import (
4"bytes"
5"context"
6"fmt"
7"io"
8"math/rand"
9"testing"
10
11"github.com/stretchr/testify/assert"
12"golang.org/x/sync/errgroup"
13
14"github.com/wal-g/wal-g/internal/abool"
15)
16
17type BufferCloser struct {
18bytes.Buffer
19io.Closer
20closed abool.AtomicBool
21}
22
23func (b *BufferCloser) Write(p []byte) (n int, err error) {
24if b.closed.IsSet() {
25return 0, io.ErrClosedPipe
26}
27return b.Buffer.Write(p)
28}
29
30func (b *BufferCloser) Close() error {
31if !b.closed.SetToIf(false, true) {
32// The behavior of Close after the first call is undefined.
33// Specific implementations may document their own behavior.
34panic("UB")
35}
36return nil
37}
38
39// ┌─> copy data per 1 byte ─>┐
40//
41// data ─> split ├─> copy data per ... bytes ─>├─> merge
42//
43// └─> copy data per 42 bytes ─>┘
44func TestSplitMerge(t *testing.T) {
45const blockSize = 128
46const dataSize = 115249 // some prime number
47bufSizes := []int{1, blockSize + 1, blockSize - 1, 2*blockSize + 1, 4, 8, 15, 16, 23, 42}
48partitions := len(bufSizes)
49
50// in:
51inputData := generateDataset(dataSize)
52dataReader := bytes.NewReader(inputData)
53readers := SplitReader(context.Background(), dataReader, partitions, blockSize)
54
55// out:
56var sink BufferCloser
57writers := MergeWriter(&sink, partitions, blockSize)
58
59errGroup := new(errgroup.Group)
60for i := 0; i < partitions; i++ {
61// idx := i
62reader := readers[i]
63writer := writers[i]
64buffSize := bufSizes[i%len(bufSizes)]
65
66errGroup.Go(func() error {
67defer writer.Close()
68// read _all_ data first and only then send it to MergeWriter:
69allData, err := io.ReadAll(reader)
70if err != nil {
71return err
72}
73
74offset := 0
75for {
76data := make([]byte, buffSize, buffSize)
77rbytes := copy(data, allData[offset:])
78offset += rbytes
79// tracelog.InfoLogger.Printf("goroutine #%d: %d bytes fetched, err=%v", idx, rbytes, rerr)
80if rbytes == 0 {
81return nil
82}
83_, werr := writer.Write(data[:rbytes])
84if werr != nil {
85return werr
86} else {
87// tracelog.InfoLogger.Printf("goroutine #%d: %d bytes copied", idx, rbytes)
88}
89}
90})
91}
92
93// Wait for upload finished:
94assert.NoError(t, errGroup.Wait())
95
96fmt.Printf("%d\n", len(inputData))
97fmt.Printf("%d\n", sink.Len())
98
99assert.ElementsMatch(t, inputData, sink.Bytes())
100}
101
102func generateDataset(size int) []byte {
103result := make([]byte, size, size)
104rand.Read(result)
105return result
106}
107