pytorch
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from functorch._C import dim
8
9
10tree_flatten = dim.tree_flatten
11
12
13def tree_map(fn, tree):
14vs, unflatten = tree_flatten(tree)
15return unflatten(fn(v) for v in vs)
16