15
"github.com/hashicorp/go-retryablehttp"
16
"github.com/segmentio/encoding/json"
18
"github.com/Khan/genqlient/graphql"
19
"google.golang.org/protobuf/proto"
20
"google.golang.org/protobuf/types/known/wrapperspb"
22
"github.com/wandb/wandb/core/internal/api"
23
"github.com/wandb/wandb/core/internal/clients"
24
"github.com/wandb/wandb/core/internal/debounce"
25
"github.com/wandb/wandb/core/internal/filetransfer"
26
"github.com/wandb/wandb/core/internal/gql"
27
"github.com/wandb/wandb/core/internal/runconfig"
28
"github.com/wandb/wandb/core/internal/shared"
29
"github.com/wandb/wandb/core/internal/version"
30
"github.com/wandb/wandb/core/pkg/artifacts"
31
fs "github.com/wandb/wandb/core/pkg/filestream"
32
"github.com/wandb/wandb/core/pkg/launch"
33
"github.com/wandb/wandb/core/pkg/observability"
34
"github.com/wandb/wandb/core/pkg/service"
35
"github.com/wandb/wandb/core/pkg/utils"
39
// RFC3339Micro Modified from time.RFC3339Nano
40
RFC3339Micro = "2006-01-02T15:04:05.000000Z07:00"
41
configDebouncerRateLimit = 1 / 30.0 // todo: audit rate limit
42
configDebouncerBurstSize = 1 // todo: audit burst size
45
type SenderOption func(*Sender)
47
func WithSenderFwdChannel(fwd chan *service.Record) SenderOption {
48
return func(s *Sender) {
53
func WithSenderOutChannel(out chan *service.Result) SenderOption {
54
return func(s *Sender) {
59
// Sender is the sender for a stream it handles the incoming messages and sends to the server
60
// or/and to the dispatcher/handler
62
// ctx is the context for the handler
65
// cancel is the cancel function for the handler
66
cancel context.CancelFunc
68
// logger is the logger for the sender
69
logger *observability.CoreLogger
71
// settings is the settings for the sender
72
settings *service.Settings
74
// fwdChan is the channel for loopback messages (messages from the sender to the handler)
75
fwdChan chan *service.Record
77
// outChan is the channel for dispatcher messages
78
outChan chan *service.Result
80
// graphqlClient is the graphql client
81
graphqlClient graphql.Client
83
// fileStream is the file stream
84
fileStream *fs.FileStream
86
// filetransfer is the file uploader/downloader
87
fileTransferManager *filetransfer.FileTransferManager
89
// RunRecord is the run record
90
// TODO: remove this and use properly updated settings
91
// + a flag indicating whether the run has started
92
RunRecord *service.RunRecord
94
// resumeState is the resume state
95
resumeState *ResumeState
97
telemetry *service.TelemetryRecord
99
metricSender *MetricSender
101
configDebouncer *debounce.Debouncer
103
// Keep track of summary which is being updated incrementally
104
summaryMap map[string]*service.SummaryItem
106
// Keep track of config which is being updated incrementally
107
runConfig *runconfig.RunConfig
109
// Info about the (local) server we are talking to
110
serverInfo *gql.ServerInfoServerInfo
112
// Keep track of exit record to pass to file stream when the time comes
113
exitRecord *service.Record
115
syncService *SyncService
119
jobBuilder *launch.JobBuilder
121
wgFileTransfer sync.WaitGroup
124
// NewSender creates a new Sender with the given settings
127
cancel context.CancelFunc,
128
logger *observability.CoreLogger,
129
settings *service.Settings,
130
opts ...SenderOption,
138
summaryMap: make(map[string]*service.SummaryItem),
139
runConfig: runconfig.New(),
140
telemetry: &service.TelemetryRecord{CoreVersion: version.Version},
141
wgFileTransfer: sync.WaitGroup{},
143
if !settings.GetXOffline().GetValue() {
144
baseURL, err := url.Parse(settings.GetBaseUrl().GetValue())
146
logger.CaptureFatalAndPanic("sender: failed to parse base URL", err)
148
backend := api.New(api.BackendOptions{
150
Logger: logger.Logger,
151
APIKey: settings.GetApiKey().GetValue(),
154
graphqlHeaders := map[string]string{
155
"X-WANDB-USERNAME": settings.GetUsername().GetValue(),
156
"X-WANDB-USER-EMAIL": settings.GetEmail().GetValue(),
158
maps.Copy(graphqlHeaders, settings.GetXExtraHttpHeaders().GetValue())
160
graphqlClient := backend.NewClient(api.ClientOptions{
161
RetryPolicy: clients.CheckRetry,
162
RetryMax: int(settings.GetXGraphqlRetryMax().GetValue()),
163
RetryWaitMin: clients.SecondsToDuration(settings.GetXGraphqlRetryWaitMinSeconds().GetValue()),
164
RetryWaitMax: clients.SecondsToDuration(settings.GetXGraphqlRetryWaitMaxSeconds().GetValue()),
165
NonRetryTimeout: clients.SecondsToDuration(settings.GetXGraphqlTimeoutSeconds().GetValue()),
166
ExtraHeaders: graphqlHeaders,
168
url := fmt.Sprintf("%s/graphql", settings.GetBaseUrl().GetValue())
169
sender.graphqlClient = graphql.NewClient(url, graphqlClient)
171
fileStreamHeaders := map[string]string{}
172
if settings.GetXShared().GetValue() {
173
fileStreamHeaders["X-WANDB-USE-ASYNC-FILESTREAM"] = "true"
176
fileStreamRetryClient := backend.NewClient(api.ClientOptions{
177
RetryMax: int(settings.GetXFileStreamRetryMax().GetValue()),
178
RetryWaitMin: clients.SecondsToDuration(settings.GetXFileStreamRetryWaitMinSeconds().GetValue()),
179
RetryWaitMax: clients.SecondsToDuration(settings.GetXFileStreamRetryWaitMaxSeconds().GetValue()),
180
NonRetryTimeout: clients.SecondsToDuration(settings.GetXFileStreamTimeoutSeconds().GetValue()),
181
ExtraHeaders: fileStreamHeaders,
184
sender.fileStream = fs.NewFileStream(
185
fs.WithSettings(settings),
186
fs.WithLogger(logger),
187
fs.WithAPIClient(fileStreamRetryClient),
188
fs.WithClientId(shared.ShortID(32)),
191
fileTransferRetryClient := retryablehttp.NewClient()
192
fileTransferRetryClient.Logger = logger
193
fileTransferRetryClient.CheckRetry = clients.CheckRetry
194
fileTransferRetryClient.RetryMax = int(settings.GetXFileTransferRetryMax().GetValue())
195
fileTransferRetryClient.RetryWaitMin = clients.SecondsToDuration(settings.GetXFileTransferRetryWaitMinSeconds().GetValue())
196
fileTransferRetryClient.RetryWaitMax = clients.SecondsToDuration(settings.GetXFileTransferRetryWaitMaxSeconds().GetValue())
197
fileTransferRetryClient.HTTPClient.Timeout = clients.SecondsToDuration(settings.GetXFileTransferTimeoutSeconds().GetValue())
198
fileTransferRetryClient.Backoff = clients.ExponentialBackoffWithJitter
200
defaultFileTransfer := filetransfer.NewDefaultFileTransfer(
202
fileTransferRetryClient,
204
sender.fileTransferManager = filetransfer.NewFileTransferManager(
205
filetransfer.WithLogger(logger),
206
filetransfer.WithSettings(settings),
207
filetransfer.WithFileTransfer(defaultFileTransfer),
208
filetransfer.WithFSCChan(sender.fileStream.GetInputChan()),
211
sender.getServerInfo()
213
if !settings.GetDisableJobCreation().GetValue() {
214
sender.jobBuilder = launch.NewJobBuilder(settings, logger)
217
sender.configDebouncer = debounce.NewDebouncer(
218
configDebouncerRateLimit,
219
configDebouncerBurstSize,
223
for _, opt := range opts {
230
// do sending of messages to the server
231
func (s *Sender) Do(inChan <-chan *service.Record) {
232
defer s.logger.Reraise()
233
s.logger.Info("sender: started", "stream_id", s.settings.RunId)
235
for record := range inChan {
237
// TODO: reevaluate the logic here
238
s.configDebouncer.Debounce(s.upsertConfig)
241
s.logger.Info("sender: closed", "stream_id", s.settings.RunId)
244
func (s *Sender) Close() {
245
// sender is done processing data, close our dispatch channel
249
func (s *Sender) GetOutboundChannel() chan *service.Result {
253
func (s *Sender) SetGraphqlClient(client graphql.Client) {
254
s.graphqlClient = client
257
func (s *Sender) SendRecord(record *service.Record) {
258
// this is for testing purposes only yet
262
// sendRecord sends a record
263
func (s *Sender) sendRecord(record *service.Record) {
264
s.logger.Debug("sender: sendRecord", "record", record, "stream_id", s.settings.RunId)
265
switch x := record.RecordType.(type) {
266
case *service.Record_Run:
267
s.sendRun(record, x.Run)
268
case *service.Record_Footer:
269
case *service.Record_Header:
270
case *service.Record_Final:
271
case *service.Record_Exit:
272
s.sendExit(record, x.Exit)
273
case *service.Record_Alert:
274
s.sendAlert(record, x.Alert)
275
case *service.Record_Metric:
276
s.sendMetric(record, x.Metric)
277
case *service.Record_Files:
278
s.sendFiles(record, x.Files)
279
case *service.Record_History:
280
s.sendHistory(record, x.History)
281
case *service.Record_Summary:
282
s.sendSummary(record, x.Summary)
283
case *service.Record_Config:
284
s.sendConfig(record, x.Config)
285
case *service.Record_Stats:
286
s.sendSystemMetrics(record, x.Stats)
287
case *service.Record_OutputRaw:
288
s.sendOutputRaw(record, x.OutputRaw)
289
case *service.Record_Telemetry:
290
s.sendTelemetry(record, x.Telemetry)
291
case *service.Record_Preempting:
292
s.sendPreempting(record)
293
case *service.Record_Request:
294
s.sendRequest(record, x.Request)
295
case *service.Record_LinkArtifact:
296
s.sendLinkArtifact(record)
297
case *service.Record_UseArtifact:
298
s.sendUseArtifact(record)
299
case *service.Record_Artifact:
300
s.sendArtifact(record, x.Artifact)
302
err := fmt.Errorf("sender: sendRecord: nil RecordType")
303
s.logger.CaptureFatalAndPanic("sender: sendRecord: nil RecordType", err)
305
err := fmt.Errorf("sender: sendRecord: unexpected type %T", x)
306
s.logger.CaptureFatalAndPanic("sender: sendRecord: unexpected type", err)
310
// sendRequest sends a request
311
func (s *Sender) sendRequest(record *service.Record, request *service.Request) {
313
switch x := request.RequestType.(type) {
314
case *service.Request_RunStart:
315
s.sendRunStart(x.RunStart)
316
case *service.Request_NetworkStatus:
317
s.sendNetworkStatusRequest(x.NetworkStatus)
318
case *service.Request_Defer:
320
case *service.Request_LogArtifact:
321
s.sendLogArtifact(record, x.LogArtifact)
322
case *service.Request_PollExit:
323
case *service.Request_ServerInfo:
324
s.sendServerInfo(record, x.ServerInfo)
325
case *service.Request_DownloadArtifact:
326
s.sendDownloadArtifact(record, x.DownloadArtifact)
327
case *service.Request_Sync:
328
s.sendSync(record, x.Sync)
329
case *service.Request_SenderRead:
330
s.sendSenderRead(record, x.SenderRead)
331
case *service.Request_Cancel:
334
err := fmt.Errorf("sender: sendRequest: nil RequestType")
335
s.logger.CaptureFatalAndPanic("sender: sendRequest: nil RequestType", err)
337
err := fmt.Errorf("sender: sendRequest: unexpected type %T", x)
338
s.logger.CaptureFatalAndPanic("sender: sendRequest: unexpected type", err)
342
// updateSettings updates the settings from the run record upon a run start
343
// with the information from the server
344
func (s *Sender) updateSettings() {
345
if s.settings == nil || s.RunRecord == nil {
349
if s.settings.XStartTime == nil && s.RunRecord.StartTime != nil {
350
startTime := float64(s.RunRecord.StartTime.Seconds) + float64(s.RunRecord.StartTime.Nanos)/1e9
351
s.settings.XStartTime = &wrapperspb.DoubleValue{Value: startTime}
354
// TODO: verify that this is the correct update logic
355
if s.RunRecord.GetEntity() != "" {
356
s.settings.Entity = &wrapperspb.StringValue{Value: s.RunRecord.Entity}
358
if s.RunRecord.GetProject() != "" && s.settings.Project == nil {
359
s.settings.Project = &wrapperspb.StringValue{Value: s.RunRecord.Project}
361
if s.RunRecord.GetDisplayName() != "" && s.settings.RunName == nil {
362
s.settings.RunName = &wrapperspb.StringValue{Value: s.RunRecord.DisplayName}
366
// sendRun starts up all the resources for a run
367
func (s *Sender) sendRunStart(_ *service.RunStartRequest) {
368
fsPath := fmt.Sprintf(
369
"files/%s/%s/%s/file_stream",
375
fs.WithPath(fsPath)(s.fileStream)
376
fs.WithOffsets(s.resumeState.GetFileStreamOffset())(s.fileStream)
380
s.fileTransferManager.Start()
383
func (s *Sender) sendNetworkStatusRequest(_ *service.NetworkStatusRequest) {
386
func (s *Sender) sendJobFlush() {
387
if s.jobBuilder == nil {
390
input := s.runConfig.Tree()
391
output := make(map[string]interface{})
394
for k, v := range s.summaryMap {
395
bytes := []byte(v.GetValueJson())
396
err := json.Unmarshal(bytes, &out)
398
s.logger.Error("sender: sendDefer: failed to unmarshal summary", "error", err)
404
artifact, err := s.jobBuilder.Build(input, output)
406
s.logger.Error("sender: sendDefer: failed to build job artifact", "error", err)
410
s.logger.Info("sender: sendDefer: no job artifact to save")
413
saver := artifacts.NewArtifactSaver(
414
s.ctx, s.graphqlClient, s.fileTransferManager, artifact, 0, "",
416
if _, err = saver.Save(s.fwdChan); err != nil {
417
s.logger.Error("sender: sendDefer: failed to save job artifact", "error", err)
421
func (s *Sender) sendDefer(request *service.DeferRequest) {
422
switch request.State {
423
case service.DeferRequest_BEGIN:
425
s.sendRequestDefer(request)
426
case service.DeferRequest_FLUSH_RUN:
428
s.sendRequestDefer(request)
429
case service.DeferRequest_FLUSH_STATS:
431
s.sendRequestDefer(request)
432
case service.DeferRequest_FLUSH_PARTIAL_HISTORY:
434
s.sendRequestDefer(request)
435
case service.DeferRequest_FLUSH_TB:
437
s.sendRequestDefer(request)
438
case service.DeferRequest_FLUSH_SUM:
440
s.sendRequestDefer(request)
441
case service.DeferRequest_FLUSH_DEBOUNCER:
442
s.configDebouncer.Flush(s.upsertConfig)
443
s.writeAndSendConfigFile()
445
s.sendRequestDefer(request)
446
case service.DeferRequest_FLUSH_OUTPUT:
448
s.sendRequestDefer(request)
449
case service.DeferRequest_FLUSH_JOB:
452
s.sendRequestDefer(request)
453
case service.DeferRequest_FLUSH_DIR:
455
s.sendRequestDefer(request)
456
case service.DeferRequest_FLUSH_FP:
457
s.wgFileTransfer.Wait()
458
s.fileTransferManager.Close()
460
s.sendRequestDefer(request)
461
case service.DeferRequest_JOIN_FP:
463
s.sendRequestDefer(request)
464
case service.DeferRequest_FLUSH_FS:
467
s.sendRequestDefer(request)
468
case service.DeferRequest_FLUSH_FINAL:
470
s.sendRequestDefer(request)
471
case service.DeferRequest_END:
473
s.syncService.Flush()
474
s.respondExit(s.exitRecord)
475
// cancel tells the stream to close the loopback channel
478
err := fmt.Errorf("sender: sendDefer: unexpected state %v", request.State)
479
s.logger.CaptureFatalAndPanic("sender: sendDefer: unexpected state", err)
483
func (s *Sender) sendRequestDefer(request *service.DeferRequest) {
484
rec := &service.Record{
485
RecordType: &service.Record_Request{Request: &service.Request{
486
RequestType: &service.Request_Defer{Defer: request},
488
Control: &service.Control{AlwaysSend: true},
493
func (s *Sender) sendTelemetry(_ *service.Record, telemetry *service.TelemetryRecord) {
494
proto.Merge(s.telemetry, telemetry)
495
s.updateConfigPrivate()
496
// TODO(perf): improve when debounce config is added, for now this sends all the time
497
s.sendConfig(nil, nil /*configRecord*/)
500
func (s *Sender) sendPreempting(record *service.Record) {
501
s.fileStream.StreamRecord(record)
504
func (s *Sender) sendLinkArtifact(record *service.Record) {
505
linker := artifacts.ArtifactLinker{
508
LinkArtifact: record.GetLinkArtifact(),
509
GraphqlClient: s.graphqlClient,
513
s.logger.CaptureFatalAndPanic("sender: sendLinkArtifact: link failure", err)
516
result := &service.Result{
517
Control: record.Control,
523
func (s *Sender) sendUseArtifact(record *service.Record) {
524
if s.jobBuilder == nil {
525
s.logger.Warn("sender: sendUseArtifact: job builder disabled, skipping")
528
s.jobBuilder.HandleUseArtifactRecord(record)
531
// Applies the change record to the run configuration.
532
func (s *Sender) updateConfig(configRecord *service.ConfigRecord) {
533
s.runConfig.ApplyChangeRecord(configRecord, func(err error) {
534
s.logger.CaptureError("Error updating run config", err)
538
// Inserts W&B-internal information into the run configuration.
540
// Uses the given telemetry
541
func (s *Sender) updateConfigPrivate() {
542
metrics := []map[int]interface{}(nil)
543
if s.metricSender != nil {
544
metrics = s.metricSender.configMetrics
547
s.runConfig.AddTelemetryAndMetrics(s.telemetry, metrics)
550
// Serializes the run configuration to send to the backend.
551
func (s *Sender) serializeConfig(format runconfig.ConfigFormat) string {
552
serializedConfig, err := s.runConfig.Serialize(format)
555
err = fmt.Errorf("failed to marshal config: %s", err)
556
s.logger.CaptureFatalAndPanic("sender: sendRun: ", err)
559
return string(serializedConfig)
562
func (s *Sender) sendRunResult(record *service.Record, runResult *service.RunUpdateResult) {
563
result := &service.Result{
564
ResultType: &service.Result_RunResult{
565
RunResult: runResult,
567
Control: record.Control,
574
func (s *Sender) checkAndUpdateResumeState(record *service.Record) error {
575
if s.graphqlClient == nil {
578
// There was no resume status set, so we don't need to do anything
579
if s.settings.GetResume().GetValue() == "" {
583
// init resume state if it doesn't exist
584
s.resumeState = NewResumeState(s.logger, s.settings.GetResume().GetValue())
586
// If we couldn't get the resume status, we should fail if resume is set
587
data, err := gql.RunResumeStatus(s.ctx, s.graphqlClient, &run.Project, utils.NilIfZero(run.Entity), run.RunId)
589
err = fmt.Errorf("failed to get run resume status: %s", err)
590
s.logger.Error("sender:", "error", err)
591
result := &service.RunUpdateResult{
592
Error: &service.ErrorInfo{
593
Message: err.Error(),
594
Code: service.ErrorInfo_COMMUNICATION,
596
s.sendRunResult(record, result)
600
if result, err := s.resumeState.Update(
605
s.sendRunResult(record, result)
612
func (s *Sender) sendRun(record *service.Record, run *service.RunRecord) {
613
if s.graphqlClient != nil {
614
// The first run record sent by the client is encoded incorrectly,
615
// causing it to overwrite the entire "_wandb" config key rather than
616
// just the necessary part ("_wandb/code_path"). This can overwrite
617
// the config from a resumed run, so we have to do this first.
619
// Logically, it would make more sense to instead start with the
620
// resumed config and apply updates on top of it.
621
s.updateConfig(run.Config)
622
proto.Merge(s.telemetry, run.Telemetry)
623
s.updateConfigPrivate()
625
if s.RunRecord == nil {
627
s.RunRecord, ok = proto.Clone(run).(*service.RunRecord)
629
err := fmt.Errorf("failed to clone RunRecord")
630
s.logger.CaptureFatalAndPanic("sender: sendRun: ", err)
633
if err := s.checkAndUpdateResumeState(record); err != nil {
634
s.logger.Error("sender: sendRun: failed to checkAndUpdateResumeState", "error", err)
639
config := s.serializeConfig(runconfig.FormatJson)
642
tags = append(tags, run.Tags...)
644
var commit, repo string
647
commit = git.GetCommit()
648
repo = git.GetRemoteUrl()
651
program := s.settings.GetProgram().GetValue()
652
// start a new context with an additional argument from the parent context
653
// this is used to pass the retry function to the graphql client
654
ctx := context.WithValue(s.ctx, clients.CtxRetryPolicyKey, clients.UpsertBucketRetryPolicy)
655
data, err := gql.UpsertBucket(
657
s.graphqlClient, // client
660
utils.NilIfZero(run.Project), // project
661
utils.NilIfZero(run.Entity), // entity
662
utils.NilIfZero(run.RunGroup), // groupName
664
utils.NilIfZero(run.DisplayName), // displayName
665
utils.NilIfZero(run.Notes), // notes
666
utils.NilIfZero(commit), // commit
668
utils.NilIfZero(run.Host), // host
670
utils.NilIfZero(program), // program
671
utils.NilIfZero(repo), // repo
672
utils.NilIfZero(run.JobType), // jobType
674
utils.NilIfZero(run.SweepId), // sweep
675
tags, // tags []string,
676
nil, // summaryMetrics
679
err = fmt.Errorf("failed to upsert bucket: %s", err)
680
s.logger.Error("sender: sendRun:", "error", err)
681
// TODO(run update): handle error communication back to the client
682
fmt.Println("ERROR: failed to upsert bucket", err.Error())
683
// TODO(sync): make this more robust in case of a failed UpsertBucket request.
684
// Need to inform the sync service that this ops failed.
685
if record.GetControl().GetReqResp() || record.GetControl().GetMailboxSlot() != "" {
686
result := &service.Result{
687
ResultType: &service.Result_RunResult{
688
RunResult: &service.RunUpdateResult{
689
Error: &service.ErrorInfo{
690
Message: err.Error(),
691
Code: service.ErrorInfo_COMMUNICATION,
695
Control: record.Control,
703
bucket := data.GetUpsertBucket().GetBucket()
704
project := bucket.GetProject()
705
entity := project.GetEntity()
706
s.RunRecord.StorageId = bucket.GetId()
707
// s.RunRecord.RunId = bucket.GetName()
708
s.RunRecord.DisplayName = utils.ZeroIfNil(bucket.GetDisplayName())
709
s.RunRecord.Project = project.GetName()
710
s.RunRecord.Entity = entity.GetName()
711
s.RunRecord.SweepId = utils.ZeroIfNil(bucket.GetSweepName())
714
if record.GetControl().GetReqResp() || record.GetControl().GetMailboxSlot() != "" {
715
runResult := s.RunRecord
716
if runResult == nil {
719
result := &service.Result{
720
ResultType: &service.Result_RunResult{
721
RunResult: &service.RunUpdateResult{Run: runResult},
723
Control: record.Control,
730
// sendHistory sends a history record to the file stream,
731
// which will then send it to the server
732
func (s *Sender) sendHistory(record *service.Record, _ *service.HistoryRecord) {
733
s.fileStream.StreamRecord(record)
736
func (s *Sender) sendSummary(_ *service.Record, summary *service.SummaryRecord) {
737
// TODO(network): buffer summary sending for network efficiency until we can send only updates
738
// TODO(compat): handle deletes, nested keys
739
// TODO(compat): write summary file
741
// track each key in the in memory summary store
742
// TODO(memory): avoid keeping summary for all distinct keys
743
for _, item := range summary.Update {
744
s.summaryMap[item.Key] = item
747
// build list of summary items from the map
748
var summaryItems []*service.SummaryItem
749
for _, v := range s.summaryMap {
750
summaryItems = append(summaryItems, v)
753
// build a full summary record to send
754
record := &service.Record{
755
RecordType: &service.Record_Summary{
756
Summary: &service.SummaryRecord{
757
Update: summaryItems,
762
s.fileStream.StreamRecord(record)
765
func (s *Sender) upsertConfig() {
766
if s.graphqlClient == nil {
769
config := s.serializeConfig(runconfig.FormatJson)
771
ctx := context.WithValue(s.ctx, clients.CtxRetryPolicyKey, clients.UpsertBucketRetryPolicy)
772
_, err := gql.UpsertBucket(
774
s.graphqlClient, // client
776
&s.RunRecord.RunId, // name
777
utils.NilIfZero(s.RunRecord.Project), // project
778
utils.NilIfZero(s.RunRecord.Entity), // entity
792
nil, // tags []string,
793
nil, // summaryMetrics
796
s.logger.Error("sender: sendConfig:", "error", err)
800
func (s *Sender) writeAndSendConfigFile() {
801
if s.settings.GetXSync().GetValue() {
802
// if sync is enabled, we don't need to do all this
806
config := s.serializeConfig(runconfig.FormatYaml)
807
configFile := filepath.Join(s.settings.GetFilesDir().GetValue(), ConfigFileName)
808
if err := os.WriteFile(configFile, []byte(config), 0644); err != nil {
809
s.logger.Error("sender: writeAndSendConfigFile: failed to write config file", "error", err)
812
record := &service.Record{
813
RecordType: &service.Record_Files{
814
Files: &service.FilesRecord{
815
Files: []*service.FilesItem{
817
Path: ConfigFileName,
818
Type: service.FilesItem_WANDB,
827
// sendConfig sends a config record to the server via an upsertBucket mutation
828
// and updates the in memory config
829
func (s *Sender) sendConfig(_ *service.Record, configRecord *service.ConfigRecord) {
830
if configRecord != nil {
831
s.updateConfig(configRecord)
833
s.configDebouncer.SetNeedsDebounce()
836
// sendSystemMetrics sends a system metrics record via the file stream
837
func (s *Sender) sendSystemMetrics(record *service.Record, _ *service.StatsRecord) {
838
s.fileStream.StreamRecord(record)
841
func (s *Sender) sendOutputRaw(record *service.Record, _ *service.OutputRawRecord) {
842
// TODO: match logic handling of lines to the one in the python version
843
// - handle carriage returns (for tqdm-like progress bars)
844
// - handle caching multiple (non-new lines) and sending them in one chunk
845
// - handle lines longer than ~60_000 characters
847
// copy the record to avoid mutating the original
848
recordCopy := proto.Clone(record).(*service.Record)
849
outputRaw := recordCopy.GetOutputRaw()
851
// ignore empty "new lines"
852
if outputRaw.Line == "\n" {
856
outputFile := filepath.Join(s.settings.GetFilesDir().GetValue(), OutputFileName)
857
// append line to file
858
f, err := os.OpenFile(outputFile, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
860
s.logger.Error("sender: sendOutputRaw: failed to open output file", "error", err)
862
if _, err := f.WriteString(outputRaw.Line + "\n"); err != nil {
863
s.logger.Error("sender: sendOutputRaw: failed to write to output file", "error", err)
866
if err := f.Close(); err != nil {
867
s.logger.Error("sender: sendOutputRaw: failed to close output file", "error", err)
871
// generate compatible timestamp to python iso-format (microseconds without Z)
872
t := strings.TrimSuffix(time.Now().UTC().Format(RFC3339Micro), "Z")
873
outputRaw.Line = fmt.Sprintf("%s %s", t, outputRaw.Line)
874
if outputRaw.OutputType == service.OutputRawRecord_STDERR {
875
outputRaw.Line = fmt.Sprintf("ERROR %s", outputRaw.Line)
877
s.fileStream.StreamRecord(recordCopy)
880
func (s *Sender) sendAlert(_ *service.Record, alert *service.AlertRecord) {
881
if s.graphqlClient == nil {
885
if s.RunRecord == nil {
886
err := fmt.Errorf("sender: sendAlert: RunRecord not set")
887
s.logger.CaptureFatalAndPanic("sender received error", err)
889
// TODO: handle invalid alert levels
890
severity := gql.AlertSeverity(alert.Level)
892
data, err := gql.NotifyScriptableRunAlert(
904
err = fmt.Errorf("sender: sendAlert: failed to notify scriptable run alert: %s", err)
905
s.logger.CaptureError("sender received error", err)
907
s.logger.Info("sender: sendAlert: notified scriptable run alert", "data", data)
912
// respondExit called from the end of the defer state machine
913
func (s *Sender) respondExit(record *service.Record) {
914
if record == nil || s.settings.GetXSync().GetValue() {
917
if record.Control.ReqResp || record.Control.MailboxSlot != "" {
918
result := &service.Result{
919
ResultType: &service.Result_ExitResult{ExitResult: &service.RunExitResult{}},
920
Control: record.Control,
927
// sendExit sends an exit record to the server and triggers the shutdown of the stream
928
func (s *Sender) sendExit(record *service.Record, _ *service.RunExitRecord) {
929
// response is done by respondExit() and called when defer state machine is complete
930
s.exitRecord = record
932
s.fileStream.StreamRecord(record)
934
// send a defer request to the handler to indicate that the user requested to finish the stream
935
// and the defer state machine can kick in triggering the shutdown process
936
request := &service.Request{RequestType: &service.Request_Defer{
937
Defer: &service.DeferRequest{State: service.DeferRequest_BEGIN}},
939
if record.Control == nil {
940
record.Control = &service.Control{AlwaysSend: true}
943
rec := &service.Record{
944
RecordType: &service.Record_Request{Request: request},
945
Control: record.Control,
951
// sendMetric sends a metrics record to the file stream,
952
// which will then send it to the server
953
func (s *Sender) sendMetric(record *service.Record, metric *service.MetricRecord) {
954
if s.metricSender == nil {
955
s.metricSender = NewMetricSender()
958
if metric.GetGlobName() != "" {
959
s.logger.Warn("sender: sendMetric: glob name is not supported in the backend", "globName", metric.GetGlobName())
963
s.encodeMetricHints(record, metric)
964
s.updateConfigPrivate()
965
s.sendConfig(nil, nil /*configRecord*/)
968
// sendFiles iterates over the files in the FilesRecord and sends them to
969
func (s *Sender) sendFiles(_ *service.Record, filesRecord *service.FilesRecord) {
970
files := filesRecord.GetFiles()
971
for _, file := range files {
972
if strings.HasPrefix(file.GetPath(), "media") {
973
file.Type = service.FilesItem_MEDIA
975
s.wgFileTransfer.Add(1)
976
go func(file *service.FilesItem) {
978
s.wgFileTransfer.Done()
983
// sendFile sends a file to the server
984
// TODO: improve this to handle multiple files and send them in one request
985
func (s *Sender) sendFile(file *service.FilesItem) {
986
if s.graphqlClient == nil || s.fileTransferManager == nil {
990
if s.RunRecord == nil {
991
err := fmt.Errorf("sender: sendFile: RunRecord not set")
992
s.logger.CaptureFatalAndPanic("sender received error", err)
995
fullPath := filepath.Join(s.settings.GetFilesDir().GetValue(), file.GetPath())
996
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
997
s.logger.Warn("sender: sendFile: file does not exist", "path", fullPath)
1001
data, err := gql.CreateRunFiles(
1005
s.RunRecord.Project,
1007
[]string{file.GetPath()},
1010
err = fmt.Errorf("sender: sendFile: failed to get upload urls: %s", err)
1011
s.logger.CaptureError("sender received error", err)
1014
headers := data.GetCreateRunFiles().GetUploadHeaders()
1015
for _, f := range data.GetCreateRunFiles().GetFiles() {
1016
fullPath := filepath.Join(s.settings.GetFilesDir().GetValue(), f.Name)
1017
task := &filetransfer.Task{
1018
Type: filetransfer.UploadTask,
1025
task.SetProgressCallback(
1026
func(processed, total int) {
1030
record := &service.Record{
1031
RecordType: &service.Record_Request{
1032
Request: &service.Request{
1033
RequestType: &service.Request_FileTransferInfo{
1034
FileTransferInfo: &service.FileTransferInfoRequest{
1035
Type: service.FileTransferInfoRequest_Upload,
1038
Processed: int64(processed),
1047
task.SetCompletionCallback(
1048
func(t *filetransfer.Task) {
1049
s.fileTransferManager.FileStreamCallback(t)
1050
fileCounts := &service.FileCounts{}
1051
switch file.GetType() {
1052
case service.FilesItem_MEDIA:
1053
fileCounts.MediaCount = 1
1054
case service.FilesItem_OTHER:
1055
fileCounts.OtherCount = 1
1056
case service.FilesItem_WANDB:
1057
fileCounts.WandbCount = 1
1060
record := &service.Record{
1061
RecordType: &service.Record_Request{
1062
Request: &service.Request{
1063
RequestType: &service.Request_FileTransferInfo{
1064
FileTransferInfo: &service.FileTransferInfoRequest{
1065
Type: service.FileTransferInfoRequest_Upload,
1069
FileCounts: fileCounts,
1078
s.fileTransferManager.AddTask(task)
1082
func (s *Sender) sendArtifact(record *service.Record, msg *service.ArtifactRecord) {
1083
saver := artifacts.NewArtifactSaver(
1084
s.ctx, s.graphqlClient, s.fileTransferManager, msg, 0, "",
1086
artifactID, err := saver.Save(s.fwdChan)
1088
err = fmt.Errorf("sender: sendArtifact: failed to log artifact ID: %s; error: %s", artifactID, err)
1089
s.logger.Error("sender: sendArtifact:", "error", err)
1094
func (s *Sender) sendLogArtifact(record *service.Record, msg *service.LogArtifactRequest) {
1095
var response service.LogArtifactResponse
1096
saver := artifacts.NewArtifactSaver(
1097
s.ctx, s.graphqlClient, s.fileTransferManager, msg.Artifact, msg.HistoryStep, msg.StagingDir,
1099
artifactID, err := saver.Save(s.fwdChan)
1101
response.ErrorMessage = err.Error()
1103
response.ArtifactId = artifactID
1106
result := &service.Result{
1107
ResultType: &service.Result_Response{
1108
Response: &service.Response{
1109
ResponseType: &service.Response_LogArtifactResponse{
1110
LogArtifactResponse: &response,
1114
Control: record.Control,
1117
s.jobBuilder.HandleLogArtifactResult(&response, msg.Artifact)
1121
func (s *Sender) sendDownloadArtifact(record *service.Record, msg *service.DownloadArtifactRequest) {
1122
// TODO: this should be handled by a separate service starup mechanism
1123
s.fileTransferManager.Start()
1125
var response service.DownloadArtifactResponse
1126
downloader := artifacts.NewArtifactDownloader(s.ctx, s.graphqlClient, s.fileTransferManager, msg.ArtifactId, msg.DownloadRoot, &msg.AllowMissingReferences)
1127
err := downloader.Download()
1129
s.logger.CaptureError("senderError: downloadArtifact: failed to download artifact: %v", err)
1130
response.ErrorMessage = err.Error()
1133
result := &service.Result{
1134
ResultType: &service.Result_Response{
1135
Response: &service.Response{
1136
ResponseType: &service.Response_DownloadArtifactResponse{
1137
DownloadArtifactResponse: &response,
1141
Control: record.Control,
1147
func (s *Sender) sendSync(record *service.Record, request *service.SyncRequest) {
1149
s.syncService = NewSyncService(s.ctx,
1150
WithSyncServiceLogger(s.logger),
1151
WithSyncServiceSenderFunc(s.sendRecord),
1152
WithSyncServiceOverwrite(request.GetOverwrite()),
1153
WithSyncServiceSkip(request.GetSkip()),
1154
WithSyncServiceFlushCallback(func(err error) {
1155
var errorInfo *service.ErrorInfo
1157
errorInfo = &service.ErrorInfo{
1158
Message: err.Error(),
1159
Code: service.ErrorInfo_UNKNOWN,
1164
if s.RunRecord != nil {
1165
baseUrl := s.settings.GetBaseUrl().GetValue()
1166
baseUrl = strings.Replace(baseUrl, "api.", "", 1)
1167
url = fmt.Sprintf("%s/%s/%s/runs/%s", baseUrl, s.RunRecord.Entity, s.RunRecord.Project, s.RunRecord.RunId)
1169
result := &service.Result{
1170
ResultType: &service.Result_Response{
1171
Response: &service.Response{
1172
ResponseType: &service.Response_SyncResponse{
1173
SyncResponse: &service.SyncResponse{
1180
Control: record.Control,
1186
s.syncService.Start()
1188
rec := &service.Record{
1189
RecordType: &service.Record_Request{
1190
Request: &service.Request{
1191
RequestType: &service.Request_SenderRead{
1192
SenderRead: &service.SenderReadRequest{
1193
StartOffset: request.GetStartOffset(),
1194
FinalOffset: request.GetFinalOffset(),
1199
Control: record.Control,
1205
func (s *Sender) sendSenderRead(record *service.Record, request *service.SenderReadRequest) {
1207
store := NewStore(s.ctx, s.settings.GetSyncFile().GetValue(), s.logger)
1208
err := store.Open(os.O_RDONLY)
1210
s.logger.CaptureError("sender: sendSenderRead: failed to create store", err)
1216
// 1. seek to startOffset
1218
// if err := s.store.reader.SeekRecord(request.GetStartOffset()); err != nil {
1219
// s.logger.CaptureError("sender: sendSenderRead: failed to seek record", err)
1222
// 2. read records until finalOffset
1225
record, err := s.store.Read()
1226
if s.settings.GetXSync().GetValue() {
1227
s.syncService.SyncRecord(record, err)
1228
} else if record != nil {
1229
s.sendRecord(record)
1235
s.logger.CaptureError("sender: sendSenderRead: failed to read record", err)
1241
func (s *Sender) getServerInfo() {
1242
if s.graphqlClient == nil {
1246
data, err := gql.ServerInfo(s.ctx, s.graphqlClient)
1248
err = fmt.Errorf("sender: getServerInfo: failed to get server info: %s", err)
1249
s.logger.CaptureError("sender received error", err)
1252
s.serverInfo = data.GetServerInfo()
1254
s.logger.Info("sender: getServerInfo: got server info", "serverInfo", s.serverInfo)
1257
// TODO: this function is for deciding which GraphQL query/mutation versions to use
1258
// func (s *Sender) getServerVersion() string {
1259
// if s.serverInfo == nil {
1262
// return s.serverInfo.GetLatestLocalVersionInfo().GetVersionOnThisInstanceString()
1265
func (s *Sender) sendServerInfo(record *service.Record, _ *service.ServerInfoRequest) {
1267
localInfo := &service.LocalInfo{}
1268
if s.serverInfo != nil && s.serverInfo.GetLatestLocalVersionInfo() != nil {
1269
localInfo = &service.LocalInfo{
1270
Version: s.serverInfo.GetLatestLocalVersionInfo().GetLatestVersionString(),
1271
OutOfDate: s.serverInfo.GetLatestLocalVersionInfo().GetOutOfDate(),
1275
result := &service.Result{
1276
ResultType: &service.Result_Response{
1277
Response: &service.Response{
1278
ResponseType: &service.Response_ServerInfoResponse{
1279
ServerInfoResponse: &service.ServerInfoResponse{
1280
LocalInfo: localInfo,
1285
Control: record.Control,