ncnn

Форк
0
/
tf_types.cc 
462 строки · 16.4 Кб
1
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2

3
Licensed under the Apache License, Version 2.0 (the "License");
4
you may not use this file except in compliance with the License.
5
You may obtain a copy of the License at
6

7
    http://www.apache.org/licenses/LICENSE-2.0
8

9
Unless required by applicable law or agreed to in writing, software
10
distributed under the License is distributed on an "AS IS" BASIS,
11
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
See the License for the specific language governing permissions and
13
limitations under the License.
14
==============================================================================*/
15

16
#include "tf_types.h"
17

18
#include "llvm/Support/ErrorHandling.h"
19
#include "mlir/Dialect/Traits.h"   // from @llvm-project
20
#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
21
#include "mlir/IR/Dialect.h"       // from @llvm-project
22
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
23

24
namespace {
25
// Returns the shape of the given value if it's ranked; returns llvm::None
26
// otherwise.
27
llvm::Optional<llvm::ArrayRef<int64_t> > GetShape(mlir::Value value)
28
{
29
    auto shaped_type = value.getType().cast<mlir::ShapedType>();
30
    if (shaped_type.hasRank()) return shaped_type.getShape();
31
    return llvm::None;
32
}
33

34
// Merges cast compatible shapes and returns a more refined shape. The two
35
// shapes are cast compatible if they have the same rank and at each dimension,
36
// either both have same size or one of them is dynamic. Returns false if the
37
// given shapes are not cast compatible. The refined shape is same or more
38
// precise than the two input shapes.
39
bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
40
                            llvm::ArrayRef<int64_t> b_shape,
41
                            llvm::SmallVectorImpl<int64_t>* refined_shape)
42
{
43
    if (a_shape.size() != b_shape.size()) return false;
44
    int64_t rank = a_shape.size();
45
    refined_shape->reserve(rank);
46
    for (auto dims : llvm::zip(a_shape, b_shape))
47
    {
48
        int64_t dim1 = std::get<0>(dims);
49
        int64_t dim2 = std::get<1>(dims);
50

51
        if (mlir::ShapedType::isDynamic(dim1))
52
        {
53
            refined_shape->push_back(dim2);
54
            continue;
55
        }
56
        if (mlir::ShapedType::isDynamic(dim2))
57
        {
58
            refined_shape->push_back(dim1);
59
            continue;
60
        }
61
        if (dim1 == dim2)
62
        {
63
            refined_shape->push_back(dim1);
64
            continue;
65
        }
66
        return false;
67
    }
68
    return true;
69
}
70

71
} // namespace
72

73
namespace mlir {
74
namespace TF {
75
//===----------------------------------------------------------------------===//
76
// Utility iterators
77
//===----------------------------------------------------------------------===//
78

79
OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it)
80
    : llvm::mapped_iterator<Operation::operand_iterator,
81
      llvm::Optional<ArrayRef<int64_t> > (*)(Value)>(
82
          it, &GetShape)
83
{
84
}
85

86
ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
87
    : llvm::mapped_iterator<Operation::result_iterator,
88
      llvm::Optional<ArrayRef<int64_t> > (*)(Value)>(
89
          it, &GetShape)
90
{
91
}
92

93
//===----------------------------------------------------------------------===//
94
// TF types helper functions
95
//===----------------------------------------------------------------------===//
96

97
bool TensorFlowType::classof(Type type)
98
{
99
    return type.getDialect().getNamespace() == "tf";
100
}
101
bool TensorFlowRefType::classof(Type type)
102
{
103
    return type.isa<
104
#define HANDLE_TF_TYPE(tftype, enumerant, name)
105
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)  tftype##Type,
106
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
107
// NOLINTNEXTLINE
108
#include "tf_types.def"
109
           >();
110
}
111
bool TensorFlowTypeWithSubtype::classof(Type type)
112
{
113
    return type.isa<ResourceType, VariantType>();
114
}
115

116
TensorFlowType TensorFlowRefType::get(Type type)
117
{
118
    MLIRContext* ctx = type.getContext();
119
    type = getElementTypeOrSelf(type);
120
    if (type.isF16())
121
    {
122
        return HalfRefType::get(ctx);
123
    }
124
    else if (type.isF32())
125
    {
126
        return FloatRefType::get(ctx);
127
    }
128
    else if (type.isF64())
129
    {
130
        return DoubleRefType::get(ctx);
131
    }
132
    else if (type.isBF16())
133
    {
134
        return Bfloat16RefType::get(ctx);
135
    }
136
    else if (auto complex_type = type.dyn_cast<ComplexType>())
137
    {
138
        Type etype = complex_type.getElementType();
139
        if (etype.isF32())
140
        {
141
            return Complex64RefType::get(ctx);
142
        }
143
        else if (etype.isF64())
144
        {
145
            return Complex128RefType::get(ctx);
146
        }
147
        llvm_unreachable("unexpected complex type");
148
    }
149
    else if (auto itype = type.dyn_cast<IntegerType>())
150
    {
151
        switch (itype.getWidth())
152
        {
153
        case 1:
154
            return BoolRefType::get(ctx);
155
        case 8:
156
            return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
157
                   : Int8RefType::get(ctx);
158
        case 16:
159
            return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
160
                   : Int16RefType::get(ctx);
161
        case 32:
162
            return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
163
                   : Int32RefType::get(ctx);
164
        case 64:
165
            return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
166
                   : Int64RefType::get(ctx);
167
        default:
168
            llvm_unreachable("unexpected integer type");
169
        }
170
    }
171
#define HANDLE_TF_TYPE(tftype, enumerant, name)          \
172
    if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
173
        return tftype##RefType::get(ctx);
174

175
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
176
// NOLINTNEXTLINE
177
#include "tf_types.def"
178
    llvm_unreachable("unexpected type kind");
179
}
180

181
Type TensorFlowRefType::RemoveRef()
182
{
183
    MLIRContext* ctx = getContext();
184
    if (isa<HalfRefType>()) return mlir::FloatType::getF16(ctx);
185
    if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
186
    if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
187
    if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
188
    if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
189
    if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
190
    if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
191
    if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
192
    if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
193
    if (isa<Uint8RefType>())
194
        return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
195
    if (isa<Uint16RefType>())
196
        return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
197
    if (isa<Uint32RefType>())
198
        return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
199
    if (isa<Uint64RefType>())
200
        return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
201
    if (isa<Complex64RefType>())
202
        return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
203
    if (isa<Complex128RefType>())
204
        return mlir::ComplexType::get(mlir::FloatType::getF64(ctx));
205
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
206
    if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
207

208
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
209
// NOLINTNEXTLINE
210
#include "tf_types.def"
211
    llvm_unreachable("unexpected tensorflow ref type kind");
212
}
213

214
Type TensorFlowTypeWithSubtype::RemoveSubtypes()
215
{
216
    MLIRContext* ctx = getContext();
217
    if (isa<VariantType>()) return VariantType::get(ctx);
218
    if (isa<ResourceType>()) return ResourceType::get(ctx);
219
    llvm_unreachable("unexpected tensorflow type with subtypes kind");
220
}
221

222
ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes()
223
{
224
    if (auto variant_type = dyn_cast<VariantType>())
225
        return variant_type.getSubtypes();
226
    if (auto resource_type = dyn_cast<ResourceType>())
227
        return resource_type.getSubtypes();
228
    llvm_unreachable("unexpected tensorflow type with subtypes kind");
229
}
230

231
// TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
232
// similar structure that could be extracted into helper method.
233
bool BroadcastCompatible(TypeRange lhs, TypeRange rhs)
234
{
235
    if (lhs.size() != rhs.size()) return false;
236
    for (auto types : llvm::zip(lhs, rhs))
237
    {
238
        // Drop ref types because they don't affect broadcast compatibility. E.g.,
239
        // `tensor<!tf.f32ref>` and `tensor<f32>` should be considered broadcast
240
        // compatible.
241
        auto lhs_type = DropRefType(std::get<0>(types));
242
        auto rhs_type = DropRefType(std::get<1>(types));
243

244
        // This should be true for all TF ops:
245
        auto lhs_tt = lhs_type.dyn_cast<TensorType>();
246
        auto rhs_tt = rhs_type.dyn_cast<TensorType>();
247
        if (!lhs_tt || !rhs_tt)
248
        {
249
            if (lhs_type != rhs_type) return false;
250
            continue;
251
        }
252

253
        // Verify matching element types. These should be identical, except for
254
        // variant type where unknown subtype is considered compatible with all
255
        // subtypes.
256
        auto lhs_et = lhs_tt.getElementType();
257
        auto rhs_et = rhs_tt.getElementType();
258
        if (lhs_et != rhs_et)
259
        {
260
            // If either does not have subtypes, then the element types don't match.
261
            auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
262
            auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
263
            if (!lhs_wst || !rhs_wst) return false;
264

265
            // Consider the subtype of variant types.
266
            auto lhs_wst_st = lhs_wst.GetSubtypes();
267
            auto rhs_wst_st = rhs_wst.GetSubtypes();
268
            if (!lhs_wst_st.empty() && !rhs_wst_st.empty())
269
            {
270
                for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st))
271
                {
272
                    if (!BroadcastCompatible(std::get<0>(subtypes),
273
                                             std::get<1>(subtypes)))
274
                        return false;
275
                }
276
            }
277
        }
278

279
        auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>();
280
        auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>();
281
        if (!lhs_rt || !rhs_rt) return true;
282
        SmallVector<int64_t, 4> shape;
283
        return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(),
284
                rhs_rt.getShape(), shape);
285
    }
286
    return true;
287
}
288

289
// Given two types `a` and `b`, returns a refined type which is cast compatible
290
// with both `a` and `b` and is equal to or more precise than both of them. It
291
// returns empty Type if the input types are not cast compatible.
292
//
293
// The two types are considered cast compatible if they have dynamically equal
294
// shapes and element type. For element types that do not have subtypes, they
295
// must be equal. However for TensorFlow types such as Resource and Variant,
296
// that also have subtypes, we recursively check for subtype compatibility for
297
// Resource types and assume all variant types are cast compatible. If either
298
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
299
//
300
// The returned type is same or more precise than the input types. For example,
301
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
302
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
303
//
304
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
305
// might allow operands to either be same as result type or be a ref type
306
// corresponding to it.
307
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
308
                                 bool may_ignore_ref_type_a)
309
{
310
    // Fast path if everything is equal.
311
    if (a == b) return b;
312

313
    auto a_tt = a.dyn_cast<mlir::TensorType>();
314
    auto b_tt = b.dyn_cast<mlir::TensorType>();
315

316
    // If only one of a or b is a tensor type, they are incompatible.
317
    if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
318

319
    // For non-tensor types, we do not need to worry about shape and can return
320
    // early.
321
    if (!a_tt && !b_tt)
322
    {
323
        // Remove ref types.
324
        if (may_ignore_ref_type_a)
325
        {
326
            if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>())
327
            {
328
                a = ref_type.RemoveRef();
329
                if (a == b) return a;
330
            }
331
        }
332
        if (a.getTypeID() != b.getTypeID()) return nullptr;
333

334
        // If either is not a type that contain subtypes then the types are not cast
335
        // compatible.
336
        auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
337
        auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
338
        if (!a_wst || !b_wst) return nullptr;
339

340
        // For Variant types we are more permissive right now and accept all pairs
341
        // of Variant types. If we are more constrainted and check compatibility of
342
        // subtypes, we might reject valid graphs.
343
        // TODO(prakalps): Variant doesn't have a subtype, we assign it
344
        // one, so we should only assign it one when we know the subtype. Then we
345
        // can be more constrained and check subtypes for cast compatibility as
346
        // well.
347
        if (a.isa<mlir::TF::VariantType>()) return a;
348

349
        // For Resource types, we recursively check the subtypes for cast
350
        // compatibility, if possible. Otherwise treat them as compatible.
351
        auto a_wst_st = a_wst.GetSubtypes();
352
        auto b_wst_st = b_wst.GetSubtypes();
353
        if (a_wst_st.empty() || b_wst_st.empty()) return a;
354
        if (a_wst_st.size() != b_wst_st.size()) return nullptr;
355
        llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
356
        for (auto subtypes : llvm::zip(a_wst_st, b_wst_st))
357
        {
358
            mlir::Type refined_st = GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
359
                                    /*may_ignore_ref_type_a=*/false);
360
            if (!refined_st) return nullptr;
361
            refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
362
        }
363

364
        return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
365
    }
366

367
    // For tensor types, check compatibility of both element type and shape.
368
    mlir::Type refined_element_ty = GetCastCompatibleType(
369
                                        a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
370
    if (!refined_element_ty) return nullptr;
371

372
    if (!a_tt.hasRank() && !b_tt.hasRank())
373
    {
374
        return mlir::UnrankedTensorType::get(refined_element_ty);
375
    }
376
    if (!a_tt.hasRank())
377
    {
378
        return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
379
    }
380
    if (!b_tt.hasRank())
381
    {
382
        return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
383
    }
384

385
    llvm::SmallVector<int64_t, 8> refined_shape;
386
    if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
387
        return nullptr;
388

389
    return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
390
}
391

392
bool HasCompatibleElementTypes(Type lhs, Type rhs,
393
                               bool may_ignore_ref_type_lhs)
394
{
395
    return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
396
}
397

398
bool AreCastCompatible(TypeRange types)
399
{
400
    Type common = types.front();
401
    for (auto type : types.drop_front())
402
    {
403
        Type refined_type = GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
404
        if (!refined_type) return false;
405
        common = refined_type;
406
    }
407
    return true;
408
}
409

410
bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs)
411
{
412
    if (lhs.size() != rhs.size()) return false;
413
    for (auto pair : llvm::zip(lhs, rhs))
414
    {
415
        auto lhs_i = std::get<0>(pair);
416
        auto rhs_i = std::get<1>(pair);
417
        if (!AreCastCompatible({lhs_i, rhs_i})) return false;
418
    }
419
    return true;
420
}
421

422
// Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
423
// type for a composed type (such as a ref type or a type with subtypes).
424
template<typename ComposedType>
425
Type DropTypeHelper(Type ty)
426
{
427
    Type element_ty = getElementTypeOrSelf(ty);
428
    auto composed_type = element_ty.dyn_cast<ComposedType>();
429
    if (!composed_type) return ty;
430

431
    Type default_ty = GetDefaultTypeOf(composed_type);
432
    if (auto ranked_ty = ty.dyn_cast<RankedTensorType>())
433
    {
434
        return RankedTensorType::get(ranked_ty.getShape(), default_ty);
435
    }
436
    else if (ty.dyn_cast<UnrankedTensorType>())
437
    {
438
        return UnrankedTensorType::get(default_ty);
439
    }
440
    else
441
    {
442
        return default_ty;
443
    }
444
}
445

446
Type DropSubTypes(Type ty)
447
{
448
    return DropTypeHelper<TF::TensorFlowTypeWithSubtype>(ty);
449
}
450

451
Type DropRefType(Type ty)
452
{
453
    return DropTypeHelper<TF::TensorFlowRefType>(ty);
454
}
455

456
Type DropRefAndSubTypes(Type ty)
457
{
458
    return DropRefType(DropSubTypes(ty));
459
}
460

461
} // namespace TF
462
} // namespace mlir
463

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.