Dragonfly2
51 строка · 1.6 Кб
1/*
2* Copyright 2023 The Dragonfly Authors
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*/
16
17package rpcserver
18
19import (
20trainerv1 "d7y.io/api/v2/pkg/apis/trainer/v1"
21
22"d7y.io/dragonfly/v2/trainer/config"
23"d7y.io/dragonfly/v2/trainer/metrics"
24"d7y.io/dragonfly/v2/trainer/service"
25storage "d7y.io/dragonfly/v2/trainer/storage"
26"d7y.io/dragonfly/v2/trainer/training"
27)
28
29// trainerServerV1 is v1 version of the trainer grpc server.
30type trainerServerV1 struct {
31// Service interface.
32service *service.V1
33}
34
35// newTrainerServerV1 returns a new trainerServerV1 instance.
36func newTrainerServerV1(cfg *config.Config, storage storage.Storage, training training.Training) trainerv1.TrainerServer {
37return &trainerServerV1{service.NewV1(cfg, storage, training)}
38}
39
40// Train handles the training request from scheduler.
41func (t *trainerServerV1) Train(stream trainerv1.Trainer_TrainServer) error {
42// Collect TrainCount metrics.
43metrics.TrainCount.Inc()
44if err := t.service.Train(stream); err != nil {
45// Collect TrainFailureCount metrics.
46metrics.TrainFailureCount.Inc()
47return err
48}
49
50return nil
51}
52