apache-ignite

Форк
0
124 строки · 4.8 Кб
1
// Licensed to the Apache Software Foundation (ASF) under one or more
2
// contributor license agreements.  See the NOTICE file distributed with
3
// this work for additional information regarding copyright ownership.
4
// The ASF licenses this file to You under the Apache License, Version 2.0
5
// (the "License"); you may not use this file except in compliance with
6
// the License.  You may obtain a copy of the License at
7
//
8
// http://www.apache.org/licenses/LICENSE-2.0
9
//
10
// Unless required by applicable law or agreed to in writing, software
11
// distributed under the License is distributed on an "AS IS" BASIS,
12
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
// See the License for the specific language governing permissions and
14
// limitations under the License.
15
= Pipelines API
16

17
Apache Ignite ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Pipelines API, where the pipeline concept is mostly inspired by the scikit-learn and Apache Spark projects.
18

19
* **Preprocessor Model **- This is an algorithm which can transform one DataSet into another DataSet.
20

21
* **Preprocessor Trainer**- This is an algorithm which can be fit on a DataSet to produce a PreprocessorModel.
22

23
* **Pipeline **-  A Pipeline chains multiple Trainers and Preprocessors together to specify an ML workflow.
24

25
* **Parameter **- All ML Trainers and Preprocessor Trainers now share a common API for specifying parameters.
26

27
CAUTION: The Pipeline API is experimental and could be changed in the next releases.
28

29

30
The Pipeline could replace the pieces of code with .fit() method calls as in the next examples:
31

32

33
[tabs]
34
--
35
tab:Without Pipeline API[]
36

37
[source, java]
38
----
39
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 3, 4, 5, 6, 8, 10).labeled(1);
40

41
TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>()
42
  .split(0.75);
43

44
Preprocessor<Integer, Vector> imputingPreprocessor = new ImputerTrainer<Integer, Vector>()
45
  .fit(ignite,
46
       dataCache,
47
       vectorizer
48
      );
49

50
Preprocessor<Integer, Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Vector>()
51
  .fit(ignite,
52
       dataCache,
53
       imputingPreprocessor
54
      );
55

56
Preprocessor<Integer, Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Vector>()
57
  .withP(1)
58
  .fit(ignite,
59
       dataCache,
60
       minMaxScalerPreprocessor
61
      );
62

63
// Tune hyper-parameters with K-fold Cross-Validation on the split training set.
64

65
DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
66

67
CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator = new CrossValidation<>();
68

69
ParamGrid paramGrid = new ParamGrid()
70
  .addHyperParam("maxDeep", trainerCV::withMaxDeep, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 10.0})
71
  .addHyperParam("minImpurityDecrease", trainerCV::withMinImpurityDecrease, new Double[] {0.0, 0.25, 0.5});
72

73
scoreCalculator
74
  .withIgnite(ignite)
75
  .withUpstreamCache(dataCache)
76
  .withTrainer(trainerCV)
77
  .withMetric(MetricName.ACCURACY)
78
  .withFilter(split.getTrainFilter())
79
  .isRunningOnPipeline(false)
80
  .withPreprocessor(normalizationPreprocessor)
81
  .withAmountOfFolds(3)
82
  .withParamGrid(paramGrid);
83

84
CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters();
85
----
86

87
tab:With Pipeline API[]
88

89
[source, java]
90
----
91
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 4, 5, 6, 8).labeled(1);
92

93
TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>()
94
  .split(0.75);
95

96
DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
97

98
Pipeline<Integer, Vector, Integer, Double> pipeline = new Pipeline<Integer, Vector, Integer, Double>()
99
  .addVectorizer(vectorizer)
100
  .addPreprocessingTrainer(new ImputerTrainer<Integer, Vector>())
101
  .addPreprocessingTrainer(new MinMaxScalerTrainer<Integer, Vector>())
102
  .addTrainer(trainer);
103

104
CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator = new CrossValidation<>();
105

106
ParamGrid paramGrid = new ParamGrid()
107
  .addHyperParam("maxDeep", trainer::withMaxDeep, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 10.0})
108
  .addHyperParam("minImpurityDecrease", trainer::withMinImpurityDecrease, new Double[] {0.0, 0.25, 0.5});
109

110
scoreCalculator
111
  .withIgnite(ignite)
112
  .withUpstreamCache(dataCache)
113
  .withPipeline(pipeline)
114
  .withMetric(MetricName.ACCURACY)
115
  .withFilter(split.getTrainFilter())
116
  .withAmountOfFolds(3)
117
  .withParamGrid(paramGrid);
118

119

120
CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters();
121
----
122
--
123

124
The full code could be found in the https://github.com/apache/ignite/blob/master/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_pipeline.java[Titanic tutorial].
125

126

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.