DFT-FE 1.1.0-pre
Density Functional Theory With Finite-Elements
Loading...
Searching...
No Matches
BLASWrapper.h
Go to the documentation of this file.
1// ---------------------------------------------------------------------
2//
3// Copyright (c) 2017-2025 The Regents of the University of Michigan and DFT-FE
4// authors.
5//
6// This file is part of the DFT-FE code.
7//
8// The DFT-FE code is free software; you can use it, redistribute
9// it, and/or modify it under the terms of the GNU Lesser General
10// Public License as published by the Free Software Foundation; either
11// version 2.1 of the License, or (at your option) any later version.
12// The full text of the license can be found in the file LICENSE at
13// the top level of the DFT-FE distribution.
14//
15// ---------------------------------------------------------------------
16//
17
18#ifndef BLASWrapper_h
19#define BLASWrapper_h
20
21#include <dftfeDataTypes.h>
22#include <MemorySpaceType.h>
23#include <complex>
24#include <TypeConfig.h>
25#include <DeviceTypeConfig.h>
26#include <cmath>
27#if defined(DFTFE_WITH_DEVICE)
28# include "Exceptions.h"
29#endif
30namespace dftfe
31{
32 namespace linearAlgebra
33 {
34 template <dftfe::utils::MemorySpace memorySpace>
36
37 template <>
39 {
40 public:
42
43 template <typename ValueType>
44 void
46 const ValueType *X,
47 const ValueType *Y,
48 ValueType *output) const;
49
50 template <typename ValueType>
51 void
53 const ValueType *X,
54 const ValueType *Y,
55 ValueType *output) const;
56
57 // Real-Single Precision GEMM
58 void
59 xgemm(const char transA,
60 const char transB,
61 const dftfe::uInt m,
62 const dftfe::uInt n,
63 const dftfe::uInt k,
64 const float *alpha,
65 const float *A,
66 const dftfe::uInt lda,
67 const float *B,
68 const dftfe::uInt ldb,
69 const float *beta,
70 float *C,
71 const dftfe::uInt ldc) const;
72 // Complex-Single Precision GEMM
73 void
74 xgemm(const char transA,
75 const char transB,
76 const dftfe::uInt m,
77 const dftfe::uInt n,
78 const dftfe::uInt k,
79 const std::complex<float> *alpha,
80 const std::complex<float> *A,
81 const dftfe::uInt lda,
82 const std::complex<float> *B,
83 const dftfe::uInt ldb,
84 const std::complex<float> *beta,
85 std::complex<float> *C,
86 const dftfe::uInt ldc) const;
87
88 // Real-double precison GEMM
89 void
90 xgemm(const char transA,
91 const char transB,
92 const dftfe::uInt m,
93 const dftfe::uInt n,
94 const dftfe::uInt k,
95 const double *alpha,
96 const double *A,
97 const dftfe::uInt lda,
98 const double *B,
99 const dftfe::uInt ldb,
100 const double *beta,
101 double *C,
102 const dftfe::uInt ldc) const;
103
104
105 // Complex-double precision GEMM
106 void
107 xgemm(const char transA,
108 const char transB,
109 const dftfe::uInt m,
110 const dftfe::uInt n,
111 const dftfe::uInt k,
112 const std::complex<double> *alpha,
113 const std::complex<double> *A,
114 const dftfe::uInt lda,
115 const std::complex<double> *B,
116 const dftfe::uInt ldb,
117 const std::complex<double> *beta,
118 std::complex<double> *C,
119 const dftfe::uInt ldc) const;
120
121 void
122 xgemv(const char transA,
123 const dftfe::uInt m,
124 const dftfe::uInt n,
125 const double *alpha,
126 const double *A,
127 const dftfe::uInt lda,
128 const double *x,
129 const dftfe::uInt incx,
130 const double *beta,
131 double *y,
132 const dftfe::uInt incy) const;
133
134 void
135 xgemv(const char transA,
136 const dftfe::uInt m,
137 const dftfe::uInt n,
138 const float *alpha,
139 const float *A,
140 const dftfe::uInt lda,
141 const float *x,
142 const dftfe::uInt incx,
143 const float *beta,
144 float *y,
145 const dftfe::uInt incy) const;
146
147 void
148 xgemv(const char transA,
149 const dftfe::uInt m,
150 const dftfe::uInt n,
151 const std::complex<double> *alpha,
152 const std::complex<double> *A,
153 const dftfe::uInt lda,
154 const std::complex<double> *x,
155 const dftfe::uInt incx,
156 const std::complex<double> *beta,
157 std::complex<double> *y,
158 const dftfe::uInt incy) const;
159
160 void
161 xgemv(const char transA,
162 const dftfe::uInt m,
163 const dftfe::uInt n,
164 const std::complex<float> *alpha,
165 const std::complex<float> *A,
166 const dftfe::uInt lda,
167 const std::complex<float> *x,
168 const dftfe::uInt incx,
169 const std::complex<float> *beta,
170 std::complex<float> *y,
171 const dftfe::uInt incy) const;
172
173
174 template <typename ValueType1, typename ValueType2>
175 void
176 xscal(ValueType1 *x, const ValueType2 alpha, const dftfe::uInt n) const;
177
178 // Brief
179 // for ( i = 0 i < numContiguousBlocks; i ++)
180 // {
181 // for( j = 0 ; j < contiguousBlockSize; j++)
182 // {
183 // output[j] += input1[i*contiguousBlockSize+j] *
184 // input2[i*contiguousBlockSize+j];
185 // }
186 // }
187 template <typename ValueType>
188 void
189 addVecOverContinuousIndex(const dftfe::uInt numContiguousBlocks,
190 const dftfe::uInt contiguousBlockSize,
191 const ValueType *input1,
192 const ValueType *input2,
193 ValueType *output);
194
195 // Real-Float scaling of Real-vector
196
197
198 // Real double Norm2
199 void
201 const double *x,
202 const dftfe::uInt incx,
203 const MPI_Comm &mpi_communicator,
204 double *result) const;
205
206
207 // Comples double Norm2
208 void
210 const std::complex<double> *x,
211 const dftfe::uInt incx,
212 const MPI_Comm &mpi_communicator,
213 double *result) const;
214 // Real dot product
215 void
217 const double *X,
218 const dftfe::uInt INCX,
219 const double *Y,
220 const dftfe::uInt INCY,
221 double *result) const;
222 // Real dot proeuct with all Reduce call
223 void
225 const double *X,
226 const dftfe::uInt INCX,
227 const double *Y,
228 const dftfe::uInt INCY,
229 const MPI_Comm &mpi_communicator,
230 double *result) const;
231
232 // Complex dot product
233 void
235 const std::complex<double> *X,
236 const dftfe::uInt INCX,
237 const std::complex<double> *Y,
238 const dftfe::uInt INCY,
239 std::complex<double> *result) const;
240
241 // Complex dot proeuct with all Reduce call
242 void
244 const std::complex<double> *X,
245 const dftfe::uInt INCX,
246 const std::complex<double> *Y,
247 const dftfe::uInt INCY,
248 const MPI_Comm &mpi_communicator,
249 std::complex<double> *result) const;
250
251
252 // MultiVector Real dot product
253 template <typename ValueType>
254 void
255 MultiVectorXDot(const dftfe::uInt contiguousBlockSize,
256 const dftfe::uInt numContiguousBlocks,
257 const ValueType *X,
258 const ValueType *Y,
259 const ValueType *onesVec,
260 ValueType *tempVector,
261 ValueType *tempResults,
262 ValueType *result) const;
263
264 // MultiVector Real dot product with all Reduce call
265 template <typename ValueType>
266 void
267 MultiVectorXDot(const dftfe::uInt contiguousBlockSize,
268 const dftfe::uInt numContiguousBlocks,
269 const ValueType *X,
270 const ValueType *Y,
271 const ValueType *onesVec,
272 ValueType *tempVector,
273 ValueType *tempResults,
274 const MPI_Comm &mpi_communicator,
275 ValueType *result) const;
276
277
278 // Real double Ax+y
279 void
281 const double *alpha,
282 const double *x,
283 const dftfe::uInt incx,
284 double *y,
285 const dftfe::uInt incy) const;
286
287 // Complex double Ax+y
288 void
290 const std::complex<double> *alpha,
291 const std::complex<double> *x,
292 const dftfe::uInt incx,
293 std::complex<double> *y,
294 const dftfe::uInt incy) const;
295
296 // Real float Ax+y
297 void
299 const float *alpha,
300 const float *x,
301 const dftfe::uInt incx,
302 float *y,
303 const dftfe::uInt incy) const;
304
305 // Complex double Ax+y
306 void
308 const std::complex<float> *alpha,
309 const std::complex<float> *x,
310 const dftfe::uInt incx,
311 std::complex<float> *y,
312 const dftfe::uInt incy) const;
313
314 // Real copy of double data
315 void
317 const double *x,
318 const dftfe::uInt incx,
319 double *y,
320 const dftfe::uInt incy) const;
321
322 // Complex double copy of data
323 void
325 const std::complex<double> *x,
326 const dftfe::uInt incx,
327 std::complex<double> *y,
328 const dftfe::uInt incy) const;
329
330 // Real copy of float data
331 void
333 const float *x,
334 const dftfe::uInt incx,
335 float *y,
336 const dftfe::uInt incy) const;
337
338 // Complex float copy of data
339 void
341 const std::complex<float> *x,
342 const dftfe::uInt incx,
343 std::complex<float> *y,
344 const dftfe::uInt incy) const;
345
346 // Real double symmetric matrix-vector product
347 void
348 xsymv(const char UPLO,
349 const dftfe::uInt N,
350 const double *alpha,
351 const double *A,
352 const dftfe::uInt LDA,
353 const double *X,
354 const dftfe::uInt INCX,
355 const double *beta,
356 double *C,
357 const dftfe::uInt INCY) const;
358
359 void
360 xgemmBatched(const char transA,
361 const char transB,
362 const dftfe::uInt m,
363 const dftfe::uInt n,
364 const dftfe::uInt k,
365 const double *alpha,
366 const double *A[],
367 const dftfe::uInt lda,
368 const double *B[],
369 const dftfe::uInt ldb,
370 const double *beta,
371 double *C[],
372 const dftfe::uInt ldc,
373 const dftfe::Int batchCount) const;
374
375 void
376 xgemmBatched(const char transA,
377 const char transB,
378 const dftfe::uInt m,
379 const dftfe::uInt n,
380 const dftfe::uInt k,
381 const std::complex<double> *alpha,
382 const std::complex<double> *A[],
383 const dftfe::uInt lda,
384 const std::complex<double> *B[],
385 const dftfe::uInt ldb,
386 const std::complex<double> *beta,
387 std::complex<double> *C[],
388 const dftfe::uInt ldc,
389 const dftfe::Int batchCount) const;
390
391
392 void
393 xgemmBatched(const char transA,
394 const char transB,
395 const dftfe::uInt m,
396 const dftfe::uInt n,
397 const dftfe::uInt k,
398 const float *alpha,
399 const float *A[],
400 const dftfe::uInt lda,
401 const float *B[],
402 const dftfe::uInt ldb,
403 const float *beta,
404 float *C[],
405 const dftfe::uInt ldc,
406 const dftfe::Int batchCount) const;
407
408 void
409 xgemmBatched(const char transA,
410 const char transB,
411 const dftfe::uInt m,
412 const dftfe::uInt n,
413 const dftfe::uInt k,
414 const std::complex<float> *alpha,
415 const std::complex<float> *A[],
416 const dftfe::uInt lda,
417 const std::complex<float> *B[],
418 const dftfe::uInt ldb,
419 const std::complex<float> *beta,
420 std::complex<float> *C[],
421 const dftfe::uInt ldc,
422 const dftfe::Int batchCount) const;
423
424
425 void
426 xgemmStridedBatched(const char transA,
427 const char transB,
428 const dftfe::uInt m,
429 const dftfe::uInt n,
430 const dftfe::uInt k,
431 const double *alpha,
432 const double *A,
433 const dftfe::uInt lda,
434 long long int strideA,
435 const double *B,
436 const dftfe::uInt ldb,
437 long long int strideB,
438 const double *beta,
439 double *C,
440 const dftfe::uInt ldc,
441 long long int strideC,
442 const dftfe::Int batchCount) const;
443
444 void
445 xgemmStridedBatched(const char transA,
446 const char transB,
447 const dftfe::uInt m,
448 const dftfe::uInt n,
449 const dftfe::uInt k,
450 const std::complex<double> *alpha,
451 const std::complex<double> *A,
452 const dftfe::uInt lda,
453 long long int strideA,
454 const std::complex<double> *B,
455 const dftfe::uInt ldb,
456 long long int strideB,
457 const std::complex<double> *beta,
458 std::complex<double> *C,
459 const dftfe::uInt ldc,
460 long long int strideC,
461 const dftfe::Int batchCount) const;
462
463 void
464 xgemmStridedBatched(const char transA,
465 const char transB,
466 const dftfe::uInt m,
467 const dftfe::uInt n,
468 const dftfe::uInt k,
469 const std::complex<float> *alpha,
470 const std::complex<float> *A,
471 const dftfe::uInt lda,
472 long long int strideA,
473 const std::complex<float> *B,
474 const dftfe::uInt ldb,
475 long long int strideB,
476 const std::complex<float> *beta,
477 std::complex<float> *C,
478 const dftfe::uInt ldc,
479 long long int strideC,
480 const dftfe::Int batchCount) const;
481
482 void
483 xgemmStridedBatched(const char transA,
484 const char transB,
485 const dftfe::uInt m,
486 const dftfe::uInt n,
487 const dftfe::uInt k,
488 const float *alpha,
489 const float *A,
490 const dftfe::uInt lda,
491 long long int strideA,
492 const float *B,
493 const dftfe::uInt ldb,
494 long long int strideB,
495 const float *beta,
496 float *C,
497 const dftfe::uInt ldc,
498 long long int strideC,
499 const dftfe::Int batchCount) const;
500
501 template <typename ValueTypeComplex, typename ValueTypeReal>
502 void
504 const ValueTypeComplex *complexArr,
505 ValueTypeReal *realArr,
506 ValueTypeReal *imagArr);
507
508
509 template <typename ValueTypeComplex, typename ValueTypeReal>
510 void
512 const ValueTypeReal *realArr,
513 const ValueTypeReal *imagArr,
514 ValueTypeComplex *complexArr);
515
516 template <typename ValueType1, typename ValueType2>
517 void
519 const ValueType1 *valueType1Arr,
520 ValueType2 *valueType2Arr);
521
522
523 template <typename ValueType1, typename ValueType2>
524 void
526 const dftfe::uInt contiguousBlockSize,
527 const dftfe::uInt numContiguousBlocks,
528 const ValueType1 *copyFromVec,
529 ValueType2 *copyToVecBlock,
530 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
531
532
533 template <typename ValueType1, typename ValueType2>
534 void
536 const dftfe::uInt contiguousBlockSize,
537 const dftfe::uInt numContiguousBlocks,
538 const dftfe::uInt startingVecId,
539 const ValueType1 *copyFromVec,
540 ValueType2 *copyToVecBlock,
541 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
542
543 template <typename ValueType1, typename ValueType2>
544 void
546 const dftfe::uInt contiguousBlockSize,
547 const dftfe::uInt numContiguousBlocks,
548 const ValueType1 *copyFromVecBlock,
549 ValueType2 *copyToVec,
550 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
551
552 template <typename ValueType1, typename ValueType2>
553 void
555 const dftfe::uInt blockSizeFrom,
556 const dftfe::uInt numBlocks,
557 const dftfe::uInt startingId,
558 const ValueType1 *copyFromVec,
559 ValueType2 *copyToVec) const;
560
561
562 template <typename ValueType1, typename ValueType2>
563 void
565 const dftfe::uInt strideTo,
566 const dftfe::uInt strideFrom,
567 const dftfe::uInt numBlocks,
568 const dftfe::uInt startingToId,
569 const dftfe::uInt startingFromId,
570 const ValueType1 *copyFromVec,
571 ValueType2 *copyToVec);
572
573
574 template <typename ValueType1, typename ValueType2>
575 void
577 const dftfe::uInt blockSizeFrom,
578 const dftfe::uInt numBlocks,
579 const dftfe::uInt startingId,
580 const ValueType1 *copyFromVec,
581 ValueType2 *copyToVec);
582
583 template <typename ValueType1, typename ValueType2>
584 void
585 stridedBlockAxpy(const dftfe::uInt contiguousBlockSize,
586 const dftfe::uInt numContiguousBlocks,
587 const ValueType1 *addFromVec,
588 const ValueType2 *scalingVector,
589 const ValueType2 a,
590 ValueType1 *addToVec) const;
591
592
593 template <typename ValueType1, typename ValueType2>
594 void
595 stridedBlockAxpBy(const dftfe::uInt contiguousBlockSize,
596 const dftfe::uInt numContiguousBlocks,
597 const ValueType1 *addFromVec,
598 const ValueType2 *scalingVector,
599 const ValueType2 a,
600 const ValueType2 b,
601 ValueType1 *addToVec) const;
602 template <typename ValueType1, typename ValueType2>
603 void
605 const ValueType2 alpha,
606 const ValueType1 *x,
607 const ValueType2 beta,
608 ValueType1 *y) const;
609 template <typename ValueType0,
610 typename ValueType1,
611 typename ValueType2,
612 typename ValueType3,
613 typename ValueType4>
614 void
616 const dftfe::uInt n,
617 const ValueType0 alpha,
618 const ValueType1 *A,
619 const ValueType2 *B,
620 const ValueType3 *D,
621 ValueType4 *C) const;
622
623 template <typename ValueType>
624 void
626 const dftfe::uInt contiguousBlockSize,
627 const dftfe::uInt numContiguousBlocks,
628 const ValueType *addFromVec,
629 ValueType *addToVec,
630 const dftfe::uInt *addToVecStartingContiguousBlockIds) const;
631
632 template <typename ValueType1, typename ValueType2, typename ValueType3>
633 void
635 const dftfe::uInt contiguousBlockSize,
636 const dftfe::uInt numContiguousBlocks,
637 const ValueType1 a,
638 const ValueType1 *s,
639 const ValueType2 *addFromVec,
640 ValueType3 *addToVec,
641 const dftfe::uInt *addToVecStartingContiguousBlockIds) const;
642 template <typename ValueType1, typename ValueType2, typename ValueType3>
643 void
645 const dftfe::uInt contiguousBlockSize,
646 const dftfe::uInt numContiguousBlocks,
647 const ValueType1 a,
648 const ValueType2 *addFromVec,
649 ValueType3 *addToVec,
650 const dftfe::uInt *addToVecStartingContiguousBlockIds) const;
651
652 template <typename ValueType1, typename ValueType2>
653 void
654 stridedBlockScale(const dftfe::uInt contiguousBlockSize,
655 const dftfe::uInt numContiguousBlocks,
656 const ValueType1 a,
657 const ValueType1 *s,
658 ValueType2 *x);
659
660 template <typename ValueType1, typename ValueType2>
661 void
663 const dftfe::uInt contiguousBlockSize,
664 const dftfe::uInt numContiguousBlocks,
665 const ValueType1 a,
666 const ValueType1 *s,
667 const ValueType2 *copyFromVec,
668 ValueType2 *copyToVecBlock,
669 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
670
671 template <typename ValueType>
672 void
673 stridedBlockScaleColumnWise(const dftfe::uInt contiguousBlockSize,
674 const dftfe::uInt numContiguousBlocks,
675 const ValueType *beta,
676 ValueType *x);
677
678 template <typename ValueType>
679 void
681 const dftfe::uInt numContiguousBlocks,
682 const ValueType *x,
683 const ValueType *beta,
684 ValueType *y);
685
686 template <typename ValueType>
687 void
689 const dftfe::uInt contiguousBlockSize,
690 const dftfe::uInt numContiguousBlocks,
691 const ValueType *x,
692 const ValueType *alpha,
693 const ValueType *y,
694 const ValueType *beta,
695 ValueType *z);
696
697 template <typename ValueType1, typename ValueType2>
698 void
699 rightDiagonalScale(const dftfe::uInt numberofVectors,
700 const dftfe::uInt sizeOfVector,
701 ValueType1 *X,
702 ValueType2 *D);
703
704 private:
705 };
706#if defined(DFTFE_WITH_DEVICE)
707# include "Exceptions.h"
708 enum class tensorOpDataType
709 {
710 fp32,
711 tf32,
712 bf16,
713 fp16
714 };
715
716 template <>
717 class BLASWrapper<dftfe::utils::MemorySpace::DEVICE>
718 {
719 public:
720 BLASWrapper();
721
722 template <typename ValueType1, typename ValueType2>
723 static void
724 copyValueType1ArrToValueType2ArrDeviceCall(
725 const dftfe::uInt size,
726 const ValueType1 *valueType1Arr,
727 ValueType2 *valueType2Arr,
728 const dftfe::utils::deviceStream_t streamId = 0);
729
730 template <typename ValueType>
731 void
732 hadamardProduct(const dftfe::uInt m,
733 const ValueType *X,
734 const ValueType *Y,
735 ValueType *output) const;
736
737 template <typename ValueType>
738 void
739 hadamardProductWithConj(const dftfe::uInt m,
740 const ValueType *X,
741 const ValueType *Y,
742 ValueType *output) const;
743
744 // Real-Single Precision GEMM
745 void
746 xgemm(const char transA,
747 const char transB,
748 const dftfe::uInt m,
749 const dftfe::uInt n,
750 const dftfe::uInt k,
751 const float *alpha,
752 const float *A,
753 const dftfe::uInt lda,
754 const float *B,
755 const dftfe::uInt ldb,
756 const float *beta,
757 float *C,
758 const dftfe::uInt ldc) const;
759 // Complex-Single Precision GEMM
760 void
761 xgemm(const char transA,
762 const char transB,
763 const dftfe::uInt m,
764 const dftfe::uInt n,
765 const dftfe::uInt k,
766 const std::complex<float> *alpha,
767 const std::complex<float> *A,
768 const dftfe::uInt lda,
769 const std::complex<float> *B,
770 const dftfe::uInt ldb,
771 const std::complex<float> *beta,
772 std::complex<float> *C,
773 const dftfe::uInt ldc) const;
774
775 // Real-double precison GEMM
776 void
777 xgemm(const char transA,
778 const char transB,
779 const dftfe::uInt m,
780 const dftfe::uInt n,
781 const dftfe::uInt k,
782 const double *alpha,
783 const double *A,
784 const dftfe::uInt lda,
785 const double *B,
786 const dftfe::uInt ldb,
787 const double *beta,
788 double *C,
789 const dftfe::uInt ldc) const;
790
791
792 // Complex-double precision GEMM
793 void
794 xgemm(const char transA,
795 const char transB,
796 const dftfe::uInt m,
797 const dftfe::uInt n,
798 const dftfe::uInt k,
799 const std::complex<double> *alpha,
800 const std::complex<double> *A,
801 const dftfe::uInt lda,
802 const std::complex<double> *B,
803 const dftfe::uInt ldb,
804 const std::complex<double> *beta,
805 std::complex<double> *C,
806 const dftfe::uInt ldc) const;
807
808
809 void
810 xgemv(const char transA,
811 const dftfe::uInt m,
812 const dftfe::uInt n,
813 const double *alpha,
814 const double *A,
815 const dftfe::uInt lda,
816 const double *x,
817 const dftfe::uInt incx,
818 const double *beta,
819 double *y,
820 const dftfe::uInt incy) const;
821
822 void
823 xgemv(const char transA,
824 const dftfe::uInt m,
825 const dftfe::uInt n,
826 const float *alpha,
827 const float *A,
828 const dftfe::uInt lda,
829 const float *x,
830 const dftfe::uInt incx,
831 const float *beta,
832 float *y,
833 const dftfe::uInt incy) const;
834
835 void
836 xgemv(const char transA,
837 const dftfe::uInt m,
838 const dftfe::uInt n,
839 const std::complex<double> *alpha,
840 const std::complex<double> *A,
841 const dftfe::uInt lda,
842 const std::complex<double> *x,
843 const dftfe::uInt incx,
844 const std::complex<double> *beta,
845 std::complex<double> *y,
846 const dftfe::uInt incy) const;
847
848 void
849 xgemv(const char transA,
850 const dftfe::uInt m,
851 const dftfe::uInt n,
852 const std::complex<float> *alpha,
853 const std::complex<float> *A,
854 const dftfe::uInt lda,
855 const std::complex<float> *x,
856 const dftfe::uInt incx,
857 const std::complex<float> *beta,
858 std::complex<float> *y,
859 const dftfe::uInt incy) const;
860
861 template <typename ValueType>
862 void
863 addVecOverContinuousIndex(const dftfe::uInt numContiguousBlocks,
864 const dftfe::uInt contiguousBlockSize,
865 const ValueType *input1,
866 const ValueType *input2,
867 ValueType *output);
868
869
870
871 template <typename ValueType1, typename ValueType2>
872 void
873 xscal(ValueType1 *x, const ValueType2 alpha, const dftfe::uInt n) const;
874
875
876
877 // Real double Norm2
878 void
879 xnrm2(const dftfe::uInt n,
880 const double *x,
881 const dftfe::uInt incx,
882 const MPI_Comm &mpi_communicator,
883 double *result) const;
884
885
886 // Complex double Norm2
887 void
888 xnrm2(const dftfe::uInt n,
889 const std::complex<double> *x,
890 const dftfe::uInt incx,
891 const MPI_Comm &mpi_communicator,
892 double *result) const;
893
894 // Real dot product
895 void
896 xdot(const dftfe::uInt N,
897 const double *X,
898 const dftfe::uInt INCX,
899 const double *Y,
900 const dftfe::uInt INCY,
901 double *result) const;
902
903 //
904 // Real dot product
905 void
906 xdot(const dftfe::uInt N,
907 const double *X,
908 const dftfe::uInt INCX,
909 const double *Y,
910 const dftfe::uInt INCY,
911 const MPI_Comm &mpi_communicator,
912 double *result) const;
913
914 // Complex dot product
915 void
916 xdot(const dftfe::uInt N,
917 const std::complex<double> *X,
918 const dftfe::uInt INCX,
919 const std::complex<double> *Y,
920 const dftfe::uInt INCY,
921 std::complex<double> *result) const;
922
923 // Complex dot product
924 void
925 xdot(const dftfe::uInt N,
926 const std::complex<double> *X,
927 const dftfe::uInt INCX,
928 const std::complex<double> *Y,
929 const dftfe::uInt INCY,
930 const MPI_Comm &mpi_communicator,
931 std::complex<double> *result) const;
932
933
934 template <typename ValueType>
935 void
936 MultiVectorXDot(const dftfe::uInt contiguousBlockSize,
937 const dftfe::uInt numContiguousBlocks,
938 const ValueType *X,
939 const ValueType *Y,
940 const ValueType *onesVec,
941 ValueType *tempVector,
942 ValueType *tempResults,
943 ValueType *result) const;
944
945 template <typename ValueType>
946 void
947 MultiVectorXDot(const dftfe::uInt contiguousBlockSize,
948 const dftfe::uInt numContiguousBlocks,
949 const ValueType *X,
950 const ValueType *Y,
951 const ValueType *onesVec,
952 ValueType *tempVector,
953 ValueType *tempResults,
954 const MPI_Comm &mpi_communicator,
955 ValueType *result) const;
956
957 // Real double Ax+y
958 void
959 xaxpy(const dftfe::uInt n,
960 const double *alpha,
961 const double *x,
962 const dftfe::uInt incx,
963 double *y,
964 const dftfe::uInt incy) const;
965
966 // Complex double Ax+y
967 void
968 xaxpy(const dftfe::uInt n,
969 const std::complex<double> *alpha,
970 const std::complex<double> *x,
971 const dftfe::uInt incx,
972 std::complex<double> *y,
973 const dftfe::uInt incy) const;
974
975 // Real copy of double data
976 void
977 xcopy(const dftfe::uInt n,
978 const double *x,
979 const dftfe::uInt incx,
980 double *y,
981 const dftfe::uInt incy) const;
982
983 // Complex double copy of data
984 void
985 xcopy(const dftfe::uInt n,
986 const std::complex<double> *x,
987 const dftfe::uInt incx,
988 std::complex<double> *y,
989 const dftfe::uInt incy) const;
990
991 // Real copy of float data
992 void
993 xcopy(const dftfe::uInt n,
994 const float *x,
995 const dftfe::uInt incx,
996 float *y,
997 const dftfe::uInt incy) const;
998
999 // Complex float copy of data
1000 void
1001 xcopy(const dftfe::uInt n,
1002 const std::complex<float> *x,
1003 const dftfe::uInt incx,
1004 std::complex<float> *y,
1005 const dftfe::uInt incy) const;
1006
1007 // Real double symmetric matrix-vector product
1008 void
1009 xsymv(const char UPLO,
1010 const dftfe::uInt N,
1011 const double *alpha,
1012 const double *A,
1013 const dftfe::uInt LDA,
1014 const double *X,
1015 const dftfe::uInt INCX,
1016 const double *beta,
1017 double *C,
1018 const dftfe::uInt INCY) const;
1019
1020 void
1021 xgemmBatched(const char transA,
1022 const char transB,
1023 const dftfe::uInt m,
1024 const dftfe::uInt n,
1025 const dftfe::uInt k,
1026 const double *alpha,
1027 const double *A[],
1028 const dftfe::uInt lda,
1029 const double *B[],
1030 const dftfe::uInt ldb,
1031 const double *beta,
1032 double *C[],
1033 const dftfe::uInt ldc,
1034 const dftfe::Int batchCount) const;
1035
1036 void
1037 xgemmBatched(const char transA,
1038 const char transB,
1039 const dftfe::uInt m,
1040 const dftfe::uInt n,
1041 const dftfe::uInt k,
1042 const std::complex<double> *alpha,
1043 const std::complex<double> *A[],
1044 const dftfe::uInt lda,
1045 const std::complex<double> *B[],
1046 const dftfe::uInt ldb,
1047 const std::complex<double> *beta,
1048 std::complex<double> *C[],
1049 const dftfe::uInt ldc,
1050 const dftfe::Int batchCount) const;
1051
1052 void
1053 xgemmBatched(const char transA,
1054 const char transB,
1055 const dftfe::uInt m,
1056 const dftfe::uInt n,
1057 const dftfe::uInt k,
1058 const float *alpha,
1059 const float *A[],
1060 const dftfe::uInt lda,
1061 const float *B[],
1062 const dftfe::uInt ldb,
1063 const float *beta,
1064 float *C[],
1065 const dftfe::uInt ldc,
1066 const dftfe::Int batchCount) const;
1067
1068 void
1069 xgemmBatched(const char transA,
1070 const char transB,
1071 const dftfe::uInt m,
1072 const dftfe::uInt n,
1073 const dftfe::uInt k,
1074 const std::complex<float> *alpha,
1075 const std::complex<float> *A[],
1076 const dftfe::uInt lda,
1077 const std::complex<float> *B[],
1078 const dftfe::uInt ldb,
1079 const std::complex<float> *beta,
1080 std::complex<float> *C[],
1081 const dftfe::uInt ldc,
1082 const dftfe::Int batchCount) const;
1083
1084 void
1085 xgemmStridedBatched(const char transA,
1086 const char transB,
1087 const dftfe::uInt m,
1088 const dftfe::uInt n,
1089 const dftfe::uInt k,
1090 const double *alpha,
1091 const double *A,
1092 const dftfe::uInt lda,
1093 long long int strideA,
1094 const double *B,
1095 const dftfe::uInt ldb,
1096 long long int strideB,
1097 const double *beta,
1098 double *C,
1099 const dftfe::uInt ldc,
1100 long long int strideC,
1101 const dftfe::Int batchCount) const;
1102
1103 void
1104 xgemmStridedBatched(const char transA,
1105 const char transB,
1106 const dftfe::uInt m,
1107 const dftfe::uInt n,
1108 const dftfe::uInt k,
1109 const std::complex<double> *alpha,
1110 const std::complex<double> *A,
1111 const dftfe::uInt lda,
1112 long long int strideA,
1113 const std::complex<double> *B,
1114 const dftfe::uInt ldb,
1115 long long int strideB,
1116 const std::complex<double> *beta,
1117 std::complex<double> *C,
1118 const dftfe::uInt ldc,
1119 long long int strideC,
1120 const dftfe::Int batchCount) const;
1121
1122 void
1123 xgemmStridedBatched(const char transA,
1124 const char transB,
1125 const dftfe::uInt m,
1126 const dftfe::uInt n,
1127 const dftfe::uInt k,
1128 const std::complex<float> *alpha,
1129 const std::complex<float> *A,
1130 const dftfe::uInt lda,
1131 long long int strideA,
1132 const std::complex<float> *B,
1133 const dftfe::uInt ldb,
1134 long long int strideB,
1135 const std::complex<float> *beta,
1136 std::complex<float> *C,
1137 const dftfe::uInt ldc,
1138 long long int strideC,
1139 const dftfe::Int batchCount) const;
1140
1141 void
1142 xgemmStridedBatched(const char transA,
1143 const char transB,
1144 const dftfe::uInt m,
1145 const dftfe::uInt n,
1146 const dftfe::uInt k,
1147 const float *alpha,
1148 const float *A,
1149 const dftfe::uInt lda,
1150 long long int strideA,
1151 const float *B,
1152 const dftfe::uInt ldb,
1153 long long int strideB,
1154 const float *beta,
1155 float *C,
1156 const dftfe::uInt ldc,
1157 long long int strideC,
1158 const dftfe::Int batchCount) const;
1159
1160 template <typename ValueTypeComplex, typename ValueTypeReal>
1161 void
1162 copyComplexArrToRealArrs(const dftfe::uInt size,
1163 const ValueTypeComplex *complexArr,
1164 ValueTypeReal *realArr,
1165 ValueTypeReal *imagArr);
1166
1167
1168 template <typename ValueTypeComplex, typename ValueTypeReal>
1169 void
1170 copyRealArrsToComplexArr(const dftfe::uInt size,
1171 const ValueTypeReal *realArr,
1172 const ValueTypeReal *imagArr,
1173 ValueTypeComplex *complexArr);
1174
1175 template <typename ValueType1, typename ValueType2>
1176 void
1177 copyValueType1ArrToValueType2Arr(const dftfe::uInt size,
1178 const ValueType1 *valueType1Arr,
1179 ValueType2 *valueType2Arr);
1180
1181
1182 template <typename ValueType1, typename ValueType2>
1183 void
1184 stridedCopyToBlock(
1185 const dftfe::uInt contiguousBlockSize,
1186 const dftfe::uInt numContiguousBlocks,
1187 const ValueType1 *copyFromVec,
1188 ValueType2 *copyToVecBlock,
1189 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
1190
1191 template <typename ValueType1, typename ValueType2>
1192 void
1193 stridedCopyToBlock(
1194 const dftfe::uInt contiguousBlockSize,
1195 const dftfe::uInt numContiguousBlocks,
1196 const dftfe::uInt startingVecId,
1197 const ValueType1 *copyFromVec,
1198 ValueType2 *copyToVecBlock,
1199 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
1200
1201
1202 template <typename ValueType1, typename ValueType2>
1203 void
1204 stridedCopyFromBlock(
1205 const dftfe::uInt contiguousBlockSize,
1206 const dftfe::uInt numContiguousBlocks,
1207 const ValueType1 *copyFromVecBlock,
1208 ValueType2 *copyToVec,
1209 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
1210
1211 template <typename ValueType1, typename ValueType2>
1212 void
1213 stridedCopyToBlockConstantStride(const dftfe::uInt blockSizeTo,
1214 const dftfe::uInt blockSizeFrom,
1215 const dftfe::uInt numBlocks,
1216 const dftfe::uInt startingId,
1217 const ValueType1 *copyFromVec,
1218 ValueType2 *copyToVec) const;
1219
1220
1221 template <typename ValueType1, typename ValueType2>
1222 void
1223 stridedCopyConstantStride(const dftfe::uInt blockSize,
1224 const dftfe::uInt strideTo,
1225 const dftfe::uInt strideFrom,
1226 const dftfe::uInt numBlocks,
1227 const dftfe::uInt startingToId,
1228 const dftfe::uInt startingFromId,
1229 const ValueType1 *copyFromVec,
1230 ValueType2 *copyToVec);
1231
1232
1233 template <typename ValueType1, typename ValueType2>
1234 void
1235 stridedCopyFromBlockConstantStride(const dftfe::uInt blockSizeTo,
1236 const dftfe::uInt blockSizeFrom,
1237 const dftfe::uInt numBlocks,
1238 const dftfe::uInt startingId,
1239 const ValueType1 *copyFromVec,
1240 ValueType2 *copyToVec);
1241 template <typename ValueType1, typename ValueType2>
1242 void
1243 axpby(const dftfe::uInt n,
1244 const ValueType2 alpha,
1245 const ValueType1 *x,
1246 const ValueType2 beta,
1247 ValueType1 *y) const;
1248
1249 template <typename ValueType1, typename ValueType2>
1250 void
1251 stridedBlockAxpy(const dftfe::uInt contiguousBlockSize,
1252 const dftfe::uInt numContiguousBlocks,
1253 const ValueType1 *addFromVec,
1254 const ValueType2 *scalingVector,
1255 const ValueType2 a,
1256 ValueType1 *addToVec) const;
1257 template <typename ValueType1, typename ValueType2>
1258 void
1259 stridedBlockAxpBy(const dftfe::uInt contiguousBlockSize,
1260 const dftfe::uInt numContiguousBlocks,
1261 const ValueType1 *addFromVec,
1262 const ValueType2 *scalingVector,
1263 const ValueType2 a,
1264 const ValueType2 b,
1265 ValueType1 *addToVec) const;
1266
1267 template <typename ValueType0,
1268 typename ValueType1,
1269 typename ValueType2,
1270 typename ValueType3,
1271 typename ValueType4>
1272 void
1273 ApaBD(const dftfe::uInt m,
1274 const dftfe::uInt n,
1275 const ValueType0 alpha,
1276 const ValueType1 *A,
1277 const ValueType2 *B,
1278 const ValueType3 *D,
1279 ValueType4 *C) const;
1280
1281
1282 template <typename ValueType>
1283 void
1284 axpyStridedBlockAtomicAdd(
1285 const dftfe::uInt contiguousBlockSize,
1286 const dftfe::uInt numContiguousBlocks,
1287 const ValueType *addFromVec,
1288 ValueType *addToVec,
1289 const dftfe::uInt *addToVecStartingContiguousBlockIds) const;
1290
1291 template <typename ValueType1, typename ValueType2, typename ValueType3>
1292 void
1293 axpyStridedBlockAtomicAdd(
1294 const dftfe::uInt contiguousBlockSize,
1295 const dftfe::uInt numContiguousBlocks,
1296 const ValueType1 a,
1297 const ValueType1 *s,
1298 const ValueType2 *addFromVec,
1299 ValueType3 *addToVec,
1300 const dftfe::uInt *addToVecStartingContiguousBlockIds) const;
1301 template <typename ValueType1, typename ValueType2, typename ValueType3>
1302 void
1303 axpyStridedBlockAtomicAdd(
1304 const dftfe::uInt contiguousBlockSize,
1305 const dftfe::uInt numContiguousBlocks,
1306 const ValueType1 a,
1307 const ValueType2 *addFromVec,
1308 ValueType3 *addToVec,
1309 const dftfe::uInt *addToVecStartingContiguousBlockIds) const;
1310
1311 template <typename ValueType1, typename ValueType2>
1312 void
1313 stridedBlockScale(const dftfe::uInt contiguousBlockSize,
1314 const dftfe::uInt numContiguousBlocks,
1315 const ValueType1 a,
1316 const ValueType1 *s,
1317 ValueType2 *x);
1318 template <typename ValueType1, typename ValueType2>
1319 void
1320 stridedBlockScaleCopy(
1321 const dftfe::uInt contiguousBlockSize,
1322 const dftfe::uInt numContiguousBlocks,
1323 const ValueType1 a,
1324 const ValueType1 *s,
1325 const ValueType2 *copyFromVec,
1326 ValueType2 *copyToVecBlock,
1327 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
1328
1329 template <typename ValueType>
1330 void
1331 stridedBlockScaleColumnWise(const dftfe::uInt contiguousBlockSize,
1332 const dftfe::uInt numContiguousBlocks,
1333 const ValueType *beta,
1334 ValueType *x);
1335
1336 template <typename ValueType>
1337 void
1338 stridedBlockScaleAndAddColumnWise(const dftfe::uInt contiguousBlockSize,
1339 const dftfe::uInt numContiguousBlocks,
1340 const ValueType *x,
1341 const ValueType *beta,
1342 ValueType *y);
1343
1344 template <typename ValueType>
1345 void
1346 stridedBlockScaleAndAddTwoVecColumnWise(
1347 const dftfe::uInt contiguousBlockSize,
1348 const dftfe::uInt numContiguousBlocks,
1349 const ValueType *x,
1350 const ValueType *alpha,
1351 const ValueType *y,
1352 const ValueType *beta,
1353 ValueType *z);
1354
1355 template <typename ValueType1, typename ValueType2>
1356 void
1357 rightDiagonalScale(const dftfe::uInt numberofVectors,
1358 const dftfe::uInt sizeOfVector,
1359 ValueType1 *X,
1360 ValueType2 *D);
1361
1363 getDeviceBlasHandle();
1364
1365
1366 template <typename ValueType1, typename ValueType2>
1367 void
1368 copyBlockDiagonalValueType1OffDiagonalValueType2FromValueType1Arr(
1369 const dftfe::uInt B,
1370 const dftfe::uInt DRem,
1371 const dftfe::uInt D,
1372 const ValueType1 *valueType1SrcArray,
1373 ValueType1 *valueType1DstArray,
1374 ValueType2 *valueType2DstArray);
1375
1376 void
1377 setTensorOpDataType(tensorOpDataType opType)
1378 {
1379 d_opType = opType;
1380 }
1381
1383 setStream(dftfe::utils::deviceStream_t streamId);
1384
1385 private:
1386# ifdef DFTFE_WITH_DEVICE_AMD
1387 void
1388 initialize();
1389# endif
1390
1391 /// storage for deviceblas handle
1392 dftfe::utils::deviceBlasHandle_t d_deviceBlasHandle;
1394 tensorOpDataType d_opType;
1395
1397 create();
1398
1400 destroy();
1401 };
1402#endif
1403
1404 } // end of namespace linearAlgebra
1405
1406} // end of namespace dftfe
1407
1408
1409#endif // BLASWrapper_h
void xcopy(const dftfe::uInt n, const float *x, const dftfe::uInt incx, float *y, const dftfe::uInt incy) const
void xaxpy(const dftfe::uInt n, const std::complex< float > *alpha, const std::complex< float > *x, const dftfe::uInt incx, std::complex< float > *y, const dftfe::uInt incy) const
void stridedCopyToBlock(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const dftfe::uInt startingVecId, const ValueType1 *copyFromVec, ValueType2 *copyToVecBlock, const dftfe::uInt *copyFromVecStartingContiguousBlockIds)
void stridedBlockAxpy(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 *addFromVec, const ValueType2 *scalingVector, const ValueType2 a, ValueType1 *addToVec) const
void axpyStridedBlockAtomicAdd(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 a, const ValueType2 *addFromVec, ValueType3 *addToVec, const dftfe::uInt *addToVecStartingContiguousBlockIds) const
void hadamardProduct(const dftfe::uInt m, const ValueType *X, const ValueType *Y, ValueType *output) const
void xgemv(const char transA, const dftfe::uInt m, const dftfe::uInt n, const std::complex< float > *alpha, const std::complex< float > *A, const dftfe::uInt lda, const std::complex< float > *x, const dftfe::uInt incx, const std::complex< float > *beta, std::complex< float > *y, const dftfe::uInt incy) const
void xgemmStridedBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< double > *alpha, const std::complex< double > *A, const dftfe::uInt lda, long long int strideA, const std::complex< double > *B, const dftfe::uInt ldb, long long int strideB, const std::complex< double > *beta, std::complex< double > *C, const dftfe::uInt ldc, long long int strideC, const dftfe::Int batchCount) const
void stridedBlockScaleCopy(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 a, const ValueType1 *s, const ValueType2 *copyFromVec, ValueType2 *copyToVecBlock, const dftfe::uInt *copyFromVecStartingContiguousBlockIds)
void xnrm2(const dftfe::uInt n, const double *x, const dftfe::uInt incx, const MPI_Comm &mpi_communicator, double *result) const
void xgemmBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const double *alpha, const double *A[], const dftfe::uInt lda, const double *B[], const dftfe::uInt ldb, const double *beta, double *C[], const dftfe::uInt ldc, const dftfe::Int batchCount) const
void xscal(ValueType1 *x, const ValueType2 alpha, const dftfe::uInt n) const
void copyValueType1ArrToValueType2Arr(const dftfe::uInt size, const ValueType1 *valueType1Arr, ValueType2 *valueType2Arr)
void xgemm(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const double *alpha, const double *A, const dftfe::uInt lda, const double *B, const dftfe::uInt ldb, const double *beta, double *C, const dftfe::uInt ldc) const
void xdot(const dftfe::uInt N, const double *X, const dftfe::uInt INCX, const double *Y, const dftfe::uInt INCY, const MPI_Comm &mpi_communicator, double *result) const
void hadamardProductWithConj(const dftfe::uInt m, const ValueType *X, const ValueType *Y, ValueType *output) const
void xgemmBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< float > *alpha, const std::complex< float > *A[], const dftfe::uInt lda, const std::complex< float > *B[], const dftfe::uInt ldb, const std::complex< float > *beta, std::complex< float > *C[], const dftfe::uInt ldc, const dftfe::Int batchCount) const
void xgemv(const char transA, const dftfe::uInt m, const dftfe::uInt n, const float *alpha, const float *A, const dftfe::uInt lda, const float *x, const dftfe::uInt incx, const float *beta, float *y, const dftfe::uInt incy) const
void MultiVectorXDot(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *X, const ValueType *Y, const ValueType *onesVec, ValueType *tempVector, ValueType *tempResults, ValueType *result) const
void MultiVectorXDot(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *X, const ValueType *Y, const ValueType *onesVec, ValueType *tempVector, ValueType *tempResults, const MPI_Comm &mpi_communicator, ValueType *result) const
void xsymv(const char UPLO, const dftfe::uInt N, const double *alpha, const double *A, const dftfe::uInt LDA, const double *X, const dftfe::uInt INCX, const double *beta, double *C, const dftfe::uInt INCY) const
void copyComplexArrToRealArrs(const dftfe::uInt size, const ValueTypeComplex *complexArr, ValueTypeReal *realArr, ValueTypeReal *imagArr)
void stridedBlockScale(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 a, const ValueType1 *s, ValueType2 *x)
void xdot(const dftfe::uInt N, const std::complex< double > *X, const dftfe::uInt INCX, const std::complex< double > *Y, const dftfe::uInt INCY, const MPI_Comm &mpi_communicator, std::complex< double > *result) const
void xgemmStridedBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< float > *alpha, const std::complex< float > *A, const dftfe::uInt lda, long long int strideA, const std::complex< float > *B, const dftfe::uInt ldb, long long int strideB, const std::complex< float > *beta, std::complex< float > *C, const dftfe::uInt ldc, long long int strideC, const dftfe::Int batchCount) const
void xgemm(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< double > *alpha, const std::complex< double > *A, const dftfe::uInt lda, const std::complex< double > *B, const dftfe::uInt ldb, const std::complex< double > *beta, std::complex< double > *C, const dftfe::uInt ldc) const
void xcopy(const dftfe::uInt n, const double *x, const dftfe::uInt incx, double *y, const dftfe::uInt incy) const
void xdot(const dftfe::uInt N, const std::complex< double > *X, const dftfe::uInt INCX, const std::complex< double > *Y, const dftfe::uInt INCY, std::complex< double > *result) const
void xgemmStridedBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const float *alpha, const float *A, const dftfe::uInt lda, long long int strideA, const float *B, const dftfe::uInt ldb, long long int strideB, const float *beta, float *C, const dftfe::uInt ldc, long long int strideC, const dftfe::Int batchCount) const
void axpby(const dftfe::uInt n, const ValueType2 alpha, const ValueType1 *x, const ValueType2 beta, ValueType1 *y) const
void axpyStridedBlockAtomicAdd(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 a, const ValueType1 *s, const ValueType2 *addFromVec, ValueType3 *addToVec, const dftfe::uInt *addToVecStartingContiguousBlockIds) const
void xaxpy(const dftfe::uInt n, const float *alpha, const float *x, const dftfe::uInt incx, float *y, const dftfe::uInt incy) const
void xgemm(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const float *alpha, const float *A, const dftfe::uInt lda, const float *B, const dftfe::uInt ldb, const float *beta, float *C, const dftfe::uInt ldc) const
void stridedCopyToBlock(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 *copyFromVec, ValueType2 *copyToVecBlock, const dftfe::uInt *copyFromVecStartingContiguousBlockIds)
void stridedCopyFromBlock(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 *copyFromVecBlock, ValueType2 *copyToVec, const dftfe::uInt *copyFromVecStartingContiguousBlockIds)
void stridedCopyToBlockConstantStride(const dftfe::uInt blockSizeTo, const dftfe::uInt blockSizeFrom, const dftfe::uInt numBlocks, const dftfe::uInt startingId, const ValueType1 *copyFromVec, ValueType2 *copyToVec) const
void stridedBlockScaleAndAddTwoVecColumnWise(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *x, const ValueType *alpha, const ValueType *y, const ValueType *beta, ValueType *z)
void stridedBlockAxpBy(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 *addFromVec, const ValueType2 *scalingVector, const ValueType2 a, const ValueType2 b, ValueType1 *addToVec) const
void xnrm2(const dftfe::uInt n, const std::complex< double > *x, const dftfe::uInt incx, const MPI_Comm &mpi_communicator, double *result) const
void stridedBlockScaleColumnWise(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *beta, ValueType *x)
void xgemmStridedBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const double *alpha, const double *A, const dftfe::uInt lda, long long int strideA, const double *B, const dftfe::uInt ldb, long long int strideB, const double *beta, double *C, const dftfe::uInt ldc, long long int strideC, const dftfe::Int batchCount) const
void xgemmBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< double > *alpha, const std::complex< double > *A[], const dftfe::uInt lda, const std::complex< double > *B[], const dftfe::uInt ldb, const std::complex< double > *beta, std::complex< double > *C[], const dftfe::uInt ldc, const dftfe::Int batchCount) const
void stridedCopyFromBlockConstantStride(const dftfe::uInt blockSizeTo, const dftfe::uInt blockSizeFrom, const dftfe::uInt numBlocks, const dftfe::uInt startingId, const ValueType1 *copyFromVec, ValueType2 *copyToVec)
void stridedBlockScaleAndAddColumnWise(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *x, const ValueType *beta, ValueType *y)
void ApaBD(const dftfe::uInt m, const dftfe::uInt n, const ValueType0 alpha, const ValueType1 *A, const ValueType2 *B, const ValueType3 *D, ValueType4 *C) const
void xaxpy(const dftfe::uInt n, const std::complex< double > *alpha, const std::complex< double > *x, const dftfe::uInt incx, std::complex< double > *y, const dftfe::uInt incy) const
void xgemv(const char transA, const dftfe::uInt m, const dftfe::uInt n, const double *alpha, const double *A, const dftfe::uInt lda, const double *x, const dftfe::uInt incx, const double *beta, double *y, const dftfe::uInt incy) const
void stridedCopyConstantStride(const dftfe::uInt blockSize, const dftfe::uInt strideTo, const dftfe::uInt strideFrom, const dftfe::uInt numBlocks, const dftfe::uInt startingToId, const dftfe::uInt startingFromId, const ValueType1 *copyFromVec, ValueType2 *copyToVec)
void xgemmBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const float *alpha, const float *A[], const dftfe::uInt lda, const float *B[], const dftfe::uInt ldb, const float *beta, float *C[], const dftfe::uInt ldc, const dftfe::Int batchCount) const
void addVecOverContinuousIndex(const dftfe::uInt numContiguousBlocks, const dftfe::uInt contiguousBlockSize, const ValueType *input1, const ValueType *input2, ValueType *output)
void xaxpy(const dftfe::uInt n, const double *alpha, const double *x, const dftfe::uInt incx, double *y, const dftfe::uInt incy) const
void axpyStridedBlockAtomicAdd(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *addFromVec, ValueType *addToVec, const dftfe::uInt *addToVecStartingContiguousBlockIds) const
void rightDiagonalScale(const dftfe::uInt numberofVectors, const dftfe::uInt sizeOfVector, ValueType1 *X, ValueType2 *D)
void xgemv(const char transA, const dftfe::uInt m, const dftfe::uInt n, const std::complex< double > *alpha, const std::complex< double > *A, const dftfe::uInt lda, const std::complex< double > *x, const dftfe::uInt incx, const std::complex< double > *beta, std::complex< double > *y, const dftfe::uInt incy) const
void xcopy(const dftfe::uInt n, const std::complex< float > *x, const dftfe::uInt incx, std::complex< float > *y, const dftfe::uInt incy) const
void copyRealArrsToComplexArr(const dftfe::uInt size, const ValueTypeReal *realArr, const ValueTypeReal *imagArr, ValueTypeComplex *complexArr)
void xgemm(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< float > *alpha, const std::complex< float > *A, const dftfe::uInt lda, const std::complex< float > *B, const dftfe::uInt ldb, const std::complex< float > *beta, std::complex< float > *C, const dftfe::uInt ldc) const
void xdot(const dftfe::uInt N, const double *X, const dftfe::uInt INCX, const double *Y, const dftfe::uInt INCY, double *result) const
void xcopy(const dftfe::uInt n, const std::complex< double > *x, const dftfe::uInt incx, std::complex< double > *y, const dftfe::uInt incy) const
Definition BLASWrapper.h:35
Definition BLASWrapper.h:33
cudaStream_t deviceStream_t
Definition DeviceTypeConfig.cu.h:27
cublasStatus_t deviceBlasStatus_t
Definition DeviceTypeConfig.cu.h:38
@ HOST
Definition MemorySpaceType.h:34
@ DEVICE
Definition MemorySpaceType.h:36
cublasHandle_t deviceBlasHandle_t
Definition DeviceTypeConfig.cu.h:36
Definition pseudoPotentialToDftfeConverter.cc:34
std::uint32_t uInt
Definition TypeConfig.h:10
@ LDA
Definition ExcSSDFunctionalBaseClass.h:35
std::int32_t Int
Definition TypeConfig.h:11