7
"github.com/Khan/genqlient/graphql"
8
"github.com/golang/mock/gomock"
9
"github.com/stretchr/testify/assert"
10
"github.com/wandb/wandb/core/internal/coretest"
11
"github.com/wandb/wandb/core/internal/gql"
12
"github.com/wandb/wandb/core/pkg/observability"
13
"github.com/wandb/wandb/core/pkg/server"
14
"github.com/wandb/wandb/core/pkg/service"
15
"google.golang.org/protobuf/types/known/wrapperspb"
18
func makeSender(client graphql.Client, resultChan chan *service.Result) *server.Sender {
19
ctx, cancel := context.WithCancel(context.Background())
20
logger := observability.NewNoOpLogger()
21
sender := server.NewSender(
26
RunId: &wrapperspb.StringValue{Value: "run1"},
28
server.WithSenderFwdChannel(make(chan *service.Record, 1)),
29
server.WithSenderOutChannel(resultChan),
31
sender.SetGraphqlClient(client)
35
func TestSendRun(t *testing.T) {
36
// Verify that project and entity are properly passed through to graphql
37
to := coretest.MakeTestObject(t)
38
defer to.TeardownTest()
40
sender := makeSender(to.MockClient, make(chan *service.Result, 1))
42
run := &service.Record{
43
RecordType: &service.Record_Run{
44
Run: &service.RunRecord{
45
Config: to.MakeConfig(),
46
Project: "testProject",
49
Control: &service.Control{
54
respEncode := &graphql.Response{
55
Data: &gql.UpsertBucketResponse{
56
UpsertBucket: &gql.UpsertBucketUpsertBucketUpsertBucketPayload{
57
Bucket: &gql.UpsertBucketUpsertBucketUpsertBucketPayloadBucketRun{
58
DisplayName: coretest.StrPtr("FakeName"),
59
Project: &gql.UpsertBucketUpsertBucketUpsertBucketPayloadBucketRunProject{
61
Entity: gql.UpsertBucketUpsertBucketUpsertBucketPayloadBucketRunProjectEntity{
70
to.MockClient.EXPECT().MakeRequest(
71
gomock.Any(), // context.Context
72
gomock.Any(), // *graphql.Request
73
gomock.Any(), // *graphql.Response
74
).Return(nil).Do(coretest.InjectResponse(
76
func(vars coretest.RequestVars) {
77
assert.Equal(t, "testEntity", vars["entity"])
78
assert.Equal(t, "testProject", vars["project"])
82
sender.SendRecord(run)
83
<-sender.GetOutboundChannel()
86
func TestSendLinkArtifact(t *testing.T) {
87
// Verify that arguments are properly passed through to graphql
88
to := coretest.MakeTestObject(t)
89
defer to.TeardownTest()
91
sender := makeSender(to.MockClient, make(chan *service.Result, 1))
93
respEncode := &graphql.Response{
94
Data: &gql.LinkArtifactResponse{
95
LinkArtifact: &gql.LinkArtifactLinkArtifactLinkArtifactPayload{
96
VersionIndex: coretest.IntPtr(0),
100
// 1. When both clientId and serverId are sent, serverId is used
101
linkArtifact := &service.Record{
102
RecordType: &service.Record_LinkArtifact{
103
LinkArtifact: &service.LinkArtifactRecord{
104
ClientId: "clientId",
105
ServerId: "serverId",
106
PortfolioName: "portfolioName",
107
PortfolioEntity: "portfolioEntity",
108
PortfolioProject: "portfolioProject",
110
Control: &service.Control{
115
to.MockClient.EXPECT().MakeRequest(
116
gomock.Any(), // context.Context
117
gomock.Any(), // *graphql.Request
118
gomock.Any(), // *graphql.Response
119
).Return(nil).Do(coretest.InjectResponse(
121
func(vars coretest.RequestVars) {
122
assert.Equal(t, "portfolioProject", vars["projectName"])
123
assert.Equal(t, "portfolioEntity", vars["entityName"])
124
assert.Equal(t, "portfolioName", vars["artifactPortfolioName"])
125
assert.Nil(t, vars["clientId"])
126
assert.Equal(t, "serverId", vars["artifactId"])
130
sender.SendRecord(linkArtifact)
131
<-sender.GetOutboundChannel()
133
// 2. When only clientId is sent, clientId is used
134
linkArtifact = &service.Record{
135
RecordType: &service.Record_LinkArtifact{
136
LinkArtifact: &service.LinkArtifactRecord{
137
ClientId: "clientId",
139
PortfolioName: "portfolioName",
140
PortfolioEntity: "portfolioEntity",
141
PortfolioProject: "portfolioProject",
143
Control: &service.Control{
148
to.MockClient.EXPECT().MakeRequest(
149
gomock.Any(), // context.Context
150
gomock.Any(), // *graphql.Request
151
gomock.Any(), // *graphql.Response
152
).Return(nil).Do(coretest.InjectResponse(
154
func(vars coretest.RequestVars) {
155
assert.Equal(t, "portfolioProject", vars["projectName"])
156
assert.Equal(t, "portfolioEntity", vars["entityName"])
157
assert.Equal(t, "portfolioName", vars["artifactPortfolioName"])
158
assert.Equal(t, "clientId", vars["clientId"])
159
assert.Nil(t, vars["artifactId"])
163
sender.SendRecord(linkArtifact)
164
<-sender.GetOutboundChannel()
166
// 2. When only serverId is sent, serverId is used
167
linkArtifact = &service.Record{
168
RecordType: &service.Record_LinkArtifact{
169
LinkArtifact: &service.LinkArtifactRecord{
171
ServerId: "serverId",
172
PortfolioName: "portfolioName",
173
PortfolioEntity: "portfolioEntity",
174
PortfolioProject: "portfolioProject",
176
Control: &service.Control{
181
to.MockClient.EXPECT().MakeRequest(
182
gomock.Any(), // context.Context
183
gomock.Any(), // *graphql.Request
184
gomock.Any(), // *graphql.Response
185
).Return(nil).Do(coretest.InjectResponse(
187
func(vars coretest.RequestVars) {
188
assert.Equal(t, "portfolioProject", vars["projectName"])
189
assert.Equal(t, "portfolioEntity", vars["entityName"])
190
assert.Equal(t, "portfolioName", vars["artifactPortfolioName"])
191
assert.Nil(t, vars["clientId"])
192
assert.Equal(t, "serverId", vars["artifactId"])
196
sender.SendRecord(linkArtifact)
197
<-sender.GetOutboundChannel()
200
func TestSendUseArtifact(t *testing.T) {
201
to := coretest.MakeTestObject(t)
202
defer to.TeardownTest()
204
sender := makeSender(to.MockClient, make(chan *service.Result, 1))
206
useArtifact := &service.Record{
207
RecordType: &service.Record_UseArtifact{
208
UseArtifact: &service.UseArtifactRecord{
211
Name: "artifactName",
216
// verify doesn't panic if used job artifact
217
sender.SendRecord(useArtifact)
219
// verify doesn't panic if partial job is broken
220
useArtifact = &service.Record{
221
RecordType: &service.Record_UseArtifact{
222
UseArtifact: &service.UseArtifactRecord{
225
Name: "artifactName",
226
Partial: &service.PartialJobArtifact{
228
SourceInfo: &service.JobSource{
230
Source: &service.Source{
231
Git: &service.GitSource{
232
GitInfo: &service.GitInfo{
243
sender.SendRecord(useArtifact)
246
func TestSendArtifact(t *testing.T) {
247
// Verify that arguments are properly passed through to graphql
248
to := coretest.MakeTestObject(t)
249
defer to.TeardownTest()
251
sender := makeSender(to.MockClient, make(chan *service.Result, 1))
253
// 1. When both clientId and serverId are sent, serverId is used
254
artifact := &service.Record{
255
RecordType: &service.Record_Artifact{
256
Artifact: &service.ArtifactRecord{
257
RunId: "test-run-id",
258
Project: "test-project",
259
Entity: "test-entity",
261
Name: "test-artifact",
262
Digest: "test-digest",
263
Aliases: []string{"latest"},
264
Manifest: &service.ArtifactManifest{
266
StoragePolicy: "wandb-storage-policy-v1",
267
Contents: []*service.ArtifactManifestEntry{{
269
Digest: "test1-digest",
271
LocalPath: "/test/local/path",
276
ClientId: "client-id",
277
SequenceClientId: "sequence-client-id",
280
createArtifactRespEncode := &graphql.Response{
281
Data: &gql.CreateArtifactResponse{
282
CreateArtifact: &gql.CreateArtifactCreateArtifactCreateArtifactPayload{
283
Artifact: gql.CreateArtifactCreateArtifactCreateArtifactPayloadArtifact{
288
to.MockClient.EXPECT().MakeRequest(
289
gomock.Any(), // context.Context
290
gomock.Any(), // *graphql.Request
291
gomock.Any(), // *graphql.Response
292
).Return(nil).Do(coretest.InjectResponse(
293
createArtifactRespEncode,
294
func(vars coretest.RequestVars) {
295
assert.Equal(t, "test-entity", vars["entityName"])
298
sender.SendRecord(artifact)