DFT-FE 1.1.0-pre
Density Functional Theory With Finite-Elements
Loading...
Searching...
No Matches
DeviceBlasWrapper.h
Go to the documentation of this file.
1// ---------------------------------------------------------------------
2//
3// Copyright (c) 2017-2025 The Regents of the University of Michigan and DFT-FE
4// authors.
5//
6// This file is part of the DFT-FE code.
7//
8// The DFT-FE code is free software; you can use it, redistribute
9// it, and/or modify it under the terms of the GNU Lesser General
10// Public License as published by the Free Software Foundation; either
11// version 2.1 of the License, or (at your option) any later version.
12// The full text of the license can be found in the file LICENSE at
13// the top level of the DFT-FE distribution.
14//
15// ---------------------------------------------------------------------
16//
17
18#ifdef DFTFE_WITH_DEVICE
19
20# ifndef dftfeDeviceBlasWrapper_H
21# define dftfeDeviceBlasWrapper_H
22
23# include <complex>
24# include <TypeConfig.h>
25# include <DeviceTypeConfig.h>
26namespace dftfe
27{
28 namespace utils
29 {
30 namespace deviceBlasWrapper
31 {
32# ifdef DFTFE_WITH_DEVICE_AMD
33 void
34 initialize();
35# endif
36
38 create(deviceBlasHandle_t *pHandle);
39
41 destroy(deviceBlasHandle_t handle);
42
44 setStream(deviceBlasHandle_t handle, deviceStream_t stream);
45
46# ifdef DFTFE_WITH_DEVICE_LANG_CUDA
48 setMathMode(deviceBlasHandle_t handle, deviceBlasMath_t mathMode);
49# endif
50
52 copy(deviceBlasHandle_t handle,
53 int n,
54 const double * x,
55 int incx,
56 double * y,
57 int incy);
58
60 nrm2(deviceBlasHandle_t handle,
61 int n,
62 const double * x,
63 int incx,
64 double * result);
65
67 dot(deviceBlasHandle_t handle,
68 int n,
69 const double * x,
70 int incx,
71 const double * y,
72 int incy,
73 double * result);
74
76 axpy(deviceBlasHandle_t handle,
77 int n,
78 const double * alpha,
79 const double * x,
80 int incx,
81 double * y,
82 int incy);
83
85 gemm(deviceBlasHandle_t handle,
86 deviceBlasOperation_t transa,
87 deviceBlasOperation_t transb,
88 int m,
89 int n,
90 int k,
91 const double * alpha,
92 const double * A,
93 int lda,
94 const double * B,
95 int ldb,
96 const double * beta,
97 double * C,
98 int ldc);
99
101 gemm(deviceBlasHandle_t handle,
102 deviceBlasOperation_t transa,
103 deviceBlasOperation_t transb,
104 int m,
105 int n,
106 int k,
107 const float * alpha,
108 const float * A,
109 int lda,
110 const float * B,
111 int ldb,
112 const float * beta,
113 float * C,
114 int ldc);
115
117 gemm(deviceBlasHandle_t handle,
118 deviceBlasOperation_t transa,
119 deviceBlasOperation_t transb,
120 int m,
121 int n,
122 int k,
123 const std::complex<double> *alpha,
124 const std::complex<double> *A,
125 int lda,
126 const std::complex<double> *B,
127 int ldb,
128 const std::complex<double> *beta,
129 std::complex<double> * C,
130 int ldc);
131
133 gemm(deviceBlasHandle_t handle,
134 deviceBlasOperation_t transa,
135 deviceBlasOperation_t transb,
136 int m,
137 int n,
138 int k,
139 const std::complex<float> *alpha,
140 const std::complex<float> *A,
141 int lda,
142 const std::complex<float> *B,
143 int ldb,
144 const std::complex<float> *beta,
145 std::complex<float> * C,
146 int ldc);
147
149 gemmBatched(deviceBlasHandle_t handle,
150 deviceBlasOperation_t transa,
151 deviceBlasOperation_t transb,
152 int m,
153 int n,
154 int k,
155 const double * alpha,
156 const double * Aarray[],
157 int lda,
158 const double * Barray[],
159 int ldb,
160 const double * beta,
161 double * Carray[],
162 int ldc,
163 int batchCount);
164
166 gemmBatched(deviceBlasHandle_t handle,
167 deviceBlasOperation_t transa,
168 deviceBlasOperation_t transb,
169 int m,
170 int n,
171 int k,
172 const std::complex<double> *alpha,
173 const std::complex<double> *Aarray[],
174 int lda,
175 const std::complex<double> *Barray[],
176 int ldb,
177 const std::complex<double> *beta,
178 std::complex<double> * Carray[],
179 int ldc,
180 int batchCount);
181
183 gemmStridedBatched(deviceBlasHandle_t handle,
184 deviceBlasOperation_t transa,
185 deviceBlasOperation_t transb,
186 int m,
187 int n,
188 int k,
189 const double * alpha,
190 const double * A,
191 int lda,
192 long long int strideA,
193 const double * B,
194 int ldb,
195 long long int strideB,
196 const double * beta,
197 double * C,
198 int ldc,
199 long long int strideC,
200 int batchCount);
201
202
204 gemmStridedBatched(deviceBlasHandle_t handle,
205 deviceBlasOperation_t transa,
206 deviceBlasOperation_t transb,
207 int m,
208 int n,
209 int k,
210 const float * alpha,
211 const float * A,
212 int lda,
213 long long int strideA,
214 const float * B,
215 int ldb,
216 long long int strideB,
217 const float * beta,
218 float * C,
219 int ldc,
220 long long int strideC,
221 int batchCount);
222
224 gemmStridedBatched(deviceBlasHandle_t handle,
225 deviceBlasOperation_t transa,
226 deviceBlasOperation_t transb,
227 int m,
228 int n,
229 int k,
230 const std::complex<double> *alpha,
231 const std::complex<double> *A,
232 int lda,
233 long long int strideA,
234 const std::complex<double> *B,
235 int ldb,
236 long long int strideB,
237 const std::complex<double> *beta,
238 std::complex<double> * C,
239 int ldc,
240 long long int strideC,
241 int batchCount);
242
244 gemmStridedBatched(deviceBlasHandle_t handle,
245 deviceBlasOperation_t transa,
246 deviceBlasOperation_t transb,
247 int m,
248 int n,
249 int k,
250 const std::complex<float> *alpha,
251 const std::complex<float> *A,
252 int lda,
253 long long int strideA,
254 const std::complex<float> *B,
255 int ldb,
256 long long int strideB,
257 const std::complex<float> *beta,
258 std::complex<float> * C,
259 int ldc,
260 long long int strideC,
261 int batchCount);
262
264 gemv(deviceBlasHandle_t handle,
265 deviceBlasOperation_t trans,
266 int m,
267 int n,
268 const double * alpha,
269 const double * A,
270 int lda,
271 const double * x,
272 int incx,
273 const double * beta,
274 double * y,
275 int incy);
276
278 gemv(deviceBlasHandle_t handle,
279 deviceBlasOperation_t trans,
280 int m,
281 int n,
282 const float * alpha,
283 const float * A,
284 int lda,
285 const float * x,
286 int incx,
287 const float * beta,
288 float * y,
289 int incy);
290
292 gemv(deviceBlasHandle_t handle,
293 deviceBlasOperation_t trans,
294 int m,
295 int n,
296 const std::complex<double> *alpha,
297 const std::complex<double> *A,
298 int lda,
299 const std::complex<double> *x,
300 int incx,
301 const std::complex<double> *beta,
302 std::complex<double> * y,
303 int incy);
304
306 gemv(deviceBlasHandle_t handle,
307 deviceBlasOperation_t trans,
308 int m,
309 int n,
310 const std::complex<float> *alpha,
311 const std::complex<float> *A,
312 int lda,
313 const std::complex<float> *x,
314 int incx,
315 const std::complex<float> *beta,
316 std::complex<float> * y,
317 int incy);
318
319
320 } // namespace deviceBlasWrapper
321 } // namespace utils
322} // namespace dftfe
323
324# endif // dftfeDeviceBlasWrapper_H
325#endif // DFTFE_WITH_DEVICE
Definition Cell.h:36
cublasStatus_t deviceBlasStatus_t
Definition DeviceTypeConfig.cu.h:38
Definition pseudoPotentialToDftfeConverter.cc:34