7
"github.com/wandb/wandb/core/internal/corelib"
8
"github.com/wandb/wandb/core/pkg/service"
9
"google.golang.org/protobuf/proto"
12
type MetricHandler struct {
13
definedMetrics map[string]*service.MetricRecord
14
globMetrics map[string]*service.MetricRecord
17
func NewMetricHandler() *MetricHandler {
18
return &MetricHandler{
19
definedMetrics: make(map[string]*service.MetricRecord),
20
globMetrics: make(map[string]*service.MetricRecord),
24
// addMetric adds a metric to the target map. If the metric already exists, it will be merged
25
// with the existing metric. If the overwrite flag is set, the metric will be overwritten.
26
func addMetric(arg interface{}, key string, target *map[string]*service.MetricRecord) (*service.MetricRecord, error) {
27
var metric *service.MetricRecord
29
switch v := arg.(type) {
31
metric = &service.MetricRecord{
34
case *service.MetricRecord:
37
// Handle invalid input
38
return nil, errors.New("invalid input")
41
if metric.GetXControl().GetOverwrite() {
42
(*target)[key] = metric
44
if existingMetric, ok := (*target)[key]; ok {
45
proto.Merge(existingMetric, metric)
47
(*target)[key] = metric
53
// createMatchingGlobMetric check if a key matches a glob pattern, if it does create a new defined metric
54
// based on the glob metric and return it.
55
func (mh *MetricHandler) createMatchingGlobMetric(key string) *service.MetricRecord {
56
for pattern, globMetric := range mh.globMetrics {
57
if match, err := filepath.Match(pattern, key); err != nil {
58
// h.logger.CaptureError("error matching metric", err)
61
metric := proto.Clone(globMetric).(*service.MetricRecord)
63
metric.Options.Defined = false
71
// handleStepMetric handles the step metric for a given metric key. If the step metric is not
72
// defined, it will be added to the defined metrics map.
73
func (h *Handler) handleStepMetric(key string) {
78
// already exists no need to add
79
if _, defined := h.metricHandler.definedMetrics[key]; defined {
83
metric, err := addMetric(key, key, &h.metricHandler.definedMetrics)
86
h.logger.CaptureError("error adding metric to map", err)
90
stepRecord := &service.Record{
91
RecordType: &service.Record_Metric{
94
Control: &service.Control{
98
h.sendRecord(stepRecord)
101
func (h *Handler) handleMetric(record *service.Record, metric *service.MetricRecord) {
102
// metric can have a glob name or a name
103
// TODO: replace glob-name/name with one-of field
105
case metric.GetGlobName() != "":
106
if _, err := addMetric(metric, metric.GetGlobName(), &h.metricHandler.globMetrics); err != nil {
107
h.logger.CaptureError("error adding metric to map", err)
111
case metric.GetName() != "":
112
if _, err := addMetric(metric, metric.GetName(), &h.metricHandler.definedMetrics); err != nil {
113
h.logger.CaptureError("error adding metric to map", err)
116
h.handleStepMetric(metric.GetStepMetric())
119
h.logger.CaptureError("invalid metric", errors.New("invalid metric"))
123
type MetricSender struct {
124
definedMetrics map[string]*service.MetricRecord
125
metricIndex map[string]int32
126
configMetrics []map[int]interface{}
129
func NewMetricSender() *MetricSender {
130
return &MetricSender{
131
definedMetrics: make(map[string]*service.MetricRecord),
132
metricIndex: make(map[string]int32),
133
configMetrics: make([]map[int]interface{}, 0),
137
// encodeMetricHints encodes the metric hints for the given metric record. The metric hints
138
// are used to configure the plots in the UI.
139
func (s *Sender) encodeMetricHints(_ *service.Record, metric *service.MetricRecord) {
141
_, err := addMetric(metric, metric.GetName(), &s.metricSender.definedMetrics)
146
if metric.GetStepMetric() != "" {
147
index, ok := s.metricSender.metricIndex[metric.GetStepMetric()]
149
metric = proto.Clone(metric).(*service.MetricRecord)
150
metric.StepMetric = ""
151
metric.StepMetricIndex = index + 1
155
encodeMetric := corelib.ProtoEncodeToDict(metric)
156
if index, ok := s.metricSender.metricIndex[metric.GetName()]; ok {
157
s.metricSender.configMetrics[index] = encodeMetric
159
nextIndex := len(s.metricSender.configMetrics)
160
s.metricSender.configMetrics = append(s.metricSender.configMetrics, encodeMetric)
161
s.metricSender.metricIndex[metric.GetName()] = int32(nextIndex)