pytorch

Форк
0
/
tree_map.py 
15 строк · 375.0 Байт
1
# Copyright (c) Facebook, Inc. and its affiliates.
2
# All rights reserved.
3
#
4
# This source code is licensed under the BSD-style license found in the
5
# LICENSE file in the root directory of this source tree.
6

7
from functorch._C import dim
8

9

10
tree_flatten = dim.tree_flatten
11

12

13
def tree_map(fn, tree):
14
    vs, unflatten = tree_flatten(tree)
15
    return unflatten(fn(v) for v in vs)
16

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.