//===- StandardAttributes.cpp - C Interface to MLIR Standard Attributes ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir-c/StandardAttributes.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/StandardTypes.h" using namespace mlir; /*============================================================================*/ /* Affine map attribute. */ /*============================================================================*/ int mlirAttributeIsAAffineMap(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { return wrap(AffineMapAttr::get(unwrap(map))); } MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { return wrap(unwrap(attr).cast().getValue()); } /*============================================================================*/ /* Array attribute. */ /*============================================================================*/ int mlirAttributeIsAArray(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, MlirAttribute *elements) { SmallVector attrs; return wrap(ArrayAttr::get( unwrapList(static_cast(numElements), elements, attrs), unwrap(ctx))); } intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { return static_cast(unwrap(attr).cast().size()); } MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { return wrap(unwrap(attr).cast().getValue()[pos]); } /*============================================================================*/ /* Dictionary attribute. */ /*============================================================================*/ int mlirAttributeIsADictionary(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, MlirNamedAttribute *elements) { SmallVector attributes; attributes.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) attributes.emplace_back(Identifier::get(elements[i].name, unwrap(ctx)), unwrap(elements[i].attribute)); return wrap(DictionaryAttr::get(attributes, unwrap(ctx))); } intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { return static_cast(unwrap(attr).cast().size()); } MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos) { NamedAttribute attribute = unwrap(attr).cast().getValue()[pos]; return {attribute.first.c_str(), wrap(attribute.second)}; } MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, const char *name) { return wrap(unwrap(attr).cast().get(name)); } /*============================================================================*/ /* Floating point attribute. */ /*============================================================================*/ int mlirAttributeIsAFloat(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, double value) { return wrap(FloatAttr::get(unwrap(type), value)); } double mlirFloatAttrGetValueDouble(MlirAttribute attr) { return unwrap(attr).cast().getValueAsDouble(); } /*============================================================================*/ /* Integer attribute. */ /*============================================================================*/ int mlirAttributeIsAInteger(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { return wrap(IntegerAttr::get(unwrap(type), value)); } int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { return unwrap(attr).cast().getInt(); } /*============================================================================*/ /* Bool attribute. */ /*============================================================================*/ int mlirAttributeIsABool(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { return wrap(BoolAttr::get(value, unwrap(ctx))); } int mlirBoolAttrGetValue(MlirAttribute attr) { return unwrap(attr).cast().getValue(); } /*============================================================================*/ /* Integer set attribute. */ /*============================================================================*/ int mlirAttributeIsAIntegerSet(MlirAttribute attr) { return unwrap(attr).isa(); } /*============================================================================*/ /* Opaque attribute. */ /*============================================================================*/ int mlirAttributeIsAOpaque(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, const char *dialectNamespace, intptr_t dataLength, const char *data, MlirType type) { return wrap(OpaqueAttr::get(Identifier::get(dialectNamespace, unwrap(ctx)), StringRef(data, dataLength), unwrap(type), unwrap(ctx))); } const char *mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { return unwrap(attr).cast().getDialectNamespace().c_str(); } void mlirOpaqueAttrGetData(MlirAttribute attr, MlirStringCallback callback, void *userData) { StringRef data = unwrap(attr).cast().getAttrData(); callback(data.data(), static_cast(data.size()), userData); } /*============================================================================*/ /* String attribute. */ /*============================================================================*/ int mlirAttributeIsAString(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirStringAttrGet(MlirContext ctx, intptr_t length, const char *data) { return wrap(StringAttr::get(StringRef(data, length), unwrap(ctx))); } MlirAttribute mlirStringAttrTypedGet(MlirType type, intptr_t length, const char *data) { return wrap(StringAttr::get(StringRef(data, length), unwrap(type))); } void mlirStringAttrGetValue(MlirAttribute attr, MlirStringCallback callback, void *userData) { StringRef data = unwrap(attr).cast().getValue(); callback(data.data(), static_cast(data.size()), userData); } /*============================================================================*/ /* SymbolRef attribute. */ /*============================================================================*/ int mlirAttributeIsASymbolRef(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length, const char *symbol, intptr_t numReferences, MlirAttribute *references) { SmallVector refs; refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) refs.push_back(unwrap(references[i]).cast()); return wrap(SymbolRefAttr::get(StringRef(symbol, length), refs, unwrap(ctx))); } void mlirSymbolRefAttrGetRootReference(MlirAttribute attr, MlirStringCallback callback, void *userData) { StringRef ref = unwrap(attr).cast().getRootReference(); callback(ref.data(), ref.size(), userData); } void mlirSymbolRefAttrGetLeafReference(MlirAttribute attr, MlirStringCallback callback, void *userData) { StringRef ref = unwrap(attr).cast().getLeafReference(); callback(ref.data(), ref.size(), userData); } intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { return static_cast( unwrap(attr).cast().getNestedReferences().size()); } MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos) { return wrap(unwrap(attr).cast().getNestedReferences()[pos]); } /*============================================================================*/ /* Flat SymbolRef attribute. */ /*============================================================================*/ int mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, intptr_t length, const char *symbol) { return wrap(FlatSymbolRefAttr::get(StringRef(symbol, length), unwrap(ctx))); } void mlirFloatSymbolRefAttrGetValue(MlirAttribute attr, MlirStringCallback callback, void *userData) { StringRef symbol = unwrap(attr).cast().getValue(); callback(symbol.data(), symbol.size(), userData); } /*============================================================================*/ /* Type attribute. */ /*============================================================================*/ int mlirAttributeIsAType(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirTypeAttrGet(MlirType type) { return wrap(TypeAttr::get(unwrap(type))); } MlirType mlirTypeAttrGetValue(MlirAttribute attr) { return wrap(unwrap(attr).cast().getValue()); } /*============================================================================*/ /* Unit attribute. */ /*============================================================================*/ int mlirAttributeIsAUnit(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirUnitAttrGet(MlirContext ctx) { return wrap(UnitAttr::get(unwrap(ctx))); } /*============================================================================*/ /* Elements attributes. */ /*============================================================================*/ int mlirAttributeIsAElements(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { return wrap(unwrap(attr).cast().getValue( llvm::makeArrayRef(idxs, rank))); } int mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { return unwrap(attr).cast().isValidIndex( llvm::makeArrayRef(idxs, rank)); } int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { return unwrap(attr).cast().getNumElements(); } /*============================================================================*/ /* Dense elements attribute. */ /*============================================================================*/ //===----------------------------------------------------------------------===// // IsA support. int mlirAttributeIsADenseElements(MlirAttribute attr) { return unwrap(attr).isa(); } int mlirAttributeIsADenseIntElements(MlirAttribute attr) { return unwrap(attr).isa(); } int mlirAttributeIsADenseFPElements(MlirAttribute attr) { return unwrap(attr).isa(); } //===----------------------------------------------------------------------===// // Constructors. MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, MlirAttribute *elements) { SmallVector attributes; return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), unwrapList(numElements, elements, attributes))); } MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element) { return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), unwrap(element))); } MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, int element) { return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), static_cast(element))); } MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, uint64_t element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, int64_t element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, float element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, intptr_t numElements, int *elements) { SmallVector values(elements, elements + numElements); return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), values)); } /// Creates a dense attribute with elements of the type deduced by templates. template static MlirAttribute getDenseAttribute(MlirType shapedType, intptr_t numElements, T *elements) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), llvm::makeArrayRef(elements, numElements))); } MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, intptr_t numElements, uint32_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType, intptr_t numElements, int32_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType, intptr_t numElements, uint64_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType, intptr_t numElements, int64_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType, intptr_t numElements, float *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, intptr_t numElements, double *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements, intptr_t *strLengths, const char **strs) { SmallVector values; values.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) values.push_back(StringRef(strs[i], strLengths[i])); return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), values)); } MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, MlirType shapedType) { return wrap(unwrap(attr).cast().reshape( unwrap(shapedType).cast())); } //===----------------------------------------------------------------------===// // Splat accessors. int mlirDenseElementsAttrIsSplat(MlirAttribute attr) { return unwrap(attr).cast().isSplat(); } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { return wrap(unwrap(attr).cast().getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } void mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr, MlirStringCallback callback, void *userData) { StringRef str = unwrap(attr).cast().getSplatValue(); callback(str.data(), str.size(), userData); } //===----------------------------------------------------------------------===// // Indexed accessors. int mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { return *( unwrap(attr).cast().getValues().begin() + pos); } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { return *( unwrap(attr).cast().getValues().begin() + pos); } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } void mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos, MlirStringCallback callback, void *userData) { StringRef str = *(unwrap(attr).cast().getValues().begin() + pos); callback(str.data(), str.size(), userData); } /*============================================================================*/ /* Opaque elements attribute. */ /*============================================================================*/ int mlirAttributeIsAOpaqueElements(MlirAttribute attr) { return unwrap(attr).isa(); } /*============================================================================*/ /* Sparse elements attribute. */ /*============================================================================*/ int mlirAttributeIsASparseElements(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, MlirAttribute denseIndices, MlirAttribute denseValues) { return wrap( SparseElementsAttr::get(unwrap(shapedType).cast(), unwrap(denseIndices).cast(), unwrap(denseValues).cast())); } MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { return wrap(unwrap(attr).cast().getIndices()); } MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { return wrap(unwrap(attr).cast().getValues()); }