DFT-EFE
 
Loading...
Searching...
No Matches
BlasAPIWrapper.h
Go to the documentation of this file.
1/******************************************************************************
2 * Copyright (c) 2021. *
3 * The Regents of the University of Michigan and DFT-EFE developers. *
4 * *
5 * This file is part of the DFT-EFE code. *
6 * *
7 * DFT-EFE is free software: you can redistribute it and/or modify *
8 * it under the terms of the Lesser GNU General Public License as *
9 * published by the Free Software Foundation, either version 3 of *
10 * the License, or (at your option) any later version. *
11 * *
12 * DFT-EFE is distributed in the hope that it will be useful, but *
13 * WITHOUT ANY WARRANTY; without even the implied warranty *
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *
15 * See the Lesser GNU General Public License for more details. *
16 * *
17 * You should have received a copy of the GNU Lesser General Public *
18 * License at the top level of DFT-EFE distribution. If not, see *
19 * <https://www.gnu.org/licenses/>. *
20 ******************************************************************************/
21
22/*
23 * @author Avirup Sircar
24 */
25
26#ifndef BLASWrapper_h
27#define BLASWrapper_h
28
29#include <cmath>
31#include <utils/TypeConfig.h>
33namespace dftefe
34{
35 namespace linearAlgebra
36 {
37 namespace blasLapack
38 {
39 namespace blasWrapper
40 {
41 template <typename ValueType1,
42 typename ValueType2,
43 typename utils::MemorySpace memorySpace>
44 void
45 gemm(const char transA,
46 const char transB,
47 const size_type m,
48 const size_type n,
49 const size_type k,
51 ValueType1 const * A,
52 const size_type lda,
53 ValueType2 const * B,
54 const size_type ldb,
57 const size_type ldc,
59
60 template <typename ValueType, typename utils::MemorySpace memorySpace>
62 asum(const size_type n,
63 ValueType const * x,
64 const size_type incx,
66
67 template <typename ValueType, typename utils::MemorySpace memorySpace>
70 ValueType const * x,
71 const size_type incx,
73
74 template <typename ValueType1,
75 typename ValueType2,
76 typename utils::MemorySpace memorySpace>
77 void
78 axpy(const size_type n,
80 ValueType1 const * x,
81 const size_type incx,
82 ValueType2 * y,
83 const size_type incy,
85
86#if defined(DFTEFE_WITH_DEVICE)
87
88 enum class tensorOpDataType
89 {
90 fp32,
91 tf32,
92 bf16,
93 fp16
94 };
95
96 template <typename ValueType1, typename ValueType2>
97 static void
98 copyValueType1ArrToValueType2ArrDeviceCall(
99 const size_type size,
100 const ValueType1 * valueType1Arr,
101 ValueType2 * valueType2Arr,
102 utils::deviceStream_t streamId = utils::defaultStream);
103
104 utils::deviceBlasHandle_t &
105 getDeviceBlasHandle();
106
107 void
108 setTensorOpDataType(tensorOpDataType opType)
109 {
110 d_opType = opType;
111 }
112
113 static utils::deviceBlasStatus_t
114 setStream(utils::deviceStream_t streamId);
115
116 inline static utils::deviceBlasHandle_t d_deviceBlasHandle;
117 inline static utils::deviceStream_t d_streamId;
118
119# ifdef DFTEFE_WITH_DEVICE_AMD
120 void
121 initialize();
122# endif
123
125 tensorOpDataType d_opType;
126
127 utils::deviceBlasStatus_t
128 create();
129
130 utils::deviceBlasStatus_t
131 destroy();
132
133#endif
134
135 } // namespace blasWrapper
136 } // namespace blasLapack
137 } // end of namespace linearAlgebra
138
139} // end of namespace dftefe
140
141
142#endif // BLASWrapper_h
Definition: LinAlgOpContext.h:38
size_type iamax(const size_type n, ValueType const *x, const size_type incx, LinAlgOpContext< memorySpace > &context)
void gemm(const char transA, const char transB, const size_type m, const size_type n, const size_type k, const scalar_type< ValueType1, ValueType2 > alpha, ValueType1 const *A, const size_type lda, ValueType2 const *B, const size_type ldb, const scalar_type< ValueType1, ValueType2 > beta, scalar_type< ValueType1, ValueType2 > *C, const size_type ldc, LinAlgOpContext< memorySpace > &context)
real_type< ValueType > asum(const size_type n, ValueType const *x, const size_type incx, LinAlgOpContext< memorySpace > &context)
void axpy(const size_type n, const scalar_type< ValueType1, ValueType2 > alpha, ValueType1 const *x, const size_type incx, ValueType2 *y, const size_type incy, LinAlgOpContext< memorySpace > &context)
typeInternal::real_type< ValueType > real_type
Definition: BlasLapackTypedef.h:177
typeInternal::scalar_type< ValueType1, ValueType2 > scalar_type
Definition: BlasLapackTypedef.h:183
MemorySpace
Definition: MemorySpaceType.h:37
dealii includes
Definition: AtomFieldDataSpherical.cpp:31
unsigned int size_type
Definition: TypeConfig.h:8