wandb

Форк
0
/
handler.go 
944 строки · 24.6 Кб
1
package server
2

3
import (
4
	"context"
5
	"fmt"
6
	"os"
7
	"path/filepath"
8

9
	"github.com/segmentio/encoding/json"
10
	"github.com/wandb/wandb/core/pkg/monitor"
11
	"google.golang.org/protobuf/encoding/protojson"
12
	"google.golang.org/protobuf/proto"
13

14
	"github.com/wandb/wandb/core/internal/corelib"
15
	"github.com/wandb/wandb/core/internal/version"
16
	"github.com/wandb/wandb/core/internal/watcher"
17
	"github.com/wandb/wandb/core/pkg/observability"
18
	"github.com/wandb/wandb/core/pkg/service"
19
)
20

21
const (
22
	MetaFileName              = "wandb-metadata.json"
23
	SummaryFileName           = "wandb-summary.json"
24
	OutputFileName            = "output.log"
25
	DiffFileName              = "diff.patch"
26
	RequirementsFileName      = "requirements.txt"
27
	ConfigFileName            = "config.yaml"
28
	summaryDebouncerRateLimit = 1 / 30.0 // todo: audit rate limit
29
	summaryDebouncerBurstSize = 1        // todo: audit burst size
30
)
31

32
type HandlerOption func(*Handler)
33

34
func WithHandlerFwdChannel(fwd chan *service.Record) HandlerOption {
35
	return func(h *Handler) {
36
		h.fwdChan = fwd
37
	}
38
}
39

40
func WithHandlerOutChannel(out chan *service.Result) HandlerOption {
41
	return func(h *Handler) {
42
		h.outChan = out
43
	}
44
}
45

46
func WithHandlerSettings(settings *service.Settings) HandlerOption {
47
	return func(h *Handler) {
48
		h.settings = settings
49
	}
50
}
51

52
func WithHandlerSystemMonitor(monitor *monitor.SystemMonitor) HandlerOption {
53
	return func(h *Handler) {
54
		h.systemMonitor = monitor
55
	}
56
}
57

58
func WithHandlerWatcher(watcher *watcher.Watcher) HandlerOption {
59
	return func(h *Handler) {
60
		h.watcher = watcher
61
	}
62
}
63

64
func WithHandlerTBHandler(handler *TBHandler) HandlerOption {
65
	return func(h *Handler) {
66
		h.tbHandler = handler
67
	}
68
}
69

70
func WithHandlerFileHandler(handler *FilesHandler) HandlerOption {
71
	return func(h *Handler) {
72
		h.filesHandler = handler
73
	}
74
}
75

76
func WithHandlerFilesInfoHandler(handler *FilesInfoHandler) HandlerOption {
77
	return func(h *Handler) {
78
		h.filesInfoHandler = handler
79
	}
80
}
81

82
func WithHandlerMetricHandler(handler *MetricHandler) HandlerOption {
83
	return func(h *Handler) {
84
		h.metricHandler = handler
85
	}
86
}
87

88
func WithHandlerSummaryHandler(handler *SummaryHandler) HandlerOption {
89
	return func(h *Handler) {
90
		h.summaryHandler = handler
91
	}
92
}
93

94
// Handler is the handler for a stream it handles the incoming messages, processes them
95
// and passes them to the writer
96
type Handler struct {
97
	// ctx is the context for the handler
98
	ctx context.Context
99

100
	// settings is the settings for the handler
101
	settings *service.Settings
102

103
	// logger is the logger for the handler
104
	logger *observability.CoreLogger
105

106
	// fwdChan is the channel for forwarding messages to the next component
107
	fwdChan chan *service.Record
108

109
	// outChan is the channel for sending results to the client
110
	outChan chan *service.Result
111

112
	// timer is used to track the run start and execution times
113
	timer Timer
114

115
	// runRecord is the runRecord record received from the server
116
	runRecord *service.RunRecord
117

118
	// summaryHandler is the summary handler for the stream
119
	summaryHandler *SummaryHandler
120

121
	// activeHistory is the history record used to track
122
	// current active history record for the stream
123
	activeHistory *ActiveHistory
124

125
	// sampledHistory is the sampled history for the stream
126
	// TODO fix this to be generic type
127
	sampledHistory map[string]*ReservoirSampling[float32]
128

129
	// metricHandler is the metric handler for the stream
130
	metricHandler *MetricHandler
131

132
	// systemMonitor is the system monitor for the stream
133
	systemMonitor *monitor.SystemMonitor
134

135
	// watcher is the watcher for the stream
136
	watcher *watcher.Watcher
137

138
	// tbHandler is the tensorboard handler
139
	tbHandler *TBHandler
140

141
	// filesHandler is the file handler for the stream
142
	filesHandler *FilesHandler
143

144
	// filesInfoHandler is the file transfer info for the stream
145
	filesInfoHandler *FilesInfoHandler
146
}
147

148
// NewHandler creates a new handler
149
func NewHandler(
150
	ctx context.Context,
151
	logger *observability.CoreLogger,
152
	opts ...HandlerOption,
153
) *Handler {
154
	h := &Handler{
155
		ctx:    ctx,
156
		logger: logger,
157
	}
158
	for _, opt := range opts {
159
		opt(h)
160
	}
161
	return h
162
}
163

164
// Do starts the handler
165
func (h *Handler) Do(inChan <-chan *service.Record) {
166
	defer h.logger.Reraise()
167
	h.logger.Info("handler: started", "stream_id", h.settings.RunId)
168
	for record := range inChan {
169
		h.logger.Debug("handling record", "record", record)
170
		h.handleRecord(record)
171
	}
172
	h.Close()
173
}
174

175
func (h *Handler) Close() {
176
	close(h.outChan)
177
	close(h.fwdChan)
178
	h.logger.Debug("handler: closed", "stream_id", h.settings.RunId)
179
}
180

181
func (h *Handler) sendResponse(record *service.Record, response *service.Response) {
182
	result := &service.Result{
183
		ResultType: &service.Result_Response{Response: response},
184
		Control:    record.Control,
185
		Uuid:       record.Uuid,
186
	}
187
	h.outChan <- result
188
}
189

190
func (h *Handler) sendRecordWithControl(record *service.Record, controlOptions ...func(*service.Control)) {
191
	if record == nil {
192
		return
193
	}
194

195
	if record.GetControl() == nil {
196
		record.Control = &service.Control{}
197
	}
198

199
	control := record.GetControl()
200
	for _, opt := range controlOptions {
201
		opt(control)
202
	}
203
	h.sendRecord(record)
204
}
205

206
func (h *Handler) sendRecord(record *service.Record) {
207
	if record == nil {
208
		return
209
	}
210
	h.fwdChan <- record
211
}
212

213
//gocyclo:ignore
214
func (h *Handler) handleRecord(record *service.Record) {
215
	h.summaryHandler.Debounce(h.sendSummary)
216
	recordType := record.GetRecordType()
217
	h.logger.Debug("handle: got a message", "record_type", recordType)
218
	switch x := record.RecordType.(type) {
219
	case *service.Record_Alert:
220
		h.handleAlert(record)
221
	case *service.Record_Artifact:
222
		h.handleArtifact(record)
223
	case *service.Record_Config:
224
		h.handleConfig(record)
225
	case *service.Record_Exit:
226
		h.handleExit(record, x.Exit)
227
	case *service.Record_Files:
228
		h.handleFiles(record)
229
	case *service.Record_Final:
230
		h.handleFinal()
231
	case *service.Record_Footer:
232
		h.handleFooter()
233
	case *service.Record_Header:
234
		h.handleHeader(record)
235
	case *service.Record_History:
236
		h.handleHistory(x.History)
237
	case *service.Record_LinkArtifact:
238
		h.handleLinkArtifact(record)
239
	case *service.Record_Metric:
240
		h.handleMetric(record, x.Metric)
241
	case *service.Record_Output:
242
	case *service.Record_OutputRaw:
243
		h.handleOutputRaw(record)
244
	case *service.Record_Preempting:
245
		h.handlePreempting(record)
246
	case *service.Record_Request:
247
		h.handleRequest(record)
248
	case *service.Record_Run:
249
		h.handleRun(record)
250
	case *service.Record_Stats:
251
		h.handleSystemMetrics(record)
252
	case *service.Record_Summary:
253
		h.handleSummary(record, x.Summary)
254
	case *service.Record_Tbrecord:
255
		h.handleTBrecord(record)
256
	case *service.Record_Telemetry:
257
		h.handleTelemetry(record)
258
	case *service.Record_UseArtifact:
259
		h.handleUseArtifact(record)
260
	case nil:
261
		err := fmt.Errorf("handleRecord: record type is nil")
262
		h.logger.CaptureFatalAndPanic("error handling record", err)
263
	default:
264
		err := fmt.Errorf("handleRecord: unknown record type %T", x)
265
		h.logger.CaptureFatalAndPanic("error handling record", err)
266
	}
267
}
268

269
//gocyclo:ignore
270
func (h *Handler) handleRequest(record *service.Record) {
271
	request := record.GetRequest()
272
	response := &service.Response{}
273
	switch x := request.RequestType.(type) {
274
	case *service.Request_CheckVersion:
275
	case *service.Request_Defer:
276
		h.handleDefer(record, x.Defer)
277
		response = nil
278
	case *service.Request_GetSummary:
279
		h.handleGetSummary(record, response)
280
	case *service.Request_Keepalive:
281
	case *service.Request_NetworkStatus:
282
	case *service.Request_PartialHistory:
283
		h.handlePartialHistory(record, x.PartialHistory)
284
		response = nil
285
	case *service.Request_PollExit:
286
		h.handlePollExit(record)
287
		response = nil
288
	case *service.Request_RunStart:
289
		h.handleRunStart(record, x.RunStart)
290
	case *service.Request_SampledHistory:
291
		h.handleSampledHistory(record, response)
292
	case *service.Request_ServerInfo:
293
		h.handleServerInfo(record)
294
		response = nil
295
	case *service.Request_PythonPackages:
296
		h.handlePythonPackages(record, x.PythonPackages)
297
		response = nil
298
	case *service.Request_Shutdown:
299
	case *service.Request_StopStatus:
300
	case *service.Request_LogArtifact:
301
		h.handleLogArtifact(record)
302
		response = nil
303
	case *service.Request_DownloadArtifact:
304
		h.handleDownloadArtifact(record)
305
		response = nil
306
	case *service.Request_JobInfo:
307
	case *service.Request_Attach:
308
		h.handleAttach(record, response)
309
	case *service.Request_Pause:
310
		h.handlePause()
311
	case *service.Request_Resume:
312
		h.handleResume()
313
	case *service.Request_Cancel:
314
		h.handleCancel(record)
315
	case *service.Request_GetSystemMetrics:
316
		h.handleGetSystemMetrics(record, response)
317
	case *service.Request_FileTransferInfo:
318
		h.handleFileTransferInfo(record)
319
	case *service.Request_InternalMessages:
320
	case *service.Request_Sync:
321
		h.handleSync(record)
322
		response = nil
323
	case *service.Request_SenderRead:
324
		h.handleSenderRead(record)
325
		response = nil
326
	default:
327
		err := fmt.Errorf("handleRequest: unknown request type %T", x)
328
		h.logger.CaptureFatalAndPanic("error handling request", err)
329
	}
330
	if response != nil {
331
		h.sendResponse(record, response)
332
	}
333
}
334

335
func (h *Handler) handleDefer(record *service.Record, request *service.DeferRequest) {
336
	switch request.State {
337
	case service.DeferRequest_BEGIN:
338
	case service.DeferRequest_FLUSH_RUN:
339
	case service.DeferRequest_FLUSH_STATS:
340
		// stop the system monitor to ensure that we don't send any more system metrics
341
		// after the run has exited
342
		h.systemMonitor.Stop()
343
	case service.DeferRequest_FLUSH_PARTIAL_HISTORY:
344
		h.activeHistory.Flush()
345
	case service.DeferRequest_FLUSH_TB:
346
		h.tbHandler.Close()
347
	case service.DeferRequest_FLUSH_SUM:
348
		h.handleSummary(nil, &service.SummaryRecord{})
349
		h.summaryHandler.Flush(h.sendSummary)
350
		h.writeAndSendSummaryFile()
351
	case service.DeferRequest_FLUSH_DEBOUNCER:
352
	case service.DeferRequest_FLUSH_OUTPUT:
353
	case service.DeferRequest_FLUSH_JOB:
354
	case service.DeferRequest_FLUSH_DIR:
355
		h.watcher.Close()
356
	case service.DeferRequest_FLUSH_FP:
357
		h.filesHandler.Flush()
358
	case service.DeferRequest_JOIN_FP:
359
	case service.DeferRequest_FLUSH_FS:
360
	case service.DeferRequest_FLUSH_FINAL:
361
		h.handleFinal()
362
		h.handleFooter()
363
	case service.DeferRequest_END:
364
	default:
365
		err := fmt.Errorf("handleDefer: unknown defer state %v", request.State)
366
		h.logger.CaptureError("unknown defer state", err)
367
	}
368
	// Need to clone the record to avoid race condition with the writer
369
	record = proto.Clone(record).(*service.Record)
370
	h.sendRecordWithControl(record,
371
		func(control *service.Control) {
372
			control.AlwaysSend = true
373
		},
374
		func(control *service.Control) {
375
			control.Local = true
376
		},
377
	)
378
}
379
func (h *Handler) handleArtifact(record *service.Record) {
380
	h.sendRecord(record)
381
}
382

383
func (h *Handler) handleLogArtifact(record *service.Record) {
384
	h.sendRecord(record)
385
}
386

387
func (h *Handler) handleDownloadArtifact(record *service.Record) {
388
	h.sendRecord(record)
389
}
390

391
func (h *Handler) handleLinkArtifact(record *service.Record) {
392
	h.sendRecord(record)
393
}
394

395
func (h *Handler) handlePollExit(record *service.Record) {
396
	result := &service.Result{
397
		ResultType: &service.Result_Response{
398
			Response: &service.Response{
399
				ResponseType: &service.Response_PollExitResponse{
400
					PollExitResponse: &service.PollExitResponse{
401
						PusherStats: h.filesInfoHandler.GetFilesStats(),
402
						FileCounts:  h.filesInfoHandler.GetFilesCount(),
403
						Done:        h.filesInfoHandler.GetDone(),
404
					},
405
				},
406
			},
407
		},
408
		Control: record.Control,
409
		Uuid:    record.Uuid,
410
	}
411
	h.outChan <- result
412
}
413

414
func (h *Handler) handleHeader(record *service.Record) {
415
	// populate with version info
416
	versionString := fmt.Sprintf("%s+%s", version.Version, h.ctx.Value(observability.Commit("commit")))
417
	record.GetHeader().VersionInfo = &service.VersionInfo{
418
		Producer:    versionString,
419
		MinConsumer: version.MinServerVersion,
420
	}
421
	h.sendRecordWithControl(
422
		record,
423
		func(control *service.Control) {
424
			control.AlwaysSend = false
425
		},
426
	)
427
}
428

429
func (h *Handler) handleFinal() {
430
	if h.settings.GetXSync().GetValue() {
431
		// if sync is enabled, we don't need to do all this
432
		return
433
	}
434
	record := &service.Record{
435
		RecordType: &service.Record_Final{
436
			Final: &service.FinalRecord{},
437
		},
438
	}
439
	h.sendRecordWithControl(
440
		record,
441
		func(control *service.Control) {
442
			control.AlwaysSend = false
443
		},
444
	)
445
}
446

447
func (h *Handler) handleFooter() {
448
	if h.settings.GetXSync().GetValue() {
449
		// if sync is enabled, we don't need to do all this
450
		return
451
	}
452
	record := &service.Record{
453
		RecordType: &service.Record_Footer{
454
			Footer: &service.FooterRecord{},
455
		},
456
	}
457
	h.sendRecordWithControl(
458
		record,
459
		func(control *service.Control) {
460
			control.AlwaysSend = false
461
		},
462
	)
463
}
464

465
func (h *Handler) handleServerInfo(record *service.Record) {
466
	h.sendRecordWithControl(record,
467
		func(control *service.Control) {
468
			control.AlwaysSend = true
469
		},
470
	)
471
}
472

473
func (h *Handler) handleRunStart(record *service.Record, request *service.RunStartRequest) {
474
	var ok bool
475
	run := request.Run
476

477
	// start the run timer
478
	h.timer = Timer{}
479
	startTime := run.StartTime.AsTime()
480
	h.timer.Start(&startTime)
481

482
	if h.runRecord, ok = proto.Clone(run).(*service.RunRecord); !ok {
483
		err := fmt.Errorf("handleRunStart: failed to clone run")
484
		h.logger.CaptureFatalAndPanic("error handling run start", err)
485
	}
486
	h.sendRecord(record)
487

488
	// start the tensorboard handler
489
	h.watcher.Start()
490

491
	h.filesHandler = h.filesHandler.With(
492
		WithFilesHandlerHandleFn(h.sendRecord),
493
	)
494

495
	if h.settings.GetConsole().GetValue() != "off" {
496
		h.filesHandler.Handle(&service.Record{
497
			RecordType: &service.Record_Files{
498
				Files: &service.FilesRecord{
499
					Files: []*service.FilesItem{
500
						{
501
							Path:   OutputFileName,
502
							Type:   service.FilesItem_WANDB,
503
							Policy: service.FilesItem_END,
504
						},
505
					},
506
				},
507
			},
508
		})
509
	}
510

511
	// start the system monitor
512
	if !h.settings.GetXDisableStats().GetValue() {
513
		h.systemMonitor.Do()
514
	}
515

516
	// save code and patch
517
	if h.settings.GetSaveCode().GetValue() {
518
		h.handleCodeSave()
519
		h.handlePatchSave()
520
	}
521

522
	// NOTE: once this request arrives in the sender,
523
	// the latter will start its filestream and uploader
524
	// initialize the run metadata from settings
525
	var git *service.GitRepoRecord
526
	if run.GetGit().GetRemoteUrl() != "" || run.GetGit().GetCommit() != "" {
527
		git = &service.GitRepoRecord{
528
			RemoteUrl: run.GetGit().GetRemoteUrl(),
529
			Commit:    run.GetGit().GetCommit(),
530
		}
531
	}
532

533
	metadata := &service.MetadataRequest{
534
		Os:            h.settings.GetXOs().GetValue(),
535
		Python:        h.settings.GetXPython().GetValue(),
536
		Host:          h.settings.GetHost().GetValue(),
537
		Cuda:          h.settings.GetXCuda().GetValue(),
538
		Program:       h.settings.GetProgram().GetValue(),
539
		CodePath:      h.settings.GetProgramRelpath().GetValue(),
540
		CodePathLocal: h.settings.GetXCodePathLocal().GetValue(),
541
		Email:         h.settings.GetEmail().GetValue(),
542
		Root:          h.settings.GetRootDir().GetValue(),
543
		Username:      h.settings.GetUsername().GetValue(),
544
		Docker:        h.settings.GetDocker().GetValue(),
545
		Executable:    h.settings.GetXExecutable().GetValue(),
546
		Args:          h.settings.GetXArgs().GetValue(),
547
		Colab:         h.settings.GetColabUrl().GetValue(),
548
		StartedAt:     run.GetStartTime(),
549
		Git:           git,
550
	}
551

552
	if !h.settings.GetXDisableStats().GetValue() {
553
		systemInfo := h.systemMonitor.Probe()
554
		if systemInfo != nil {
555
			proto.Merge(metadata, systemInfo)
556
		}
557
	}
558
	h.handleMetadata(metadata)
559
}
560

561
func (h *Handler) handlePythonPackages(_ *service.Record, request *service.PythonPackagesRequest) {
562
	// write all requirements to a file
563
	// send the file as a Files record
564
	filename := filepath.Join(h.settings.GetFilesDir().GetValue(), RequirementsFileName)
565
	file, err := os.Create(filename)
566
	if err != nil {
567
		h.logger.Error("error creating requirements file", "error", err)
568
		return
569
	}
570
	defer file.Close()
571

572
	for _, pkg := range request.Package {
573
		line := fmt.Sprintf("%s==%s\n", pkg.Name, pkg.Version)
574
		_, err := file.WriteString(line)
575
		if err != nil {
576
			h.logger.Error("error writing requirements file", "error", err)
577
			return
578
		}
579
	}
580
	record := &service.Record{
581
		RecordType: &service.Record_Files{
582
			Files: &service.FilesRecord{
583
				Files: []*service.FilesItem{
584
					{
585
						Path: RequirementsFileName,
586
						Type: service.FilesItem_WANDB,
587
					},
588
				},
589
			},
590
		},
591
	}
592
	h.handleFiles(record)
593
}
594

595
func (h *Handler) handleCodeSave() {
596
	programRelative := h.settings.GetProgramRelpath().GetValue()
597
	if programRelative == "" {
598
		h.logger.Warn("handleCodeSave: program relative path is empty")
599
		return
600
	}
601

602
	programAbsolute := h.settings.GetProgramAbspath().GetValue()
603
	if _, err := os.Stat(programAbsolute); err != nil {
604
		h.logger.Warn("handleCodeSave: program absolute path does not exist", "path", programAbsolute)
605
		return
606
	}
607

608
	codeDir := filepath.Join(h.settings.GetFilesDir().GetValue(), "code")
609
	if err := os.MkdirAll(filepath.Join(codeDir, filepath.Dir(programRelative)), os.ModePerm); err != nil {
610
		return
611
	}
612
	savedProgram := filepath.Join(codeDir, programRelative)
613
	if _, err := os.Stat(savedProgram); err != nil {
614
		if err = copyFile(programAbsolute, savedProgram); err != nil {
615
			return
616
		}
617
	}
618
	record := &service.Record{
619
		RecordType: &service.Record_Files{
620
			Files: &service.FilesRecord{
621
				Files: []*service.FilesItem{
622
					{
623
						Path: filepath.Join("code", programRelative),
624
						Type: service.FilesItem_WANDB,
625
					},
626
				},
627
			},
628
		},
629
	}
630
	h.handleFiles(record)
631
}
632

633
func (h *Handler) handlePatchSave() {
634
	// capture git state
635
	if h.settings.GetDisableGit().GetValue() {
636
		return
637
	}
638

639
	git := NewGit(h.settings.GetRootDir().GetValue(), h.logger)
640
	if !git.IsAvailable() {
641
		return
642
	}
643

644
	files := []*service.FilesItem{}
645

646
	filesDirPath := h.settings.GetFilesDir().GetValue()
647
	file := filepath.Join(filesDirPath, DiffFileName)
648
	if err := git.SavePatch("HEAD", file); err != nil {
649
		h.logger.Error("error generating diff", "error", err)
650
	} else {
651
		files = append(files, &service.FilesItem{Path: DiffFileName, Type: service.FilesItem_WANDB})
652
	}
653

654
	if output, err := git.LatestCommit("@{u}"); err != nil {
655
		h.logger.Error("error getting latest commit", "error", err)
656
	} else {
657
		diffFileName := fmt.Sprintf("diff_%s.patch", output)
658
		file = filepath.Join(filesDirPath, diffFileName)
659
		if err := git.SavePatch("@{u}", file); err != nil {
660
			h.logger.Error("error generating diff", "error", err)
661
		} else {
662
			files = append(files, &service.FilesItem{Path: diffFileName, Type: service.FilesItem_WANDB})
663
		}
664
	}
665

666
	if len(files) == 0 {
667
		return
668
	}
669

670
	record := &service.Record{
671
		RecordType: &service.Record_Files{
672
			Files: &service.FilesRecord{
673
				Files: files,
674
			},
675
		},
676
	}
677
	h.handleFiles(record)
678
}
679

680
func (h *Handler) handleMetadata(request *service.MetadataRequest) {
681
	// TODO: Sending metadata as a request for now, eventually this should be turned into
682
	//  a record and stored in the transaction log
683
	if h.settings.GetXDisableMeta().GetValue() {
684
		return
685
	}
686

687
	mo := protojson.MarshalOptions{
688
		Indent: "  ",
689
		// EmitUnpopulated: true,
690
	}
691
	jsonBytes, err := mo.Marshal(request)
692
	if err != nil {
693
		h.logger.CaptureError("error marshalling metadata", err)
694
		return
695
	}
696
	filePath := filepath.Join(h.settings.GetFilesDir().GetValue(), MetaFileName)
697
	if err := os.WriteFile(filePath, jsonBytes, 0644); err != nil {
698
		h.logger.CaptureError("error writing metadata file", err)
699
		return
700
	}
701

702
	record := &service.Record{
703
		RecordType: &service.Record_Files{
704
			Files: &service.FilesRecord{
705
				Files: []*service.FilesItem{
706
					{
707
						Path: MetaFileName,
708
						Type: service.FilesItem_WANDB,
709
					},
710
				},
711
			},
712
		},
713
	}
714

715
	h.handleFiles(record)
716
}
717

718
func (h *Handler) handleAttach(_ *service.Record, response *service.Response) {
719

720
	response.ResponseType = &service.Response_AttachResponse{
721
		AttachResponse: &service.AttachResponse{
722
			Run: h.runRecord,
723
		},
724
	}
725
}
726

727
func (h *Handler) handleCancel(record *service.Record) {
728
	h.sendRecord(record)
729
}
730

731
func (h *Handler) handlePause() {
732
	h.timer.Pause()
733
	h.systemMonitor.Stop()
734
}
735

736
func (h *Handler) handleResume() {
737
	h.timer.Resume()
738
	h.systemMonitor.Do()
739
}
740

741
func (h *Handler) handleSystemMetrics(record *service.Record) {
742
	h.sendRecord(record)
743
}
744

745
func (h *Handler) handleOutputRaw(record *service.Record) {
746
	h.sendRecord(record)
747
}
748

749
func (h *Handler) handlePreempting(record *service.Record) {
750
	h.sendRecord(record)
751
}
752

753
func (h *Handler) handleRun(record *service.Record) {
754
	h.sendRecordWithControl(record,
755
		func(control *service.Control) {
756
			control.AlwaysSend = true
757
		},
758
	)
759
}
760

761
func (h *Handler) handleConfig(record *service.Record) {
762
	h.sendRecord(record)
763
}
764

765
func (h *Handler) handleAlert(record *service.Record) {
766
	h.sendRecord(record)
767
}
768

769
func (h *Handler) handleExit(record *service.Record, exit *service.RunExitRecord) {
770
	// stop the run timer and set the runtime
771
	h.timer.Pause()
772
	runtime := int32(h.timer.Elapsed().Seconds())
773
	exit.Runtime = runtime
774

775
	// update summary with runtime
776
	if !h.settings.GetXSync().GetValue() {
777
		summaryRecord := corelib.ConsolidateSummaryItems(h.summaryHandler.consolidatedSummary, []*service.SummaryItem{
778
			{
779
				Key: "_wandb", ValueJson: fmt.Sprintf(`{"runtime": %d}`, runtime),
780
			},
781
		})
782
		h.summaryHandler.updateSummaryDelta(summaryRecord)
783
	}
784

785
	// send the exit record
786
	h.sendRecordWithControl(record,
787
		func(control *service.Control) {
788
			control.AlwaysSend = true
789
			// do not write to the transaction log when syncing an offline run
790
			if h.settings.GetXSync().GetValue() {
791
				control.Local = true
792
			}
793
		},
794
	)
795
}
796

797
func (h *Handler) handleFiles(record *service.Record) {
798
	if record.GetFiles() == nil {
799
		return
800
	}
801
	h.filesHandler.Handle(record)
802
}
803

804
func (h *Handler) handleGetSummary(_ *service.Record, response *service.Response) {
805
	var items []*service.SummaryItem
806

807
	for key, element := range h.summaryHandler.consolidatedSummary {
808
		items = append(items, &service.SummaryItem{Key: key, ValueJson: element})
809
	}
810
	response.ResponseType = &service.Response_GetSummaryResponse{
811
		GetSummaryResponse: &service.GetSummaryResponse{
812
			Item: items,
813
		},
814
	}
815
}
816

817
func (h *Handler) handleGetSystemMetrics(_ *service.Record, response *service.Response) {
818
	sm := h.systemMonitor.GetBuffer()
819

820
	response.ResponseType = &service.Response_GetSystemMetricsResponse{
821
		GetSystemMetricsResponse: &service.GetSystemMetricsResponse{
822
			SystemMetrics: make(map[string]*service.SystemMetricsBuffer),
823
		},
824
	}
825

826
	for key, samples := range sm {
827
		buffer := make([]*service.SystemMetricSample, 0, len(samples.GetElements()))
828

829
		// convert samples to buffer:
830
		for _, sample := range samples.GetElements() {
831
			buffer = append(buffer, &service.SystemMetricSample{
832
				Timestamp: sample.Timestamp,
833
				Value:     float32(sample.Value),
834
			})
835
		}
836
		// add to response as map key: buffer
837
		response.GetGetSystemMetricsResponse().SystemMetrics[key] = &service.SystemMetricsBuffer{
838
			Record: buffer,
839
		}
840
	}
841
}
842

843
func (h *Handler) handleFileTransferInfo(record *service.Record) {
844
	h.filesInfoHandler.Handle(record)
845
}
846

847
func (h *Handler) handleSync(record *service.Record) {
848
	h.sendRecord(record)
849
}
850

851
func (h *Handler) handleSenderRead(record *service.Record) {
852
	h.sendRecord(record)
853
}
854

855
func (h *Handler) handleTelemetry(record *service.Record) {
856
	h.sendRecord(record)
857
}
858

859
func (h *Handler) handleUseArtifact(record *service.Record) {
860
	h.sendRecord(record)
861
}
862

863
func (h *Handler) writeAndSendSummaryFile() {
864
	if h.settings.GetXSync().GetValue() {
865
		// if sync is enabled, we don't need to do all this
866
		return
867
	}
868

869
	// write summary to file
870
	summaryFile := filepath.Join(h.settings.GetFilesDir().GetValue(), SummaryFileName)
871

872
	jsonBytes, err := json.MarshalIndent(h.summaryHandler.consolidatedSummary, "", "  ")
873
	if err != nil {
874
		h.logger.Error("handler: writeAndSendSummaryFile: error marshalling summary", "error", err)
875
		return
876
	}
877

878
	if err := os.WriteFile(summaryFile, []byte(jsonBytes), 0644); err != nil {
879
		h.logger.Error("handler: writeAndSendSummaryFile: failed to write config file", "error", err)
880
	}
881

882
	// send summary file
883
	h.filesHandler.Handle(&service.Record{
884
		RecordType: &service.Record_Files{
885
			Files: &service.FilesRecord{
886
				Files: []*service.FilesItem{
887
					{
888
						Path: SummaryFileName,
889
						Type: service.FilesItem_WANDB,
890
					},
891
				},
892
			},
893
		},
894
	})
895
}
896

897
func (h *Handler) sendSummary() {
898
	summaryRecord := &service.SummaryRecord{
899
		Update: []*service.SummaryItem{},
900
	}
901

902
	for key, value := range h.summaryHandler.summaryDelta {
903
		summaryRecord.Update = append(summaryRecord.Update, &service.SummaryItem{
904
			Key: key, ValueJson: value,
905
		})
906
	}
907

908
	record := &service.Record{
909
		RecordType: &service.Record_Summary{
910
			Summary: summaryRecord,
911
		},
912
	}
913
	h.sendRecord(record)
914
	// reset delta summary
915
	clear(h.summaryHandler.summaryDelta)
916
}
917

918
func (h *Handler) handleSummary(_ *service.Record, summary *service.SummaryRecord) {
919
	if h.settings.GetXSync().GetValue() {
920
		// if sync is enabled, we don't need to do all this
921
		return
922
	}
923

924
	runtime := int32(h.timer.Elapsed().Seconds())
925

926
	// update summary with runtime
927
	summary.Update = append(summary.Update, &service.SummaryItem{
928
		Key: "_wandb", ValueJson: fmt.Sprintf(`{"runtime": %d}`, runtime),
929
	})
930

931
	summaryRecord := corelib.ConsolidateSummaryItems(h.summaryHandler.consolidatedSummary, summary.Update)
932
	h.summaryHandler.updateSummaryDelta(summaryRecord)
933
}
934

935
func (h *Handler) handleTBrecord(record *service.Record) {
936
	err := h.tbHandler.Handle(record)
937
	if err != nil {
938
		h.logger.CaptureError("error handling tbrecord", err)
939
	}
940
}
941

942
func (h *Handler) GetRun() *service.RunRecord {
943
	return h.runRecord
944
}
945

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.