1
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3
Licensed under the Apache License, Version 2.0 (the "License");
4
you may not use this file except in compliance with the License.
5
You may obtain a copy of the License at
7
http://www.apache.org/licenses/LICENSE-2.0
9
Unless required by applicable law or agreed to in writing, software
10
distributed under the License is distributed on an "AS IS" BASIS,
11
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
See the License for the specific language governing permissions and
13
limitations under the License.
14
==============================================================================*/
18
#include "llvm/Support/ErrorHandling.h"
19
#include "mlir/Dialect/Traits.h" // from @llvm-project
20
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
21
#include "mlir/IR/Dialect.h" // from @llvm-project
22
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
25
// Returns the shape of the given value if it's ranked; returns llvm::None
27
llvm::Optional<llvm::ArrayRef<int64_t> > GetShape(mlir::Value value)
29
auto shaped_type = value.getType().cast<mlir::ShapedType>();
30
if (shaped_type.hasRank()) return shaped_type.getShape();
34
// Merges cast compatible shapes and returns a more refined shape. The two
35
// shapes are cast compatible if they have the same rank and at each dimension,
36
// either both have same size or one of them is dynamic. Returns false if the
37
// given shapes are not cast compatible. The refined shape is same or more
38
// precise than the two input shapes.
39
bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
40
llvm::ArrayRef<int64_t> b_shape,
41
llvm::SmallVectorImpl<int64_t>* refined_shape)
43
if (a_shape.size() != b_shape.size()) return false;
44
int64_t rank = a_shape.size();
45
refined_shape->reserve(rank);
46
for (auto dims : llvm::zip(a_shape, b_shape))
48
int64_t dim1 = std::get<0>(dims);
49
int64_t dim2 = std::get<1>(dims);
51
if (mlir::ShapedType::isDynamic(dim1))
53
refined_shape->push_back(dim2);
56
if (mlir::ShapedType::isDynamic(dim2))
58
refined_shape->push_back(dim1);
63
refined_shape->push_back(dim1);
75
//===----------------------------------------------------------------------===//
77
//===----------------------------------------------------------------------===//
79
OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it)
80
: llvm::mapped_iterator<Operation::operand_iterator,
81
llvm::Optional<ArrayRef<int64_t> > (*)(Value)>(
86
ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
87
: llvm::mapped_iterator<Operation::result_iterator,
88
llvm::Optional<ArrayRef<int64_t> > (*)(Value)>(
93
//===----------------------------------------------------------------------===//
94
// TF types helper functions
95
//===----------------------------------------------------------------------===//
97
bool TensorFlowType::classof(Type type)
99
return type.getDialect().getNamespace() == "tf";
101
bool TensorFlowRefType::classof(Type type)
104
#define HANDLE_TF_TYPE(tftype, enumerant, name)
105
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type,
106
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
108
#include "tf_types.def"
111
bool TensorFlowTypeWithSubtype::classof(Type type)
113
return type.isa<ResourceType, VariantType>();
116
TensorFlowType TensorFlowRefType::get(Type type)
118
MLIRContext* ctx = type.getContext();
119
type = getElementTypeOrSelf(type);
122
return HalfRefType::get(ctx);
124
else if (type.isF32())
126
return FloatRefType::get(ctx);
128
else if (type.isF64())
130
return DoubleRefType::get(ctx);
132
else if (type.isBF16())
134
return Bfloat16RefType::get(ctx);
136
else if (auto complex_type = type.dyn_cast<ComplexType>())
138
Type etype = complex_type.getElementType();
141
return Complex64RefType::get(ctx);
143
else if (etype.isF64())
145
return Complex128RefType::get(ctx);
147
llvm_unreachable("unexpected complex type");
149
else if (auto itype = type.dyn_cast<IntegerType>())
151
switch (itype.getWidth())
154
return BoolRefType::get(ctx);
156
return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
157
: Int8RefType::get(ctx);
159
return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
160
: Int16RefType::get(ctx);
162
return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
163
: Int32RefType::get(ctx);
165
return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
166
: Int64RefType::get(ctx);
168
llvm_unreachable("unexpected integer type");
171
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
172
if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
173
return tftype##RefType::get(ctx);
175
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
177
#include "tf_types.def"
178
llvm_unreachable("unexpected type kind");
181
Type TensorFlowRefType::RemoveRef()
183
MLIRContext* ctx = getContext();
184
if (isa<HalfRefType>()) return mlir::FloatType::getF16(ctx);
185
if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
186
if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
187
if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
188
if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
189
if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
190
if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
191
if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
192
if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
193
if (isa<Uint8RefType>())
194
return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
195
if (isa<Uint16RefType>())
196
return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
197
if (isa<Uint32RefType>())
198
return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
199
if (isa<Uint64RefType>())
200
return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
201
if (isa<Complex64RefType>())
202
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
203
if (isa<Complex128RefType>())
204
return mlir::ComplexType::get(mlir::FloatType::getF64(ctx));
205
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
206
if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
208
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
210
#include "tf_types.def"
211
llvm_unreachable("unexpected tensorflow ref type kind");
214
Type TensorFlowTypeWithSubtype::RemoveSubtypes()
216
MLIRContext* ctx = getContext();
217
if (isa<VariantType>()) return VariantType::get(ctx);
218
if (isa<ResourceType>()) return ResourceType::get(ctx);
219
llvm_unreachable("unexpected tensorflow type with subtypes kind");
222
ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes()
224
if (auto variant_type = dyn_cast<VariantType>())
225
return variant_type.getSubtypes();
226
if (auto resource_type = dyn_cast<ResourceType>())
227
return resource_type.getSubtypes();
228
llvm_unreachable("unexpected tensorflow type with subtypes kind");
231
// TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
232
// similar structure that could be extracted into helper method.
233
bool BroadcastCompatible(TypeRange lhs, TypeRange rhs)
235
if (lhs.size() != rhs.size()) return false;
236
for (auto types : llvm::zip(lhs, rhs))
238
// Drop ref types because they don't affect broadcast compatibility. E.g.,
239
// `tensor<!tf.f32ref>` and `tensor<f32>` should be considered broadcast
241
auto lhs_type = DropRefType(std::get<0>(types));
242
auto rhs_type = DropRefType(std::get<1>(types));
244
// This should be true for all TF ops:
245
auto lhs_tt = lhs_type.dyn_cast<TensorType>();
246
auto rhs_tt = rhs_type.dyn_cast<TensorType>();
247
if (!lhs_tt || !rhs_tt)
249
if (lhs_type != rhs_type) return false;
253
// Verify matching element types. These should be identical, except for
254
// variant type where unknown subtype is considered compatible with all
256
auto lhs_et = lhs_tt.getElementType();
257
auto rhs_et = rhs_tt.getElementType();
258
if (lhs_et != rhs_et)
260
// If either does not have subtypes, then the element types don't match.
261
auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
262
auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
263
if (!lhs_wst || !rhs_wst) return false;
265
// Consider the subtype of variant types.
266
auto lhs_wst_st = lhs_wst.GetSubtypes();
267
auto rhs_wst_st = rhs_wst.GetSubtypes();
268
if (!lhs_wst_st.empty() && !rhs_wst_st.empty())
270
for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st))
272
if (!BroadcastCompatible(std::get<0>(subtypes),
273
std::get<1>(subtypes)))
279
auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>();
280
auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>();
281
if (!lhs_rt || !rhs_rt) return true;
282
SmallVector<int64_t, 4> shape;
283
return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(),
284
rhs_rt.getShape(), shape);
289
// Given two types `a` and `b`, returns a refined type which is cast compatible
290
// with both `a` and `b` and is equal to or more precise than both of them. It
291
// returns empty Type if the input types are not cast compatible.
293
// The two types are considered cast compatible if they have dynamically equal
294
// shapes and element type. For element types that do not have subtypes, they
295
// must be equal. However for TensorFlow types such as Resource and Variant,
296
// that also have subtypes, we recursively check for subtype compatibility for
297
// Resource types and assume all variant types are cast compatible. If either
298
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
300
// The returned type is same or more precise than the input types. For example,
301
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
302
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
304
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
305
// might allow operands to either be same as result type or be a ref type
306
// corresponding to it.
307
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
308
bool may_ignore_ref_type_a)
310
// Fast path if everything is equal.
311
if (a == b) return b;
313
auto a_tt = a.dyn_cast<mlir::TensorType>();
314
auto b_tt = b.dyn_cast<mlir::TensorType>();
316
// If only one of a or b is a tensor type, they are incompatible.
317
if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
319
// For non-tensor types, we do not need to worry about shape and can return
324
if (may_ignore_ref_type_a)
326
if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>())
328
a = ref_type.RemoveRef();
329
if (a == b) return a;
332
if (a.getTypeID() != b.getTypeID()) return nullptr;
334
// If either is not a type that contain subtypes then the types are not cast
336
auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
337
auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
338
if (!a_wst || !b_wst) return nullptr;
340
// For Variant types we are more permissive right now and accept all pairs
341
// of Variant types. If we are more constrainted and check compatibility of
342
// subtypes, we might reject valid graphs.
343
// TODO(prakalps): Variant doesn't have a subtype, we assign it
344
// one, so we should only assign it one when we know the subtype. Then we
345
// can be more constrained and check subtypes for cast compatibility as
347
if (a.isa<mlir::TF::VariantType>()) return a;
349
// For Resource types, we recursively check the subtypes for cast
350
// compatibility, if possible. Otherwise treat them as compatible.
351
auto a_wst_st = a_wst.GetSubtypes();
352
auto b_wst_st = b_wst.GetSubtypes();
353
if (a_wst_st.empty() || b_wst_st.empty()) return a;
354
if (a_wst_st.size() != b_wst_st.size()) return nullptr;
355
llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
356
for (auto subtypes : llvm::zip(a_wst_st, b_wst_st))
358
mlir::Type refined_st = GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
359
/*may_ignore_ref_type_a=*/false);
360
if (!refined_st) return nullptr;
361
refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
364
return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
367
// For tensor types, check compatibility of both element type and shape.
368
mlir::Type refined_element_ty = GetCastCompatibleType(
369
a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
370
if (!refined_element_ty) return nullptr;
372
if (!a_tt.hasRank() && !b_tt.hasRank())
374
return mlir::UnrankedTensorType::get(refined_element_ty);
378
return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
382
return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
385
llvm::SmallVector<int64_t, 8> refined_shape;
386
if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
389
return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
392
bool HasCompatibleElementTypes(Type lhs, Type rhs,
393
bool may_ignore_ref_type_lhs)
395
return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
398
bool AreCastCompatible(TypeRange types)
400
Type common = types.front();
401
for (auto type : types.drop_front())
403
Type refined_type = GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
404
if (!refined_type) return false;
405
common = refined_type;
410
bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs)
412
if (lhs.size() != rhs.size()) return false;
413
for (auto pair : llvm::zip(lhs, rhs))
415
auto lhs_i = std::get<0>(pair);
416
auto rhs_i = std::get<1>(pair);
417
if (!AreCastCompatible({lhs_i, rhs_i})) return false;
422
// Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
423
// type for a composed type (such as a ref type or a type with subtypes).
424
template<typename ComposedType>
425
Type DropTypeHelper(Type ty)
427
Type element_ty = getElementTypeOrSelf(ty);
428
auto composed_type = element_ty.dyn_cast<ComposedType>();
429
if (!composed_type) return ty;
431
Type default_ty = GetDefaultTypeOf(composed_type);
432
if (auto ranked_ty = ty.dyn_cast<RankedTensorType>())
434
return RankedTensorType::get(ranked_ty.getShape(), default_ty);
436
else if (ty.dyn_cast<UnrankedTensorType>())
438
return UnrankedTensorType::get(default_ty);
446
Type DropSubTypes(Type ty)
448
return DropTypeHelper<TF::TensorFlowTypeWithSubtype>(ty);
451
Type DropRefType(Type ty)
453
return DropTypeHelper<TF::TensorFlowRefType>(ty);
456
Type DropRefAndSubTypes(Type ty)
458
return DropRefType(DropSubTypes(ty));