apache-ignite

Форк
0
90 строк · 4.0 Кб
1
// Licensed to the Apache Software Foundation (ASF) under one or more
2
// contributor license agreements.  See the NOTICE file distributed with
3
// this work for additional information regarding copyright ownership.
4
// The ASF licenses this file to You under the Apache License, Version 2.0
5
// (the "License"); you may not use this file except in compliance with
6
// the License.  You may obtain a copy of the License at
7
//
8
// http://www.apache.org/licenses/LICENSE-2.0
9
//
10
// Unless required by applicable law or agreed to in writing, software
11
// distributed under the License is distributed on an "AS IS" BASIS,
12
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
// See the License for the specific language governing permissions and
14
// limitations under the License.
15
= Cross-Validation
16

17
Cross validation functionality in Apache Ignite is represented by the `CrossValidation` class. This is a calculator parameterized by the type of model, type of label and key-value types of data. After instantiation (constructor doesn’t accept any additional parameters) we can use a score method to perform cross validation.
18

19
Let’s imagine that we have a trainer, a training set and we want to make cross validation using accuracy as a metric and using 4 folds. Apache Ignite allows us to do this as shown in the following example:
20

21

22
== Cross-Validation (without Pipeline API usage)
23

24
[source, java]
25
----
26
// Create classification trainer
27
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
28

29
// Create cross-validation instance
30
CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
31
  = new CrossValidation<>();
32

33
// Set up the cross-validation process
34
scoreCalculator
35
    .withIgnite(ignite)
36
    .withUpstreamCache(trainingSet)
37
    .withTrainer(trainer)
38
    .withMetric(MetricName.ACCURACY)
39
    .withPreprocessor(vectorizer)
40
    .withAmountOfFolds(4)
41
    .isRunningOnPipeline(false)
42

43
// Calculate accuracy for each fold
44
double[] accuracyScores = scoreCalculator.scoreByFolds();
45
----
46

47
In this example we specify trainer and metric as parameters, after that we pass common training arguments such as a link to the Ignite instance, cache, vectorizers, and finally specify the number of folds. This method returns an array containing chosen metrics for all possible splits of the training set.
48

49
== Cross-Validation (with Pipeline API usage)
50

51
Define the pipeline and pass it as a parameter to Cross-Validation instance to run cross-validation on Pipeline.
52

53
CAUTION: The Pipeline API is experimental and could be changed in the next releases.
54

55

56
[source, java]
57
----
58
// Create classification trainer
59
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
60

61
Pipeline<Integer, Vector, Integer, Double> pipeline
62
  = new Pipeline<Integer, Vector, Integer, Double>()
63
    .addVectorizer(vectorizer)
64
    .addPreprocessingTrainer(new ImputerTrainer<Integer, Vector>())
65
    .addPreprocessingTrainer(new MinMaxScalerTrainer<Integer, Vector>())
66
    .addTrainer(trainer);
67

68

69
// Create cross-validation instance
70
CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
71
  = new CrossValidation<>();
72

73
// Set up the cross-validation process
74
scoreCalculator
75
    .withIgnite(ignite)
76
    .withUpstreamCache(trainingSet)
77
    .withPipeline(pipeline)
78
    .withMetric(MetricName.ACCURACY)
79
    .withPreprocessor(vectorizer)
80
    .withAmountOfFolds(4)
81
    .isRunningOnPipeline(false)
82

83
// Calculate accuracy for each fold
84
double[] accuracyScores = scoreCalculator.scoreByFolds();
85
----
86

87

88
== Example
89

90
To see how the Cross Validation can be used in practice, try https://github.com/apache/ignite/blob/master/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java[this example] and see step https://github.com/apache/ignite/blob/master/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_pipeline.java[8 of ML Tutorial] that are available on GitHub and delivered with every Apache Ignite distribution.
91

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.