optuna

Форк
0
/
test_cli.py 
1744 строки · 59.4 Кб
1
import json
2
import os
3
import platform
4
import re
5
import subprocess
6
from subprocess import CalledProcessError
7
import tempfile
8
from typing import Any
9
from typing import Callable
10
from typing import Optional
11
from typing import Tuple
12
from unittest.mock import MagicMock
13
from unittest.mock import patch
14

15
import fakeredis
16
import numpy as np
17
from pandas import Timedelta
18
from pandas import Timestamp
19
import pytest
20
import yaml
21

22
import optuna
23
import optuna.cli
24
from optuna.exceptions import CLIUsageError
25
from optuna.exceptions import ExperimentalWarning
26
from optuna.storages import JournalFileStorage
27
from optuna.storages import JournalRedisStorage
28
from optuna.storages import JournalStorage
29
from optuna.storages import RDBStorage
30
from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX
31
from optuna.study import StudyDirection
32
from optuna.testing.storages import StorageSupplier
33
from optuna.testing.tempfile_pool import NamedTemporaryFilePool
34
from optuna.trial import Trial
35
from optuna.trial import TrialState
36

37

38
# An example of objective functions
39
def objective_func(trial: Trial) -> float:
40
    x = trial.suggest_float("x", -10, 10)
41
    return (x + 5) ** 2
42

43

44
# An example of objective functions for branched search spaces
45
def objective_func_branched_search_space(trial: Trial) -> float:
46
    c = trial.suggest_categorical("c", ("A", "B"))
47
    if c == "A":
48
        x = trial.suggest_float("x", -10, 10)
49
        return (x + 5) ** 2
50
    else:
51
        y = trial.suggest_float("y", -10, 10)
52
        return (y + 5) ** 2
53

54

55
# An example of objective functions for multi-objective optimization
56
def objective_func_multi_objective(trial: Trial) -> Tuple[float, float]:
57
    x = trial.suggest_float("x", -10, 10)
58
    return (x + 5) ** 2, (x - 5) ** 2
59

60

61
def _parse_output(output: str, output_format: str) -> Any:
62
    """Parse CLI output.
63

64
    Args:
65
        output:
66
            The output of command.
67
        output_format:
68
            The format of output specified by command.
69

70
    Returns:
71
        For table format, a list of dict formatted rows.
72
        For JSON or YAML format, a list or a dict corresponding to ``output``.
73
    """
74
    if output_format == "value":
75
        # Currently, _parse_output with output_format="value" is used only for
76
        # `study-names` command.
77
        return [{"name": values} for values in output.split(os.linesep)]
78
    elif output_format == "table":
79
        rows = output.split(os.linesep)
80
        assert all(len(rows[0]) == len(row) for row in rows)
81
        # Check ruled lines.
82
        assert rows[0] == rows[2] == rows[-1]
83

84
        keys = [r.strip() for r in rows[1].split("|")[1:-1]]
85
        ret = []
86
        for record in rows[3:-1]:
87
            attrs = {}
88
            for key, attr in zip(keys, record.split("|")[1:-1]):
89
                attrs[key] = attr.strip()
90
            ret.append(attrs)
91
        return ret
92
    elif output_format == "json":
93
        return json.loads(output)
94
    elif output_format == "yaml":
95
        return yaml.safe_load(output)
96
    else:
97
        assert False
98

99

100
@pytest.mark.skip_coverage
101
def test_create_study_command() -> None:
102
    with StorageSupplier("sqlite") as storage:
103
        assert isinstance(storage, RDBStorage)
104
        storage_url = str(storage.engine.url)
105

106
        # Create study.
107
        command = ["optuna", "create-study", "--storage", storage_url]
108
        subprocess.check_call(command)
109

110
        # Command output should be in name string format (no-name + UUID).
111
        study_name = str(subprocess.check_output(command).decode().strip())
112
        name_re = r"^no-name-[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$"
113
        assert re.match(name_re, study_name) is not None
114

115
        # study_name should be stored in storage.
116
        study_id = storage.get_study_id_from_name(study_name)
117
        assert study_id == 2
118

119

120
@pytest.mark.skip_coverage
121
def test_create_study_command_with_study_name() -> None:
122
    with StorageSupplier("sqlite") as storage:
123
        assert isinstance(storage, RDBStorage)
124
        storage_url = str(storage.engine.url)
125
        study_name = "test_study"
126

127
        # Create study with name.
128
        command = ["optuna", "create-study", "--storage", storage_url, "--study-name", study_name]
129
        study_name = str(subprocess.check_output(command).decode().strip())
130

131
        # Check if study_name is stored in the storage.
132
        study_id = storage.get_study_id_from_name(study_name)
133
        assert storage.get_study_name_from_id(study_id) == study_name
134

135

136
@pytest.mark.skip_coverage
137
def test_create_study_command_without_storage_url() -> None:
138
    with pytest.raises(subprocess.CalledProcessError) as err:
139
        subprocess.check_output(
140
            ["optuna", "create-study"],
141
            env={k: v for k, v in os.environ.items() if k != "OPTUNA_STORAGE"},
142
        )
143
    usage = err.value.output.decode()
144
    assert usage.startswith("usage:")
145

146

147
@pytest.mark.skip_coverage
148
def test_create_study_command_with_storage_env() -> None:
149
    with StorageSupplier("sqlite") as storage:
150
        assert isinstance(storage, RDBStorage)
151
        storage_url = str(storage.engine.url)
152

153
        # Create study.
154
        command = ["optuna", "create-study"]
155
        env = {**os.environ, "OPTUNA_STORAGE": storage_url}
156
        subprocess.check_call(command, env=env)
157

158
        # Command output should be in name string format (no-name + UUID).
159
        study_name = str(subprocess.check_output(command, env=env).decode().strip())
160
        name_re = r"^no-name-[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$"
161
        assert re.match(name_re, study_name) is not None
162

163
        # study_name should be stored in storage.
164
        study_id = storage.get_study_id_from_name(study_name)
165
        assert study_id == 2
166

167

168
@pytest.mark.skip_coverage
169
def test_create_study_command_with_direction() -> None:
170
    with StorageSupplier("sqlite") as storage:
171
        assert isinstance(storage, RDBStorage)
172
        storage_url = str(storage.engine.url)
173

174
        command = ["optuna", "create-study", "--storage", storage_url, "--direction", "minimize"]
175
        study_name = str(subprocess.check_output(command).decode().strip())
176
        study_id = storage.get_study_id_from_name(study_name)
177
        assert storage.get_study_directions(study_id) == [StudyDirection.MINIMIZE]
178

179
        command = ["optuna", "create-study", "--storage", storage_url, "--direction", "maximize"]
180
        study_name = str(subprocess.check_output(command).decode().strip())
181
        study_id = storage.get_study_id_from_name(study_name)
182
        assert storage.get_study_directions(study_id) == [StudyDirection.MAXIMIZE]
183

184
        command = ["optuna", "create-study", "--storage", storage_url, "--direction", "test"]
185

186
        # --direction should be either 'minimize' or 'maximize'.
187
        with pytest.raises(subprocess.CalledProcessError):
188
            subprocess.check_call(command)
189

190

191
@pytest.mark.skip_coverage
192
def test_create_study_command_with_multiple_directions() -> None:
193
    with StorageSupplier("sqlite") as storage:
194
        assert isinstance(storage, RDBStorage)
195
        storage_url = str(storage.engine.url)
196
        command = [
197
            "optuna",
198
            "create-study",
199
            "--storage",
200
            storage_url,
201
            "--directions",
202
            "minimize",
203
            "maximize",
204
        ]
205

206
        study_name = str(subprocess.check_output(command).decode().strip())
207
        study_id = storage.get_study_id_from_name(study_name)
208
        expected_directions = [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE]
209
        assert storage.get_study_directions(study_id) == expected_directions
210

211
        command = [
212
            "optuna",
213
            "create-study",
214
            "--storage",
215
            storage_url,
216
            "--directions",
217
            "minimize",
218
            "maximize",
219
            "test",
220
        ]
221

222
        # Each direction in --directions should be either `minimize` or `maximize`.
223
        with pytest.raises(subprocess.CalledProcessError):
224
            subprocess.check_call(command)
225

226
        command = [
227
            "optuna",
228
            "create-study",
229
            "--storage",
230
            storage_url,
231
            "--direction",
232
            "minimize",
233
            "--directions",
234
            "minimize",
235
            "maximize",
236
            "test",
237
        ]
238

239
        # It can't specify both --direction and --directions
240
        with pytest.raises(subprocess.CalledProcessError):
241
            subprocess.check_call(command)
242

243

244
@pytest.mark.skip_coverage
245
def test_delete_study_command() -> None:
246
    with StorageSupplier("sqlite") as storage:
247
        assert isinstance(storage, RDBStorage)
248
        storage_url = str(storage.engine.url)
249
        study_name = "delete-study-test"
250

251
        # Create study.
252
        command = ["optuna", "create-study", "--storage", storage_url, "--study-name", study_name]
253
        subprocess.check_call(command)
254
        assert study_name in {s.study_name: s for s in storage.get_all_studies()}
255

256
        # Delete study.
257
        command = ["optuna", "delete-study", "--storage", storage_url, "--study-name", study_name]
258
        subprocess.check_call(command)
259
        assert study_name not in {s.study_name: s for s in storage.get_all_studies()}
260

261

262
@pytest.mark.skip_coverage
263
def test_delete_study_command_without_storage_url() -> None:
264
    with pytest.raises(subprocess.CalledProcessError):
265
        subprocess.check_output(
266
            ["optuna", "delete-study", "--study-name", "dummy_study"],
267
            env={k: v for k, v in os.environ.items() if k != "OPTUNA_STORAGE"},
268
        )
269

270

271
@pytest.mark.skip_coverage
272
def test_study_set_user_attr_command() -> None:
273
    with StorageSupplier("sqlite") as storage:
274
        assert isinstance(storage, RDBStorage)
275
        storage_url = str(storage.engine.url)
276

277
        # Create study.
278
        study_name = storage.get_study_name_from_id(
279
            storage.create_new_study(directions=[StudyDirection.MINIMIZE])
280
        )
281

282
        base_command = [
283
            "optuna",
284
            "study",
285
            "set-user-attr",
286
            "--study-name",
287
            study_name,
288
            "--storage",
289
            storage_url,
290
        ]
291

292
        example_attrs = {"architecture": "ResNet", "baselen_score": "0.002"}
293
        for key, value in example_attrs.items():
294
            subprocess.check_call(base_command + ["--key", key, "--value", value])
295

296
        # Attrs should be stored in storage.
297
        study_id = storage.get_study_id_from_name(study_name)
298
        study_user_attrs = storage.get_study_user_attrs(study_id)
299
        assert len(study_user_attrs) == 2
300
        assert all(study_user_attrs[k] == v for k, v in example_attrs.items())
301

302

303
@pytest.mark.skip_coverage
304
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
305
def test_study_names_command(output_format: Optional[str]) -> None:
306
    with StorageSupplier("sqlite") as storage:
307
        assert isinstance(storage, RDBStorage)
308
        storage_url = str(storage.engine.url)
309

310
        expected_study_names = ["study-names-test1", "study-names-test2"]
311
        expected_column_name = "name"
312

313
        # Create a study.
314
        command = [
315
            "optuna",
316
            "create-study",
317
            "--storage",
318
            storage_url,
319
            "--study-name",
320
            expected_study_names[0],
321
        ]
322
        subprocess.check_output(command)
323

324
        # Get study names.
325
        command = ["optuna", "study-names", "--storage", storage_url]
326
        if output_format is not None:
327
            command += ["--format", output_format]
328
        output = str(subprocess.check_output(command).decode().strip())
329
        study_names = _parse_output(output, output_format or "value")
330

331
        # Check user_attrs are not printed.
332
        assert len(study_names) == 1
333
        assert study_names[0]["name"] == expected_study_names[0]
334

335
        # Create another study.
336
        command = [
337
            "optuna",
338
            "create-study",
339
            "--storage",
340
            storage_url,
341
            "--study-name",
342
            expected_study_names[1],
343
        ]
344
        subprocess.check_output(command)
345

346
        # Get study names.
347
        command = ["optuna", "study-names", "--storage", storage_url]
348
        if output_format is not None:
349
            command += ["--format", output_format]
350
        output = str(subprocess.check_output(command).decode().strip())
351
        study_names = _parse_output(output, output_format or "value")
352

353
        assert len(study_names) == 2
354
        for i, study_name in enumerate(study_names):
355
            assert list(study_name.keys()) == [expected_column_name]
356
            assert study_name["name"] == expected_study_names[i]
357

358

359
@pytest.mark.skip_coverage
360
def test_study_names_command_without_storage_url() -> None:
361
    with pytest.raises(subprocess.CalledProcessError):
362
        subprocess.check_output(
363
            ["optuna", "study-names", "--study-name", "dummy_study"],
364
            env={k: v for k, v in os.environ.items() if k != "OPTUNA_STORAGE"},
365
        )
366

367

368
@pytest.mark.skip_coverage
369
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
370
def test_studies_command(output_format: Optional[str]) -> None:
371
    with StorageSupplier("sqlite") as storage:
372
        assert isinstance(storage, RDBStorage)
373
        storage_url = str(storage.engine.url)
374

375
        # First study.
376
        study_1 = optuna.create_study(storage=storage)
377

378
        # Run command.
379
        command = ["optuna", "studies", "--storage", storage_url]
380
        if output_format is not None:
381
            command += ["--format", output_format]
382

383
        output = str(subprocess.check_output(command).decode().strip())
384
        studies = _parse_output(output, output_format or "table")
385

386
        expected_keys = ["name", "direction", "n_trials", "datetime_start"]
387

388
        # Check user_attrs are not printed.
389
        if output_format is None or output_format == "table":
390
            assert list(studies[0].keys()) == expected_keys
391
        else:
392
            assert set(studies[0].keys()) == set(expected_keys)
393

394
        # Add a second study.
395
        study_2 = optuna.create_study(
396
            storage=storage, study_name="study_2", directions=["minimize", "maximize"]
397
        )
398
        study_2.optimize(objective_func_multi_objective, n_trials=10)
399
        study_2.set_user_attr("key_1", "value_1")
400
        study_2.set_user_attr("key_2", "value_2")
401

402
        # Run command again to include second study.
403
        output = str(subprocess.check_output(command).decode().strip())
404
        studies = _parse_output(output, output_format or "table")
405

406
        expected_keys = ["name", "direction", "n_trials", "datetime_start", "user_attrs"]
407

408
        assert len(studies) == 2
409
        for study in studies:
410
            if output_format is None or output_format == "table":
411
                assert list(study.keys()) == expected_keys
412
            else:
413
                assert set(study.keys()) == set(expected_keys)
414

415
        # Check study_name, direction, n_trials and user_attrs for the first study.
416
        assert studies[0]["name"] == study_1.study_name
417
        if output_format is None or output_format == "table":
418
            assert studies[0]["n_trials"] == "0"
419
            assert eval(studies[0]["direction"]) == ("MINIMIZE",)
420
            assert eval(studies[0]["user_attrs"]) == {}
421
        else:
422
            assert studies[0]["n_trials"] == 0
423
            assert studies[0]["direction"] == ["MINIMIZE"]
424
            assert studies[0]["user_attrs"] == {}
425

426
        # Check study_name, direction, n_trials and user_attrs for the second study.
427
        assert studies[1]["name"] == study_2.study_name
428
        if output_format is None or output_format == "table":
429
            assert studies[1]["n_trials"] == "10"
430
            assert eval(studies[1]["direction"]) == ("MINIMIZE", "MAXIMIZE")
431
            assert eval(studies[1]["user_attrs"]) == {"key_1": "value_1", "key_2": "value_2"}
432
        else:
433
            assert studies[1]["n_trials"] == 10
434
            assert studies[1]["direction"] == ["MINIMIZE", "MAXIMIZE"]
435
            assert studies[1]["user_attrs"] == {"key_1": "value_1", "key_2": "value_2"}
436

437

438
@pytest.mark.skip_coverage
439
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
440
def test_studies_command_flatten(output_format: Optional[str]) -> None:
441
    with StorageSupplier("sqlite") as storage:
442
        assert isinstance(storage, RDBStorage)
443
        storage_url = str(storage.engine.url)
444

445
        # First study.
446
        study_1 = optuna.create_study(storage=storage)
447

448
        # Run command.
449
        command = ["optuna", "studies", "--storage", storage_url, "--flatten"]
450
        if output_format is not None:
451
            command += ["--format", output_format]
452

453
        output = str(subprocess.check_output(command).decode().strip())
454
        studies = _parse_output(output, output_format or "table")
455

456
        expected_keys_1 = [
457
            "name",
458
            "direction_0",
459
            "n_trials",
460
            "datetime_start",
461
        ]
462

463
        # Check user_attrs are not printed.
464
        if output_format is None or output_format == "table":
465
            assert list(studies[0].keys()) == expected_keys_1
466
        else:
467
            assert set(studies[0].keys()) == set(expected_keys_1)
468

469
        # Add a second study.
470
        study_2 = optuna.create_study(
471
            storage=storage, study_name="study_2", directions=["minimize", "maximize"]
472
        )
473
        study_2.optimize(objective_func_multi_objective, n_trials=10)
474
        study_2.set_user_attr("key_1", "value_1")
475
        study_2.set_user_attr("key_2", "value_2")
476

477
        # Run command again to include second study.
478
        output = str(subprocess.check_output(command).decode().strip())
479
        studies = _parse_output(output, output_format or "table")
480

481
        if output_format is None or output_format == "table":
482
            expected_keys_1 = expected_keys_2 = [
483
                "name",
484
                "direction_0",
485
                "direction_1",
486
                "n_trials",
487
                "datetime_start",
488
                "user_attrs",
489
            ]
490
        else:
491
            expected_keys_1 = ["name", "direction_0", "n_trials", "datetime_start", "user_attrs"]
492
            expected_keys_2 = [
493
                "name",
494
                "direction_0",
495
                "direction_1",
496
                "n_trials",
497
                "datetime_start",
498
                "user_attrs",
499
            ]
500

501
        assert len(studies) == 2
502
        if output_format is None or output_format == "table":
503
            assert list(studies[0].keys()) == expected_keys_1
504
            assert list(studies[1].keys()) == expected_keys_2
505
        else:
506
            assert set(studies[0].keys()) == set(expected_keys_1)
507
            assert set(studies[1].keys()) == set(expected_keys_2)
508

509
        # Check study_name, direction, n_trials and user_attrs for the first study.
510
        assert studies[0]["name"] == study_1.study_name
511
        if output_format is None or output_format == "table":
512
            assert studies[0]["n_trials"] == "0"
513
            assert studies[0]["user_attrs"] == "{}"
514
        else:
515
            assert studies[0]["n_trials"] == 0
516
            assert studies[0]["user_attrs"] == {}
517
        assert studies[0]["direction_0"] == "MINIMIZE"
518

519
        # Check study_name, direction, n_trials and user_attrs for the second study.
520
        assert studies[1]["name"] == study_2.study_name
521
        if output_format is None or output_format == "table":
522
            assert studies[1]["n_trials"] == "10"
523
            assert studies[1]["user_attrs"] == "{'key_1': 'value_1', 'key_2': 'value_2'}"
524
        else:
525
            assert studies[1]["n_trials"] == 10
526
            assert studies[1]["user_attrs"] == {"key_1": "value_1", "key_2": "value_2"}
527
        assert studies[1]["direction_0"] == "MINIMIZE"
528
        assert studies[1]["direction_1"] == "MAXIMIZE"
529

530

531
@pytest.mark.skip_coverage
532
@pytest.mark.parametrize("objective", (objective_func, objective_func_branched_search_space))
533
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
534
def test_trials_command(objective: Callable[[Trial], float], output_format: Optional[str]) -> None:
535
    with StorageSupplier("sqlite") as storage:
536
        assert isinstance(storage, RDBStorage)
537
        storage_url = str(storage.engine.url)
538
        study_name = "test_study"
539
        n_trials = 10
540

541
        study = optuna.create_study(storage=storage, study_name=study_name)
542
        study.optimize(objective, n_trials=n_trials)
543
        attrs = (
544
            "number",
545
            "value",
546
            "datetime_start",
547
            "datetime_complete",
548
            "duration",
549
            "params",
550
            "user_attrs",
551
            "state",
552
        )
553

554
        # Run command.
555
        command = [
556
            "optuna",
557
            "trials",
558
            "--storage",
559
            storage_url,
560
            "--study-name",
561
            study_name,
562
        ]
563

564
        if output_format is not None:
565
            command += ["--format", output_format]
566

567
        output = str(subprocess.check_output(command).decode().strip())
568
        trials = _parse_output(output, output_format or "table")
569

570
        assert len(trials) == n_trials
571

572
        df = study.trials_dataframe(attrs, multi_index=True)
573

574
        for i, trial in enumerate(trials):
575
            for key in df.columns:
576
                expected_value = df.loc[i][key]
577

578
                # The param may be NaN when the objective function has branched search space.
579
                if (
580
                    key[0] == "params"
581
                    and isinstance(expected_value, float)
582
                    and np.isnan(expected_value)
583
                ):
584
                    if output_format is None or output_format == "table":
585
                        assert key[1] not in eval(trial["params"])
586
                    else:
587
                        assert key[1] not in trial["params"]
588
                    continue
589

590
                if key[1] == "":
591
                    value = trial[key[0]]
592
                else:
593
                    if output_format is None or output_format == "table":
594
                        value = eval(trial[key[0]])[key[1]]
595
                    else:
596
                        value = trial[key[0]][key[1]]
597

598
                if isinstance(value, (int, float)):
599
                    if np.isnan(expected_value):
600
                        assert np.isnan(value)
601
                    else:
602
                        assert value == expected_value
603
                elif isinstance(expected_value, Timestamp):
604
                    assert value == expected_value.strftime("%Y-%m-%d %H:%M:%S")
605
                elif isinstance(expected_value, Timedelta):
606
                    assert value == str(expected_value.to_pytimedelta())
607
                else:
608
                    assert value == str(expected_value)
609

610

611
@pytest.mark.skip_coverage
612
@pytest.mark.parametrize("objective", (objective_func, objective_func_branched_search_space))
613
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
614
def test_trials_command_flatten(
615
    objective: Callable[[Trial], float], output_format: Optional[str]
616
) -> None:
617
    with StorageSupplier("sqlite") as storage:
618
        assert isinstance(storage, RDBStorage)
619
        storage_url = str(storage.engine.url)
620
        study_name = "test_study"
621
        n_trials = 10
622

623
        study = optuna.create_study(storage=storage, study_name=study_name)
624
        study.optimize(objective, n_trials=n_trials)
625
        attrs = (
626
            "number",
627
            "value",
628
            "datetime_start",
629
            "datetime_complete",
630
            "duration",
631
            "params",
632
            "user_attrs",
633
            "state",
634
        )
635

636
        # Run command.
637
        command = [
638
            "optuna",
639
            "trials",
640
            "--storage",
641
            storage_url,
642
            "--study-name",
643
            study_name,
644
            "--flatten",
645
        ]
646

647
        if output_format is not None:
648
            command += ["--format", output_format]
649

650
        output = str(subprocess.check_output(command).decode().strip())
651
        trials = _parse_output(output, output_format or "table")
652

653
        assert len(trials) == n_trials
654

655
        df = study.trials_dataframe(attrs)
656

657
        for i, trial in enumerate(trials):
658
            assert set(trial.keys()) <= set(df.columns)
659
            for key in df.columns:
660
                expected_value = df.loc[i][key]
661

662
                # The param may be NaN when the objective function has branched search space.
663
                if (
664
                    key.startswith("params_")
665
                    and isinstance(expected_value, float)
666
                    and np.isnan(expected_value)
667
                ):
668
                    if output_format is None or output_format == "table":
669
                        assert trial[key] == ""
670
                    else:
671
                        assert key not in trial
672
                    continue
673

674
                value = trial[key]
675

676
                if isinstance(value, (int, float)):
677
                    if np.isnan(expected_value):
678
                        assert np.isnan(value)
679
                    else:
680
                        assert value == expected_value
681
                elif isinstance(expected_value, Timestamp):
682
                    assert value == expected_value.strftime("%Y-%m-%d %H:%M:%S")
683
                elif isinstance(expected_value, Timedelta):
684
                    assert value == str(expected_value.to_pytimedelta())
685
                else:
686
                    assert value == str(expected_value)
687

688

689
@pytest.mark.skip_coverage
690
@pytest.mark.parametrize("objective", (objective_func, objective_func_branched_search_space))
691
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
692
def test_best_trial_command(
693
    objective: Callable[[Trial], float], output_format: Optional[str]
694
) -> None:
695
    with StorageSupplier("sqlite") as storage:
696
        assert isinstance(storage, RDBStorage)
697
        storage_url = str(storage.engine.url)
698
        study_name = "test_study"
699
        n_trials = 10
700

701
        study = optuna.create_study(storage=storage, study_name=study_name)
702
        study.optimize(objective, n_trials=n_trials)
703
        attrs = (
704
            "number",
705
            "value",
706
            "datetime_start",
707
            "datetime_complete",
708
            "duration",
709
            "params",
710
            "user_attrs",
711
            "state",
712
        )
713

714
        # Run command.
715
        command = [
716
            "optuna",
717
            "best-trial",
718
            "--storage",
719
            storage_url,
720
            "--study-name",
721
            study_name,
722
        ]
723

724
        if output_format is not None:
725
            command += ["--format", output_format]
726

727
        output = str(subprocess.check_output(command).decode().strip())
728
        best_trial = _parse_output(output, output_format or "table")
729

730
        if output_format is None or output_format == "table":
731
            assert len(best_trial) == 1
732
            best_trial = best_trial[0]
733

734
        df = study.trials_dataframe(attrs, multi_index=True)
735

736
        for key in df.columns:
737
            expected_value = df.loc[study.best_trial.number][key]
738

739
            # The param may be NaN when the objective function has branched search space.
740
            if (
741
                key[0] == "params"
742
                and isinstance(expected_value, float)
743
                and np.isnan(expected_value)
744
            ):
745
                if output_format is None or output_format == "table":
746
                    assert key[1] not in eval(best_trial["params"])
747
                else:
748
                    assert key[1] not in best_trial["params"]
749
                continue
750

751
            if key[1] == "":
752
                value = best_trial[key[0]]
753
            else:
754
                if output_format is None or output_format == "table":
755
                    value = eval(best_trial[key[0]])[key[1]]
756
                else:
757
                    value = best_trial[key[0]][key[1]]
758

759
            if isinstance(value, (int, float)):
760
                if np.isnan(expected_value):
761
                    assert np.isnan(value)
762
                else:
763
                    assert value == expected_value
764
            elif isinstance(expected_value, Timestamp):
765
                assert value == expected_value.strftime("%Y-%m-%d %H:%M:%S")
766
            elif isinstance(expected_value, Timedelta):
767
                assert value == str(expected_value.to_pytimedelta())
768
            else:
769
                assert value == str(expected_value)
770

771

772
@pytest.mark.skip_coverage
773
@pytest.mark.parametrize("objective", (objective_func, objective_func_branched_search_space))
774
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
775
def test_best_trial_command_flatten(
776
    objective: Callable[[Trial], float], output_format: Optional[str]
777
) -> None:
778
    with StorageSupplier("sqlite") as storage:
779
        assert isinstance(storage, RDBStorage)
780
        storage_url = str(storage.engine.url)
781
        study_name = "test_study"
782
        n_trials = 10
783

784
        study = optuna.create_study(storage=storage, study_name=study_name)
785
        study.optimize(objective, n_trials=n_trials)
786
        attrs = (
787
            "number",
788
            "value",
789
            "datetime_start",
790
            "datetime_complete",
791
            "duration",
792
            "params",
793
            "user_attrs",
794
            "state",
795
        )
796

797
        # Run command.
798
        command = [
799
            "optuna",
800
            "best-trial",
801
            "--storage",
802
            storage_url,
803
            "--study-name",
804
            study_name,
805
            "--flatten",
806
        ]
807

808
        if output_format is not None:
809
            command += ["--format", output_format]
810

811
        output = str(subprocess.check_output(command).decode().strip())
812
        best_trial = _parse_output(output, output_format or "table")
813

814
        if output_format is None or output_format == "table":
815
            assert len(best_trial) == 1
816
            best_trial = best_trial[0]
817

818
        df = study.trials_dataframe(attrs)
819

820
        assert set(best_trial.keys()) <= set(df.columns)
821
        for key in df.columns:
822
            expected_value = df.loc[study.best_trial.number][key]
823

824
            # The param may be NaN when the objective function has branched search space.
825
            if (
826
                key.startswith("params_")
827
                and isinstance(expected_value, float)
828
                and np.isnan(expected_value)
829
            ):
830
                if output_format is None or output_format == "table":
831
                    assert best_trial[key] == ""
832
                else:
833
                    assert key not in best_trial
834
                continue
835

836
            value = best_trial[key]
837
            if isinstance(value, (int, float)):
838
                if np.isnan(expected_value):
839
                    assert np.isnan(value)
840
                else:
841
                    assert value == expected_value
842
            elif isinstance(expected_value, Timestamp):
843
                assert value == expected_value.strftime("%Y-%m-%d %H:%M:%S")
844
            elif isinstance(expected_value, Timedelta):
845
                assert value == str(expected_value.to_pytimedelta())
846
            else:
847
                assert value == str(expected_value)
848

849

850
@pytest.mark.skip_coverage
851
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
852
def test_best_trials_command(output_format: Optional[str]) -> None:
853
    with StorageSupplier("sqlite") as storage:
854
        assert isinstance(storage, RDBStorage)
855
        storage_url = str(storage.engine.url)
856
        study_name = "test_study"
857
        n_trials = 10
858

859
        study = optuna.create_study(
860
            storage=storage, study_name=study_name, directions=("minimize", "minimize")
861
        )
862
        study.optimize(objective_func_multi_objective, n_trials=n_trials)
863
        attrs = (
864
            "number",
865
            "values",
866
            "datetime_start",
867
            "datetime_complete",
868
            "duration",
869
            "params",
870
            "user_attrs",
871
            "state",
872
        )
873

874
        # Run command.
875
        command = [
876
            "optuna",
877
            "best-trials",
878
            "--storage",
879
            storage_url,
880
            "--study-name",
881
            study_name,
882
        ]
883

884
        if output_format is not None:
885
            command += ["--format", output_format]
886

887
        output = str(subprocess.check_output(command).decode().strip())
888
        trials = _parse_output(output, output_format or "table")
889
        best_trials = [trial.number for trial in study.best_trials]
890

891
        assert len(trials) == len(best_trials)
892

893
        df = study.trials_dataframe(attrs, multi_index=True)
894

895
        for trial in trials:
896
            number = int(trial["number"]) if output_format in (None, "table") else trial["number"]
897
            assert number in best_trials
898
            for key in df.columns:
899
                expected_value = df.loc[number][key]
900

901
                # The param may be NaN when the objective function has branched search space.
902
                if (
903
                    key[0] == "params"
904
                    and isinstance(expected_value, float)
905
                    and np.isnan(expected_value)
906
                ):
907
                    if output_format is None or output_format == "table":
908
                        assert key[1] not in eval(trial["params"])
909
                    else:
910
                        assert key[1] not in trial["params"]
911
                    continue
912

913
                if key[1] == "":
914
                    value = trial[key[0]]
915
                else:
916
                    if output_format is None or output_format == "table":
917
                        value = eval(trial[key[0]])[key[1]]
918
                    else:
919
                        value = trial[key[0]][key[1]]
920

921
                if isinstance(value, (int, float)):
922
                    if np.isnan(expected_value):
923
                        assert np.isnan(value)
924
                    else:
925
                        assert value == expected_value
926
                elif isinstance(expected_value, Timestamp):
927
                    assert value == expected_value.strftime("%Y-%m-%d %H:%M:%S")
928
                elif isinstance(expected_value, Timedelta):
929
                    assert value == str(expected_value.to_pytimedelta())
930
                else:
931
                    assert value == str(expected_value)
932

933

934
@pytest.mark.skip_coverage
935
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
936
def test_best_trials_command_flatten(output_format: Optional[str]) -> None:
937
    with StorageSupplier("sqlite") as storage:
938
        assert isinstance(storage, RDBStorage)
939
        storage_url = str(storage.engine.url)
940
        study_name = "test_study"
941
        n_trials = 10
942

943
        study = optuna.create_study(
944
            storage=storage, study_name=study_name, directions=("minimize", "minimize")
945
        )
946
        study.optimize(objective_func_multi_objective, n_trials=n_trials)
947
        attrs = (
948
            "number",
949
            "values",
950
            "datetime_start",
951
            "datetime_complete",
952
            "duration",
953
            "params",
954
            "user_attrs",
955
            "state",
956
        )
957

958
        # Run command.
959
        command = [
960
            "optuna",
961
            "best-trials",
962
            "--storage",
963
            storage_url,
964
            "--study-name",
965
            study_name,
966
            "--flatten",
967
        ]
968

969
        if output_format is not None:
970
            command += ["--format", output_format]
971

972
        output = str(subprocess.check_output(command).decode().strip())
973
        trials = _parse_output(output, output_format or "table")
974
        best_trials = [trial.number for trial in study.best_trials]
975

976
        assert len(trials) == len(best_trials)
977

978
        df = study.trials_dataframe(attrs)
979

980
        for trial in trials:
981
            assert set(trial.keys()) <= set(df.columns)
982
            number = int(trial["number"]) if output_format in (None, "table") else trial["number"]
983
            for key in df.columns:
984
                expected_value = df.loc[number][key]
985

986
                # The param may be NaN when the objective function has branched search space.
987
                if (
988
                    key.startswith("params_")
989
                    and isinstance(expected_value, float)
990
                    and np.isnan(expected_value)
991
                ):
992
                    if output_format is None or output_format == "table":
993
                        assert trial[key] == ""
994
                    else:
995
                        assert key not in trial
996
                    continue
997

998
                value = trial[key]
999
                if isinstance(value, (int, float)):
1000
                    if np.isnan(expected_value):
1001
                        assert np.isnan(value)
1002
                    else:
1003
                        assert value == expected_value
1004
                elif isinstance(expected_value, Timestamp):
1005
                    assert value == expected_value.strftime("%Y-%m-%d %H:%M:%S")
1006
                elif isinstance(expected_value, Timedelta):
1007
                    assert value == str(expected_value.to_pytimedelta())
1008
                else:
1009
                    assert value == str(expected_value)
1010

1011

1012
@pytest.mark.skip_coverage
1013
def test_create_study_command_with_skip_if_exists() -> None:
1014
    with StorageSupplier("sqlite") as storage:
1015
        assert isinstance(storage, RDBStorage)
1016
        storage_url = str(storage.engine.url)
1017
        study_name = "test_study"
1018

1019
        # Create study with name.
1020
        command = ["optuna", "create-study", "--storage", storage_url, "--study-name", study_name]
1021
        study_name = str(subprocess.check_output(command).decode().strip())
1022

1023
        # Check if study_name is stored in the storage.
1024
        study_id = storage.get_study_id_from_name(study_name)
1025
        assert storage.get_study_name_from_id(study_id) == study_name
1026

1027
        # Try to create the same name study without `--skip-if-exists` flag (error).
1028
        command = ["optuna", "create-study", "--storage", storage_url, "--study-name", study_name]
1029
        with pytest.raises(subprocess.CalledProcessError):
1030
            subprocess.check_output(command)
1031

1032
        # Try to create the same name study with `--skip-if-exists` flag (OK).
1033
        command = [
1034
            "optuna",
1035
            "create-study",
1036
            "--storage",
1037
            storage_url,
1038
            "--study-name",
1039
            study_name,
1040
            "--skip-if-exists",
1041
        ]
1042
        study_name = str(subprocess.check_output(command).decode().strip())
1043
        new_study_id = storage.get_study_id_from_name(study_name)
1044
        assert study_id == new_study_id  # The existing study instance is reused.
1045

1046

1047
@pytest.mark.skip_coverage
1048
def test_study_optimize_command() -> None:
1049
    with StorageSupplier("sqlite") as storage:
1050
        assert isinstance(storage, RDBStorage)
1051
        storage_url = str(storage.engine.url)
1052

1053
        study_name = storage.get_study_name_from_id(
1054
            storage.create_new_study(directions=[StudyDirection.MINIMIZE])
1055
        )
1056
        command = [
1057
            "optuna",
1058
            "study",
1059
            "optimize",
1060
            "--study-name",
1061
            study_name,
1062
            "--n-trials",
1063
            "10",
1064
            __file__,
1065
            "objective_func",
1066
            "--storage",
1067
            storage_url,
1068
        ]
1069
        subprocess.check_call(command)
1070

1071
        study = optuna.load_study(storage=storage_url, study_name=study_name)
1072
        assert len(study.trials) == 10
1073
        assert "x" in study.best_params
1074

1075
        # Check if a default value of study_name is stored in the storage.
1076
        assert storage.get_study_name_from_id(study._study_id).startswith(
1077
            DEFAULT_STUDY_NAME_PREFIX
1078
        )
1079

1080

1081
@pytest.mark.skip_coverage
1082
def test_study_optimize_command_inconsistent_args() -> None:
1083
    with NamedTemporaryFilePool() as tf:
1084
        db_url = "sqlite:///{}".format(tf.name)
1085

1086
        # --study-name argument is missing.
1087
        with pytest.raises(subprocess.CalledProcessError):
1088
            subprocess.check_call(
1089
                [
1090
                    "optuna",
1091
                    "study",
1092
                    "optimize",
1093
                    "--storage",
1094
                    db_url,
1095
                    "--n-trials",
1096
                    "10",
1097
                    __file__,
1098
                    "objective_func",
1099
                ]
1100
            )
1101

1102

1103
@pytest.mark.skip_coverage
1104
def test_empty_argv() -> None:
1105
    command_empty = ["optuna"]
1106
    command_empty_output = str(subprocess.check_output(command_empty))
1107

1108
    command_help = ["optuna", "help"]
1109
    command_help_output = str(subprocess.check_output(command_help))
1110

1111
    assert command_empty_output == command_help_output
1112

1113

1114
def test_check_storage_url() -> None:
1115
    storage_in_args = "sqlite:///args.db"
1116
    assert storage_in_args == optuna.cli._check_storage_url(storage_in_args)
1117

1118
    with pytest.warns(ExperimentalWarning):
1119
        with patch.dict("optuna.cli.os.environ", {"OPTUNA_STORAGE": "sqlite:///args.db"}):
1120
            optuna.cli._check_storage_url(None)
1121

1122
    with pytest.raises(CLIUsageError):
1123
        optuna.cli._check_storage_url(None)
1124

1125

1126
@pytest.mark.skipif(platform.system() == "Windows", reason="Skip on Windows")
1127
@patch("optuna.storages._journal.redis.redis")
1128
def test_get_storage_without_storage_class(mock_redis: MagicMock) -> None:
1129
    with tempfile.NamedTemporaryFile(suffix=".db") as fp:
1130
        storage = optuna.cli._get_storage(f"sqlite:///{fp.name}", storage_class=None)
1131
        assert isinstance(storage, RDBStorage)
1132

1133
    with tempfile.NamedTemporaryFile(suffix=".log") as fp:
1134
        storage = optuna.cli._get_storage(fp.name, storage_class=None)
1135
        assert isinstance(storage, JournalStorage)
1136
        assert isinstance(storage._backend, JournalFileStorage)
1137

1138
    mock_redis.Redis = fakeredis.FakeRedis
1139
    storage = optuna.cli._get_storage("redis://localhost:6379", storage_class=None)
1140
    assert isinstance(storage, JournalStorage)
1141
    assert isinstance(storage._backend, JournalRedisStorage)
1142

1143
    with pytest.raises(CLIUsageError):
1144
        optuna.cli._get_storage("./file-not-found.log", storage_class=None)
1145

1146

1147
@pytest.mark.skipif(platform.system() == "Windows", reason="Skip on Windows")
1148
@patch("optuna.storages._journal.redis.redis")
1149
def test_get_storage_with_storage_class(mock_redis: MagicMock) -> None:
1150
    with tempfile.NamedTemporaryFile(suffix=".db") as fp:
1151
        storage = optuna.cli._get_storage(f"sqlite:///{fp.name}", storage_class=None)
1152
        assert isinstance(storage, RDBStorage)
1153

1154
    with tempfile.NamedTemporaryFile(suffix=".log") as fp:
1155
        storage = optuna.cli._get_storage(fp.name, storage_class="JournalFileStorage")
1156
        assert isinstance(storage, JournalStorage)
1157
        assert isinstance(storage._backend, JournalFileStorage)
1158

1159
    mock_redis.Redis = fakeredis.FakeRedis
1160
    storage = optuna.cli._get_storage(
1161
        "redis:///localhost:6379", storage_class="JournalRedisStorage"
1162
    )
1163
    assert isinstance(storage, JournalStorage)
1164
    assert isinstance(storage._backend, JournalRedisStorage)
1165

1166
    with pytest.raises(CLIUsageError):
1167
        with tempfile.NamedTemporaryFile(suffix=".db") as fp:
1168
            optuna.cli._get_storage(f"sqlite:///{fp.name}", storage_class="InMemoryStorage")
1169

1170

1171
@pytest.mark.skip_coverage
1172
def test_storage_upgrade_command() -> None:
1173
    with StorageSupplier("sqlite") as storage:
1174
        assert isinstance(storage, RDBStorage)
1175
        storage_url = str(storage.engine.url)
1176

1177
        command = ["optuna", "storage", "upgrade"]
1178
        with pytest.raises(CalledProcessError):
1179
            subprocess.check_call(
1180
                command,
1181
                env={k: v for k, v in os.environ.items() if k != "OPTUNA_STORAGE"},
1182
            )
1183

1184
        command.extend(["--storage", storage_url])
1185
        subprocess.check_call(command)
1186

1187

1188
@pytest.mark.skip_coverage
1189
def test_storage_upgrade_command_with_invalid_url() -> None:
1190
    with StorageSupplier("sqlite") as storage:
1191
        assert isinstance(storage, RDBStorage)
1192

1193
        command = ["optuna", "storage", "upgrade", "--storage", "invalid-storage-url"]
1194
        with pytest.raises(CalledProcessError):
1195
            subprocess.check_call(command)
1196

1197

1198
@pytest.mark.skip_coverage
1199
@pytest.mark.parametrize(
1200
    "direction,directions,sampler,sampler_kwargs,output_format",
1201
    [
1202
        (None, None, None, None, None),
1203
        ("minimize", None, None, None, None),
1204
        (None, "minimize maximize", None, None, None),
1205
        (None, None, "RandomSampler", None, None),
1206
        (None, None, "TPESampler", '{"multivariate": true}', None),
1207
        (None, None, None, None, "json"),
1208
        (None, None, None, None, "yaml"),
1209
    ],
1210
)
1211
def test_ask(
1212
    direction: Optional[str],
1213
    directions: Optional[str],
1214
    sampler: Optional[str],
1215
    sampler_kwargs: Optional[str],
1216
    output_format: Optional[str],
1217
) -> None:
1218
    study_name = "test_study"
1219
    search_space = (
1220
        '{"x": {"name": "FloatDistribution", "attributes": {"low": 0.0, "high": 1.0}}, '
1221
        '"y": {"name": "CategoricalDistribution", "attributes": {"choices": ["foo"]}}}'
1222
    )
1223

1224
    with NamedTemporaryFilePool() as tf:
1225
        db_url = "sqlite:///{}".format(tf.name)
1226

1227
        args = [
1228
            "optuna",
1229
            "ask",
1230
            "--storage",
1231
            db_url,
1232
            "--study-name",
1233
            study_name,
1234
            "--search-space",
1235
            search_space,
1236
        ]
1237

1238
        if direction is not None:
1239
            args += ["--direction", direction]
1240
        if directions is not None:
1241
            args += ["--directions"] + directions.split()
1242
        if sampler is not None:
1243
            args += ["--sampler", sampler]
1244
        if sampler_kwargs is not None:
1245
            args += ["--sampler-kwargs", sampler_kwargs]
1246
        if output_format is not None:
1247
            args += ["--format", output_format]
1248

1249
        result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1250
        output = str(result.stdout.decode().strip())
1251
        trial = _parse_output(output, output_format or "json")
1252

1253
        if output_format == "table":
1254
            assert len(trial) == 1
1255
            trial = trial[0]
1256
            assert trial["number"] == "0"
1257
            params = eval(trial["params"])
1258
            assert len(params) == 2
1259
            assert 0 <= params["x"] <= 1
1260
            assert params["y"] == "foo"
1261
        else:
1262
            assert trial["number"] == 0
1263
            assert 0 <= trial["params"]["x"] <= 1
1264
            assert trial["params"]["y"] == "foo"
1265

1266
        if direction is not None or directions is not None:
1267
            warning_message = result.stderr.decode()
1268
            assert "FutureWarning" in warning_message
1269

1270

1271
@pytest.mark.skip_coverage
1272
@pytest.mark.parametrize(
1273
    "direction,directions,sampler,sampler_kwargs,output_format",
1274
    [
1275
        (None, None, None, None, None),
1276
        ("minimize", None, None, None, None),
1277
        (None, "minimize maximize", None, None, None),
1278
        (None, None, "RandomSampler", None, None),
1279
        (None, None, "TPESampler", '{"multivariate": true}', None),
1280
        (None, None, None, None, "json"),
1281
        (None, None, None, None, "yaml"),
1282
    ],
1283
)
1284
def test_ask_flatten(
1285
    direction: Optional[str],
1286
    directions: Optional[str],
1287
    sampler: Optional[str],
1288
    sampler_kwargs: Optional[str],
1289
    output_format: Optional[str],
1290
) -> None:
1291
    study_name = "test_study"
1292
    search_space = (
1293
        '{"x": {"name": "FloatDistribution", "attributes": {"low": 0.0, "high": 1.0}}, '
1294
        '"y": {"name": "CategoricalDistribution", "attributes": {"choices": ["foo"]}}}'
1295
    )
1296

1297
    with NamedTemporaryFilePool() as tf:
1298
        db_url = "sqlite:///{}".format(tf.name)
1299

1300
        args = [
1301
            "optuna",
1302
            "ask",
1303
            "--storage",
1304
            db_url,
1305
            "--study-name",
1306
            study_name,
1307
            "--search-space",
1308
            search_space,
1309
            "--flatten",
1310
        ]
1311

1312
        if direction is not None:
1313
            args += ["--direction", direction]
1314
        if directions is not None:
1315
            args += ["--directions"] + directions.split()
1316
        if sampler is not None:
1317
            args += ["--sampler", sampler]
1318
        if sampler_kwargs is not None:
1319
            args += ["--sampler-kwargs", sampler_kwargs]
1320
        if output_format is not None:
1321
            args += ["--format", output_format]
1322

1323
        result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1324
        output = str(result.stdout.decode().strip())
1325
        trial = _parse_output(output, output_format or "json")
1326

1327
        if output_format == "table":
1328
            assert len(trial) == 1
1329
            trial = trial[0]
1330
            assert trial["number"] == "0"
1331
            assert 0 <= float(trial["params_x"]) <= 1
1332
            assert trial["params_y"] == "foo"
1333
        else:
1334
            assert trial["number"] == 0
1335
            assert 0 <= trial["params_x"] <= 1
1336
            assert trial["params_y"] == "foo"
1337

1338
        if direction is not None or directions is not None:
1339
            warning_message = result.stderr.decode()
1340
            assert "FutureWarning" in warning_message
1341

1342

1343
@pytest.mark.skip_coverage
1344
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
1345
def test_ask_empty_search_space(output_format: str) -> None:
1346
    study_name = "test_study"
1347

1348
    with NamedTemporaryFilePool() as tf:
1349
        db_url = "sqlite:///{}".format(tf.name)
1350

1351
        args = [
1352
            "optuna",
1353
            "ask",
1354
            "--storage",
1355
            db_url,
1356
            "--study-name",
1357
            study_name,
1358
        ]
1359

1360
        if output_format is not None:
1361
            args += ["--format", output_format]
1362

1363
        output = str(subprocess.check_output(args).decode().strip())
1364
        trial = _parse_output(output, output_format or "json")
1365

1366
        if output_format == "table":
1367
            assert len(trial) == 1
1368
            trial = trial[0]
1369
            assert trial["number"] == "0"
1370
            assert trial["params"] == "{}"
1371
        else:
1372
            assert trial["number"] == 0
1373
            assert trial["params"] == {}
1374

1375

1376
@pytest.mark.skip_coverage
1377
@pytest.mark.parametrize("output_format", (None, "table", "json", "yaml"))
1378
def test_ask_empty_search_space_flatten(output_format: str) -> None:
1379
    study_name = "test_study"
1380

1381
    with NamedTemporaryFilePool() as tf:
1382
        db_url = "sqlite:///{}".format(tf.name)
1383

1384
        args = [
1385
            "optuna",
1386
            "ask",
1387
            "--storage",
1388
            db_url,
1389
            "--study-name",
1390
            study_name,
1391
            "--flatten",
1392
        ]
1393

1394
        if output_format is not None:
1395
            args += ["--format", output_format]
1396

1397
        output = str(subprocess.check_output(args).decode().strip())
1398
        trial = _parse_output(output, output_format or "json")
1399

1400
        if output_format == "table":
1401
            assert len(trial) == 1
1402
            trial = trial[0]
1403
            assert trial["number"] == "0"
1404
            assert "params" not in trial
1405
        else:
1406
            assert trial["number"] == 0
1407
            assert "params" not in trial
1408

1409

1410
@pytest.mark.skip_coverage
1411
def test_ask_sampler_kwargs_without_sampler() -> None:
1412
    study_name = "test_study"
1413
    search_space = (
1414
        '{"x": {"name": "FloatDistribution", "attributes": {"low": 0.0, "high": 1.0}}, '
1415
        '"y": {"name": "CategoricalDistribution", "attributes": {"choices": ["foo"]}}}'
1416
    )
1417

1418
    with NamedTemporaryFilePool() as tf:
1419
        db_url = "sqlite:///{}".format(tf.name)
1420

1421
        args = [
1422
            "optuna",
1423
            "ask",
1424
            "--storage",
1425
            db_url,
1426
            "--study-name",
1427
            study_name,
1428
            "--search-space",
1429
            search_space,
1430
            "--sampler-kwargs",
1431
            '{"multivariate": true}',
1432
        ]
1433

1434
        result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1435
        error_message = result.stderr.decode()
1436
        assert "`--sampler_kwargs` is set without `--sampler`." in error_message
1437

1438

1439
@pytest.mark.skip_coverage
1440
@pytest.mark.parametrize(
1441
    "direction,directions,sampler,sampler_kwargs",
1442
    [
1443
        (None, None, None, None),
1444
        ("minimize", None, None, None),
1445
        (None, "minimize maximize", None, None),
1446
        (None, None, "RandomSampler", None),
1447
        (None, None, "TPESampler", '{"multivariate": true}'),
1448
    ],
1449
)
1450
def test_create_study_and_ask(
1451
    direction: Optional[str],
1452
    directions: Optional[str],
1453
    sampler: Optional[str],
1454
    sampler_kwargs: Optional[str],
1455
) -> None:
1456
    study_name = "test_study"
1457
    search_space = (
1458
        '{"x": {"name": "FloatDistribution", "attributes": {"low": 0.0, "high": 1.0}}, '
1459
        '"y": {"name": "CategoricalDistribution", "attributes": {"choices": ["foo"]}}}'
1460
    )
1461

1462
    with NamedTemporaryFilePool() as tf:
1463
        db_url = "sqlite:///{}".format(tf.name)
1464

1465
        create_study_args = [
1466
            "optuna",
1467
            "create-study",
1468
            "--storage",
1469
            db_url,
1470
            "--study-name",
1471
            study_name,
1472
        ]
1473

1474
        if direction is not None:
1475
            create_study_args += ["--direction", direction]
1476
        if directions is not None:
1477
            create_study_args += ["--directions"] + directions.split()
1478
        subprocess.check_call(create_study_args)
1479

1480
        args = [
1481
            "optuna",
1482
            "ask",
1483
            "--storage",
1484
            db_url,
1485
            "--study-name",
1486
            study_name,
1487
            "--search-space",
1488
            search_space,
1489
        ]
1490

1491
        if sampler is not None:
1492
            args += ["--sampler", sampler]
1493
        if sampler_kwargs is not None:
1494
            args += ["--sampler-kwargs", sampler_kwargs]
1495

1496
        output = str(subprocess.check_output(args).decode().strip())
1497
        trial = _parse_output(output, "json")
1498

1499
        assert trial["number"] == 0
1500
        assert 0 <= trial["params"]["x"] <= 1
1501
        assert trial["params"]["y"] == "foo"
1502

1503

1504
@pytest.mark.skip_coverage
1505
@pytest.mark.parametrize(
1506
    "direction,directions,ask_direction,ask_directions",
1507
    [
1508
        (None, None, "maximize", None),
1509
        ("minimize", None, "maximize", None),
1510
        ("minimize", None, None, "minimize minimize"),
1511
        (None, "minimize maximize", None, "maximize minimize"),
1512
        (None, "minimize maximize", "minimize", None),
1513
    ],
1514
)
1515
def test_create_study_and_ask_with_inconsistent_directions(
1516
    direction: Optional[str],
1517
    directions: Optional[str],
1518
    ask_direction: Optional[str],
1519
    ask_directions: Optional[str],
1520
) -> None:
1521
    study_name = "test_study"
1522
    search_space = (
1523
        '{"x": {"name": "FloatDistribution", "attributes": {"low": 0.0, "high": 1.0}}, '
1524
        '"y": {"name": "CategoricalDistribution", "attributes": {"choices": ["foo"]}}}'
1525
    )
1526

1527
    with NamedTemporaryFilePool() as tf:
1528
        db_url = "sqlite:///{}".format(tf.name)
1529

1530
        create_study_args = [
1531
            "optuna",
1532
            "create-study",
1533
            "--storage",
1534
            db_url,
1535
            "--study-name",
1536
            study_name,
1537
        ]
1538

1539
        if direction is not None:
1540
            create_study_args += ["--direction", direction]
1541
        if directions is not None:
1542
            create_study_args += ["--directions"] + directions.split()
1543
        subprocess.check_call(create_study_args)
1544

1545
        args = [
1546
            "optuna",
1547
            "ask",
1548
            "--storage",
1549
            db_url,
1550
            "--study-name",
1551
            study_name,
1552
            "--search-space",
1553
            search_space,
1554
        ]
1555
        if ask_direction is not None:
1556
            args += ["--direction", ask_direction]
1557
        if ask_directions is not None:
1558
            args += ["--directions"] + ask_directions.split()
1559

1560
        result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1561
        error_message = result.stderr.decode()
1562
        assert "Cannot overwrite study direction" in error_message
1563

1564

1565
@pytest.mark.skip_coverage
1566
def test_ask_with_both_direction_and_directions() -> None:
1567
    study_name = "test_study"
1568
    search_space = (
1569
        '{"x": {"name": "FloatDistribution", "attributes": {"low": 0.0, "high": 1.0}}, '
1570
        '"y": {"name": "CategoricalDistribution", "attributes": {"choices": ["foo"]}}}'
1571
    )
1572

1573
    with NamedTemporaryFilePool() as tf:
1574
        db_url = "sqlite:///{}".format(tf.name)
1575

1576
        create_study_args = [
1577
            "optuna",
1578
            "create-study",
1579
            "--storage",
1580
            db_url,
1581
            "--study-name",
1582
            study_name,
1583
        ]
1584
        subprocess.check_call(create_study_args)
1585

1586
        args = [
1587
            "optuna",
1588
            "ask",
1589
            "--storage",
1590
            db_url,
1591
            "--study-name",
1592
            study_name,
1593
            "--search-space",
1594
            search_space,
1595
            "--direction",
1596
            "minimize",
1597
            "--directions",
1598
            "minimize",
1599
        ]
1600

1601
        result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1602
        error_message = result.stderr.decode()
1603
        assert "Specify only one of `direction` and `directions`." in error_message
1604

1605

1606
@pytest.mark.skip_coverage
1607
def test_tell() -> None:
1608
    study_name = "test_study"
1609

1610
    with NamedTemporaryFilePool() as tf:
1611
        db_url = "sqlite:///{}".format(tf.name)
1612

1613
        output: Any = subprocess.check_output(
1614
            [
1615
                "optuna",
1616
                "ask",
1617
                "--storage",
1618
                db_url,
1619
                "--study-name",
1620
                study_name,
1621
                "--format",
1622
                "json",
1623
            ]
1624
        )
1625
        output = output.decode("utf-8")
1626
        output = json.loads(output)
1627
        trial_number = output["number"]
1628

1629
        subprocess.check_output(
1630
            [
1631
                "optuna",
1632
                "tell",
1633
                "--storage",
1634
                db_url,
1635
                "--trial-number",
1636
                str(trial_number),
1637
                "--values",
1638
                "1.2",
1639
            ]
1640
        )
1641

1642
        study = optuna.load_study(storage=db_url, study_name=study_name)
1643
        assert len(study.trials) == 1
1644
        assert study.trials[0].state == TrialState.COMPLETE
1645
        assert study.trials[0].values == [1.2]
1646

1647
        # Error when updating a finished trial.
1648
        ret = subprocess.run(
1649
            [
1650
                "optuna",
1651
                "tell",
1652
                "--storage",
1653
                db_url,
1654
                "--trial-number",
1655
                str(trial_number),
1656
                "--values",
1657
                "1.2",
1658
            ]
1659
        )
1660
        assert ret.returncode != 0
1661

1662
        # Passing `--skip-if-finished` to a finished trial for a noop.
1663
        subprocess.check_output(
1664
            [
1665
                "optuna",
1666
                "tell",
1667
                "--storage",
1668
                db_url,
1669
                "--trial-number",
1670
                str(trial_number),
1671
                "--values",
1672
                "1.3",  # Setting a different value and make sure it's not persisted.
1673
                "--skip-if-finished",
1674
            ]
1675
        )
1676

1677
        study = optuna.load_study(storage=db_url, study_name=study_name)
1678
        assert len(study.trials) == 1
1679
        assert study.trials[0].state == TrialState.COMPLETE
1680
        assert study.trials[0].values == [1.2]
1681

1682

1683
@pytest.mark.skip_coverage
1684
def test_tell_with_nan() -> None:
1685
    study_name = "test_study"
1686

1687
    with NamedTemporaryFilePool() as tf:
1688
        db_url = "sqlite:///{}".format(tf.name)
1689

1690
        output: Any = subprocess.check_output(
1691
            [
1692
                "optuna",
1693
                "ask",
1694
                "--storage",
1695
                db_url,
1696
                "--study-name",
1697
                study_name,
1698
                "--format",
1699
                "json",
1700
            ]
1701
        )
1702
        output = output.decode("utf-8")
1703
        output = json.loads(output)
1704
        trial_number = output["number"]
1705

1706
        subprocess.check_output(
1707
            [
1708
                "optuna",
1709
                "tell",
1710
                "--storage",
1711
                db_url,
1712
                "--trial-number",
1713
                str(trial_number),
1714
                "--values",
1715
                "nan",
1716
            ]
1717
        )
1718

1719
        study = optuna.load_study(storage=db_url, study_name=study_name)
1720
        assert len(study.trials) == 1
1721
        assert study.trials[0].state == TrialState.FAIL
1722
        assert study.trials[0].values is None
1723

1724

1725
@pytest.mark.skip_coverage
1726
@pytest.mark.parametrize(
1727
    "verbosity, expected",
1728
    [
1729
        ("--verbose", True),
1730
        ("--quiet", False),
1731
    ],
1732
)
1733
def test_configure_logging_verbosity(verbosity: str, expected: bool) -> None:
1734
    with StorageSupplier("sqlite") as storage:
1735
        assert isinstance(storage, RDBStorage)
1736
        storage_url = str(storage.engine.url)
1737

1738
        # Create study.
1739
        args = ["optuna", "create-study", "--storage", storage_url, verbosity]
1740
        # `--verbose` makes the log level DEBUG.
1741
        # `--quiet` makes the log level WARNING.
1742
        result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1743
        error_message = result.stderr.decode()
1744
        assert ("A new study created in RDB with name" in error_message) == expected
1745

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.