Dragonfly2
505 строк · 18.2 Кб
1/*
2* Copyright 2023 The Dragonfly Authors
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*/
16
17package service
18
19import (
20"context"
21"errors"
22"fmt"
23"io"
24"os"
25"path/filepath"
26"reflect"
27"sync"
28"testing"
29
30"github.com/stretchr/testify/assert"
31"go.uber.org/mock/gomock"
32"google.golang.org/protobuf/types/known/emptypb"
33
34trainerv1 "d7y.io/api/v2/pkg/apis/trainer/v1"
35trainerv1mocks "d7y.io/api/v2/pkg/apis/trainer/v1/mocks"
36
37"d7y.io/dragonfly/v2/pkg/idgen"
38"d7y.io/dragonfly/v2/trainer/config"
39storagemocks "d7y.io/dragonfly/v2/trainer/storage/mocks"
40trainingmocks "d7y.io/dragonfly/v2/trainer/training/mocks"
41)
42
43var (
44mockHostName = "localhost"
45mockIP = "127.0.0.1"
46mockHostID = idgen.HostIDV2(mockIP, mockHostName)
47mockDataset = []byte("foo")
48)
49
50func TestService_NewV1(t *testing.T) {
51tests := []struct {
52name string
53run func(t *testing.T, s any)
54}{
55{
56name: "new service",
57run: func(t *testing.T, s any) {
58assert := assert.New(t)
59assert.Equal(reflect.TypeOf(s).Elem().Name(), "V1")
60},
61},
62}
63for _, tc := range tests {
64t.Run(tc.name, func(t *testing.T) {
65ctl := gomock.NewController(t)
66defer ctl.Finish()
67storage := storagemocks.NewMockStorage(ctl)
68training := trainingmocks.NewMockTraining(ctl)
69tc.run(t, NewV1(config.New(), storage, training))
70})
71}
72}
73
74func TestV1_Train(t *testing.T) {
75tests := []struct {
76name string
77run func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
78ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder)
79}{
80{
81name: "receive GNN and MLP train requests success",
82run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
83ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
84networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(),
85fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
86if err != nil {
87t.Fatal(err)
88}
89
90downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(),
91fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
92if err != nil {
93t.Fatal(err)
94}
95
96var wg sync.WaitGroup
97wg.Add(1)
98defer wg.Wait()
99gomock.InOrder(
100mtts.Recv().Return(&trainerv1.TrainRequest{
101Hostname: mockHostName,
102Ip: mockIP,
103Request: &trainerv1.TrainRequest_TrainGnnRequest{
104TrainGnnRequest: &trainerv1.TrainGNNRequest{
105Dataset: mockDataset,
106},
107},
108}, nil).Times(1),
109
110ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1),
111ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1),
112mtts.Recv().Return(&trainerv1.TrainRequest{
113Hostname: mockHostName,
114Ip: mockIP,
115Request: &trainerv1.TrainRequest_TrainMlpRequest{
116TrainMlpRequest: &trainerv1.TrainMLPRequest{
117Dataset: mockDataset,
118},
119},
120}, nil).Times(1),
121mtts.Recv().Return(nil, io.EOF).Times(1),
122mtts.SendAndClose(new(emptypb.Empty)).Return(nil).Times(1),
123mt.Train(context.Background(), mockIP, mockHostName).DoAndReturn(func(ctx context.Context, ip, hostName string) error {
124wg.Done()
125return nil
126}).Times(1),
127)
128
129assert := assert.New(t)
130assert.NoError(svc.Train(stream))
131},
132},
133{
134name: "receive error",
135run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
136ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
137gomock.InOrder(
138mtts.Recv().Return(nil, errors.New("receive error")).Times(1),
139)
140
141assert := assert.New(t)
142assert.EqualError(svc.Train(stream), "receive error")
143},
144},
145{
146name: "open network topology file error",
147run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
148ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
149gomock.InOrder(
150mtts.Recv().Return(&trainerv1.TrainRequest{
151Hostname: mockHostName,
152Ip: mockIP,
153Request: &trainerv1.TrainRequest_TrainGnnRequest{
154TrainGnnRequest: &trainerv1.TrainGNNRequest{
155Dataset: mockDataset,
156},
157},
158}, nil).Times(1),
159
160ms.OpenNetworkTopology(mockHostID).Return(nil, errors.New("open network topology file error")).Times(1),
161)
162
163assert := assert.New(t)
164assert.EqualError(svc.Train(stream),
165"rpc error: code = Internal desc = open network topology failed: open network topology file error")
166},
167},
168{
169name: "open download file error",
170run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
171ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
172networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
173if err != nil {
174t.Fatal(err)
175}
176
177gomock.InOrder(
178mtts.Recv().Return(&trainerv1.TrainRequest{
179Hostname: mockHostName,
180Ip: mockIP,
181Request: &trainerv1.TrainRequest_TrainGnnRequest{
182TrainGnnRequest: &trainerv1.TrainGNNRequest{
183Dataset: mockDataset,
184},
185},
186}, nil).Times(1),
187
188ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1),
189ms.OpenDownload(mockHostID).Return(nil, errors.New("open download file error")).Times(1),
190ms.ClearNetworkTopology(mockHostID).Do(func(id string) {
191networktopologyFile.Close()
192if err := os.Remove(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", id, "csv"))); err != nil {
193t.Fatal(err)
194}
195}).Return(nil).Times(1),
196)
197
198assert := assert.New(t)
199assert.EqualError(svc.Train(stream),
200"rpc error: code = Internal desc = open download failed: open download file error")
201},
202},
203{
204name: "clear network topology file error",
205run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
206ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
207networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
208if err != nil {
209t.Fatal(err)
210}
211
212downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
213if err != nil {
214t.Fatal(err)
215}
216
217gomock.InOrder(
218mtts.Recv().Return(&trainerv1.TrainRequest{
219Hostname: mockHostName,
220Ip: mockIP,
221Request: &trainerv1.TrainRequest_TrainGnnRequest{
222TrainGnnRequest: &trainerv1.TrainGNNRequest{
223Dataset: mockDataset,
224},
225},
226}, nil).Times(1),
227
228ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1),
229ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1),
230mtts.Recv().Return(nil, errors.New("receive error")).Times(1),
231ms.ClearDownload(mockHostID).Do(func(id string) {
232downloadFile.Close()
233if err := os.Remove(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", id, "csv"))); err != nil {
234t.Fatal(err)
235}
236}).Return(nil).Times(1),
237
238ms.ClearNetworkTopology(mockHostID).Do(func(id string) {
239networktopologyFile.Close()
240if err := os.Remove(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", id, "csv"))); err != nil {
241t.Fatal(err)
242}
243}).Return(errors.New("clear network topology file error")).Times(1),
244)
245
246assert := assert.New(t)
247assert.EqualError(svc.Train(stream), "receive error")
248},
249},
250{
251name: "clear download file error",
252run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
253ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
254networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
255if err != nil {
256t.Fatal(err)
257}
258
259downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
260if err != nil {
261t.Fatal(err)
262}
263
264gomock.InOrder(
265mtts.Recv().Return(&trainerv1.TrainRequest{
266Hostname: mockHostName,
267Ip: mockIP,
268Request: &trainerv1.TrainRequest_TrainGnnRequest{
269TrainGnnRequest: &trainerv1.TrainGNNRequest{
270Dataset: mockDataset,
271},
272},
273}, nil).Times(1),
274
275ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1),
276ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1),
277mtts.Recv().Return(nil, errors.New("receive error")).Times(1),
278ms.ClearDownload(mockHostID).Do(func(id string) {
279downloadFile.Close()
280if err := os.Remove(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", id, "csv"))); err != nil {
281t.Fatal(err)
282}
283}).Return(errors.New("clear download file error")).Times(1),
284
285ms.ClearNetworkTopology(mockHostID).Do(func(id string) {
286networktopologyFile.Close()
287if err := os.Remove(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", id, "csv"))); err != nil {
288t.Fatal(err)
289}
290}).Return(nil).Times(1),
291)
292
293assert := assert.New(t)
294assert.EqualError(svc.Train(stream), "receive error")
295},
296},
297{
298name: "store network topology error",
299run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
300ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
301networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
302if err != nil {
303t.Fatal(err)
304}
305
306downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
307if err != nil {
308t.Fatal(err)
309}
310
311gomock.InOrder(
312mtts.Recv().Return(&trainerv1.TrainRequest{
313Hostname: mockHostName,
314Ip: mockIP,
315Request: &trainerv1.TrainRequest_TrainGnnRequest{
316TrainGnnRequest: &trainerv1.TrainGNNRequest{
317Dataset: mockDataset,
318},
319},
320}, nil).Times(1),
321
322ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1),
323ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1),
324)
325
326networktopologyFile.Close()
327assert := assert.New(t)
328assert.EqualError(svc.Train(stream),
329"rpc error: code = Internal desc = write network topology failed: write /tmp/networktopology-52fa2eb710c71cc3e6ba7be6ca82453fcfe59e1c5da358ab3df8b72fd4d2a2cf.csv: file already closed")
330},
331},
332{
333name: "store download error",
334run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
335ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
336networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
337if err != nil {
338t.Fatal(err)
339}
340
341downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
342if err != nil {
343t.Fatal(err)
344}
345
346gomock.InOrder(
347mtts.Recv().Return(&trainerv1.TrainRequest{
348Hostname: mockHostName,
349Ip: mockIP,
350Request: &trainerv1.TrainRequest_TrainMlpRequest{
351TrainMlpRequest: &trainerv1.TrainMLPRequest{
352Dataset: mockDataset,
353},
354},
355}, nil).Times(1),
356
357ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1),
358ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1),
359)
360
361downloadFile.Close()
362assert := assert.New(t)
363assert.EqualError(svc.Train(stream),
364"rpc error: code = Internal desc = write download failed: write /tmp/download-52fa2eb710c71cc3e6ba7be6ca82453fcfe59e1c5da358ab3df8b72fd4d2a2cf.csv: file already closed")
365},
366},
367{
368name: "receive unknown request",
369run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
370ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
371networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
372if err != nil {
373t.Fatal(err)
374}
375
376downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
377if err != nil {
378t.Fatal(err)
379}
380
381gomock.InOrder(
382mtts.Recv().Return(&trainerv1.TrainRequest{
383Hostname: mockHostName,
384Ip: mockIP,
385Request: nil,
386}, nil).Times(1),
387
388ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1),
389ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1),
390)
391
392assert := assert.New(t)
393assert.EqualError(svc.Train(stream), "rpc error: code = FailedPrecondition desc = receive unknown request: <nil>")
394},
395},
396{
397name: "send and close error",
398run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
399ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
400networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
401if err != nil {
402t.Fatal(err)
403}
404
405downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
406if err != nil {
407t.Fatal(err)
408}
409
410gomock.InOrder(
411mtts.Recv().Return(&trainerv1.TrainRequest{
412Hostname: mockHostName,
413Ip: mockIP,
414Request: &trainerv1.TrainRequest_TrainGnnRequest{
415TrainGnnRequest: &trainerv1.TrainGNNRequest{
416Dataset: mockDataset,
417},
418},
419}, nil).Times(1),
420
421ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1),
422ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1),
423mtts.Recv().Return(&trainerv1.TrainRequest{
424Hostname: mockHostName,
425Ip: mockIP,
426Request: &trainerv1.TrainRequest_TrainMlpRequest{
427TrainMlpRequest: &trainerv1.TrainMLPRequest{
428Dataset: mockDataset,
429},
430},
431}, nil).Times(1),
432mtts.Recv().Return(nil, io.EOF).Times(1),
433mtts.SendAndClose(new(emptypb.Empty)).Return(errors.New("send and close error")).Times(1),
434)
435
436assert := assert.New(t)
437assert.EqualError(svc.Train(stream), "send and close error")
438},
439},
440{
441name: "training error",
442run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder,
443ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) {
444networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
445if err != nil {
446t.Fatal(err)
447}
448
449downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
450if err != nil {
451t.Fatal(err)
452}
453
454var wg sync.WaitGroup
455wg.Add(1)
456defer wg.Wait()
457gomock.InOrder(
458mtts.Recv().Return(&trainerv1.TrainRequest{
459Hostname: mockHostName,
460Ip: mockIP,
461Request: &trainerv1.TrainRequest_TrainGnnRequest{
462TrainGnnRequest: &trainerv1.TrainGNNRequest{
463Dataset: mockDataset,
464},
465},
466}, nil).Times(1),
467
468ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1),
469ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1),
470mtts.Recv().Return(&trainerv1.TrainRequest{
471Hostname: mockHostName,
472Ip: mockIP,
473Request: &trainerv1.TrainRequest_TrainMlpRequest{
474TrainMlpRequest: &trainerv1.TrainMLPRequest{
475Dataset: mockDataset,
476},
477},
478}, nil).Times(1),
479mtts.Recv().Return(nil, io.EOF).Times(1),
480mtts.SendAndClose(new(emptypb.Empty)).Return(nil).Times(1),
481mt.Train(context.Background(), mockIP, mockHostName).DoAndReturn(func(ctx context.Context, ip, hostName string) error {
482wg.Done()
483return errors.New("training error")
484}).Times(1),
485)
486
487assert := assert.New(t)
488assert.NoError(svc.Train(stream))
489},
490},
491}
492
493for _, tc := range tests {
494t.Run(tc.name, func(t *testing.T) {
495ctl := gomock.NewController(t)
496defer ctl.Finish()
497storage := storagemocks.NewMockStorage(ctl)
498training := trainingmocks.NewMockTraining(ctl)
499stream := trainerv1mocks.NewMockTrainer_TrainServer(ctl)
500
501svc := NewV1(config.New(), storage, training)
502tc.run(t, svc, stream, stream.EXPECT(), storage.EXPECT(), training.EXPECT())
503})
504}
505}
506