DFT-FE 1.3.0-pre
Density Functional Theory With Finite-Elements
Loading...
Searching...
No Matches
BLASWrapper.h
Go to the documentation of this file.
1// ---------------------------------------------------------------------
2//
3// Copyright (c) 2017-2025 The Regents of the University of Michigan and DFT-FE
4// authors.
5//
6// This file is part of the DFT-FE code.
7//
8// The DFT-FE code is free software; you can use it, redistribute
9// it, and/or modify it under the terms of the GNU Lesser General
10// Public License as published by the Free Software Foundation; either
11// version 2.1 of the License, or (at your option) any later version.
12// The full text of the license can be found in the file LICENSE at
13// the top level of the DFT-FE distribution.
14//
15// ---------------------------------------------------------------------
16//
17
18#ifndef BLASWrapper_h
19#define BLASWrapper_h
20
21#include <dftfeDataTypes.h>
22#include <MemorySpaceType.h>
23#include <complex>
24#include <TypeConfig.h>
25#include <DeviceTypeConfig.h>
26#include <cmath>
27#if defined(DFTFE_WITH_DEVICE)
28# include "Exceptions.h"
29#endif
30namespace dftfe
31{
32 namespace linearAlgebra
33 {
34 template <dftfe::utils::MemorySpace memorySpace>
36
37 template <>
39 {
40 public:
42
43 template <typename ValueType>
44 void
46 const ValueType *X,
47 const ValueType *Y,
48 ValueType *output);
49
50 template <typename ValueType>
51 void
53 const ValueType *X,
54 const ValueType *Y,
55 ValueType *output);
56
57 // Real-Single Precision GEMM
58 void
59 xgemm(const char transA,
60 const char transB,
61 const dftfe::uInt m,
62 const dftfe::uInt n,
63 const dftfe::uInt k,
64 const float *alpha,
65 const float *A,
66 const dftfe::uInt lda,
67 const float *B,
68 const dftfe::uInt ldb,
69 const float *beta,
70 float *C,
71 const dftfe::uInt ldc);
72 // Complex-Single Precision GEMM
73 void
74 xgemm(const char transA,
75 const char transB,
76 const dftfe::uInt m,
77 const dftfe::uInt n,
78 const dftfe::uInt k,
79 const std::complex<float> *alpha,
80 const std::complex<float> *A,
81 const dftfe::uInt lda,
82 const std::complex<float> *B,
83 const dftfe::uInt ldb,
84 const std::complex<float> *beta,
85 std::complex<float> *C,
86 const dftfe::uInt ldc);
87
88 // Real-double precison GEMM
89 void
90 xgemm(const char transA,
91 const char transB,
92 const dftfe::uInt m,
93 const dftfe::uInt n,
94 const dftfe::uInt k,
95 const double *alpha,
96 const double *A,
97 const dftfe::uInt lda,
98 const double *B,
99 const dftfe::uInt ldb,
100 const double *beta,
101 double *C,
102 const dftfe::uInt ldc);
103
104
105 // Complex-double precision GEMM
106 void
107 xgemm(const char transA,
108 const char transB,
109 const dftfe::uInt m,
110 const dftfe::uInt n,
111 const dftfe::uInt k,
112 const std::complex<double> *alpha,
113 const std::complex<double> *A,
114 const dftfe::uInt lda,
115 const std::complex<double> *B,
116 const dftfe::uInt ldb,
117 const std::complex<double> *beta,
118 std::complex<double> *C,
119 const dftfe::uInt ldc);
120
121 void
122 xgemv(const char transA,
123 const dftfe::uInt m,
124 const dftfe::uInt n,
125 const double *alpha,
126 const double *A,
127 const dftfe::uInt lda,
128 const double *x,
129 const dftfe::uInt incx,
130 const double *beta,
131 double *y,
132 const dftfe::uInt incy);
133
134 void
135 xgemv(const char transA,
136 const dftfe::uInt m,
137 const dftfe::uInt n,
138 const float *alpha,
139 const float *A,
140 const dftfe::uInt lda,
141 const float *x,
142 const dftfe::uInt incx,
143 const float *beta,
144 float *y,
145 const dftfe::uInt incy);
146
147 void
148 xgemv(const char transA,
149 const dftfe::uInt m,
150 const dftfe::uInt n,
151 const std::complex<double> *alpha,
152 const std::complex<double> *A,
153 const dftfe::uInt lda,
154 const std::complex<double> *x,
155 const dftfe::uInt incx,
156 const std::complex<double> *beta,
157 std::complex<double> *y,
158 const dftfe::uInt incy);
159
160 void
161 xgemv(const char transA,
162 const dftfe::uInt m,
163 const dftfe::uInt n,
164 const std::complex<float> *alpha,
165 const std::complex<float> *A,
166 const dftfe::uInt lda,
167 const std::complex<float> *x,
168 const dftfe::uInt incx,
169 const std::complex<float> *beta,
170 std::complex<float> *y,
171 const dftfe::uInt incy);
172
173
174 template <typename ValueType1, typename ValueType2>
175 void
176 xscal(ValueType1 *x, const ValueType2 alpha, const dftfe::uInt n);
177
178 // Brief
179 // for ( i = 0 i < numContiguousBlocks; i ++)
180 // {
181 // for( j = 0 ; j < contiguousBlockSize; j++)
182 // {
183 // output[j] += input1[i*contiguousBlockSize+j] *
184 // input2[i*contiguousBlockSize+j];
185 // }
186 // }
187 template <typename ValueType>
188 void
189 addVecOverContinuousIndex(const dftfe::uInt numContiguousBlocks,
190 const dftfe::uInt contiguousBlockSize,
191 const ValueType *input1,
192 const ValueType *input2,
193 ValueType *output);
194
195 // Real-Float scaling of Real-vector
196
197
198 // Real double Norm2
199 void
201 const double *x,
202 const dftfe::uInt incx,
203 const MPI_Comm &mpi_communicator,
204 double *result);
205
206
207 // Comples double Norm2
208 void
210 const std::complex<double> *x,
211 const dftfe::uInt incx,
212 const MPI_Comm &mpi_communicator,
213 double *result);
214 // Real dot product
215 void
217 const double *X,
218 const dftfe::uInt INCX,
219 const double *Y,
220 const dftfe::uInt INCY,
221 double *result);
222 // Real dot product
223 void
225 const float *X,
226 const dftfe::uInt INCX,
227 const float *Y,
228 const dftfe::uInt INCY,
229 float *result);
230 // Real dot proeuct with all Reduce call
231 void
233 const double *X,
234 const dftfe::uInt INCX,
235 const double *Y,
236 const dftfe::uInt INCY,
237 const MPI_Comm &mpi_communicator,
238 double *result);
239
240 // Complex dot product
241 void
243 const std::complex<double> *X,
244 const dftfe::uInt INCX,
245 const std::complex<double> *Y,
246 const dftfe::uInt INCY,
247 std::complex<double> *result);
248 // Complex dot product
249 void
251 const std::complex<float> *X,
252 const dftfe::uInt INCX,
253 const std::complex<float> *Y,
254 const dftfe::uInt INCY,
255 std::complex<float> *result);
256 // Complex dot proeuct with all Reduce call
257 void
259 const std::complex<double> *X,
260 const dftfe::uInt INCX,
261 const std::complex<double> *Y,
262 const dftfe::uInt INCY,
263 const MPI_Comm &mpi_communicator,
264 std::complex<double> *result);
265
266
267 // MultiVector Real dot product
268 template <typename ValueType>
269 void
270 MultiVectorXDot(const dftfe::uInt contiguousBlockSize,
271 const dftfe::uInt numContiguousBlocks,
272 const ValueType *X,
273 const ValueType *Y,
274 const ValueType *onesVec,
275 ValueType *tempVector,
276 ValueType *tempResults,
277 ValueType *result);
278
279 // MultiVector Real dot product with all Reduce call
280 template <typename ValueType>
281 void
282 MultiVectorXDot(const dftfe::uInt contiguousBlockSize,
283 const dftfe::uInt numContiguousBlocks,
284 const ValueType *X,
285 const ValueType *Y,
286 const ValueType *onesVec,
287 ValueType *tempVector,
288 ValueType *tempResults,
289 const MPI_Comm &mpi_communicator,
290 ValueType *result);
291
292
293 // Real double Ax+y
294 void
296 const double *alpha,
297 const double *x,
298 const dftfe::uInt incx,
299 double *y,
300 const dftfe::uInt incy);
301
302 // Complex double Ax+y
303 void
305 const std::complex<double> *alpha,
306 const std::complex<double> *x,
307 const dftfe::uInt incx,
308 std::complex<double> *y,
309 const dftfe::uInt incy);
310
311 // Real float Ax+y
312 void
314 const float *alpha,
315 const float *x,
316 const dftfe::uInt incx,
317 float *y,
318 const dftfe::uInt incy);
319
320 // Complex double Ax+y
321 void
323 const std::complex<float> *alpha,
324 const std::complex<float> *x,
325 const dftfe::uInt incx,
326 std::complex<float> *y,
327 const dftfe::uInt incy);
328
329 // Real copy of double data
330 void
332 const double *x,
333 const dftfe::uInt incx,
334 double *y,
335 const dftfe::uInt incy);
336
337 // Complex double copy of data
338 void
340 const std::complex<double> *x,
341 const dftfe::uInt incx,
342 std::complex<double> *y,
343 const dftfe::uInt incy);
344
345 // Real copy of float data
346 void
348 const float *x,
349 const dftfe::uInt incx,
350 float *y,
351 const dftfe::uInt incy);
352
353 // Complex float copy of data
354 void
356 const std::complex<float> *x,
357 const dftfe::uInt incx,
358 std::complex<float> *y,
359 const dftfe::uInt incy);
360
361 // Real double symmetric matrix-vector product
362 void
363 xsymv(const char UPLO,
364 const dftfe::uInt N,
365 const double *alpha,
366 const double *A,
367 const dftfe::uInt LDA,
368 const double *X,
369 const dftfe::uInt INCX,
370 const double *beta,
371 double *C,
372 const dftfe::uInt INCY);
373
374 void
375 xgemmBatched(const char transA,
376 const char transB,
377 const dftfe::uInt m,
378 const dftfe::uInt n,
379 const dftfe::uInt k,
380 const double *alpha,
381 const double *A[],
382 const dftfe::uInt lda,
383 const double *B[],
384 const dftfe::uInt ldb,
385 const double *beta,
386 double *C[],
387 const dftfe::uInt ldc,
388 const dftfe::Int batchCount);
389
390 void
391 xgemmBatched(const char transA,
392 const char transB,
393 const dftfe::uInt m,
394 const dftfe::uInt n,
395 const dftfe::uInt k,
396 const std::complex<double> *alpha,
397 const std::complex<double> *A[],
398 const dftfe::uInt lda,
399 const std::complex<double> *B[],
400 const dftfe::uInt ldb,
401 const std::complex<double> *beta,
402 std::complex<double> *C[],
403 const dftfe::uInt ldc,
404 const dftfe::Int batchCount);
405
406
407 void
408 xgemmBatched(const char transA,
409 const char transB,
410 const dftfe::uInt m,
411 const dftfe::uInt n,
412 const dftfe::uInt k,
413 const float *alpha,
414 const float *A[],
415 const dftfe::uInt lda,
416 const float *B[],
417 const dftfe::uInt ldb,
418 const float *beta,
419 float *C[],
420 const dftfe::uInt ldc,
421 const dftfe::Int batchCount);
422
423 void
424 xgemmBatched(const char transA,
425 const char transB,
426 const dftfe::uInt m,
427 const dftfe::uInt n,
428 const dftfe::uInt k,
429 const std::complex<float> *alpha,
430 const std::complex<float> *A[],
431 const dftfe::uInt lda,
432 const std::complex<float> *B[],
433 const dftfe::uInt ldb,
434 const std::complex<float> *beta,
435 std::complex<float> *C[],
436 const dftfe::uInt ldc,
437 const dftfe::Int batchCount);
438
439
440 void
441 xgemmStridedBatched(const char transA,
442 const char transB,
443 const dftfe::uInt m,
444 const dftfe::uInt n,
445 const dftfe::uInt k,
446 const double *alpha,
447 const double *A,
448 const dftfe::uInt lda,
449 long long int strideA,
450 const double *B,
451 const dftfe::uInt ldb,
452 long long int strideB,
453 const double *beta,
454 double *C,
455 const dftfe::uInt ldc,
456 long long int strideC,
457 const dftfe::Int batchCount);
458
459 void
460 xgemmStridedBatched(const char transA,
461 const char transB,
462 const dftfe::uInt m,
463 const dftfe::uInt n,
464 const dftfe::uInt k,
465 const std::complex<double> *alpha,
466 const std::complex<double> *A,
467 const dftfe::uInt lda,
468 long long int strideA,
469 const std::complex<double> *B,
470 const dftfe::uInt ldb,
471 long long int strideB,
472 const std::complex<double> *beta,
473 std::complex<double> *C,
474 const dftfe::uInt ldc,
475 long long int strideC,
476 const dftfe::Int batchCount);
477
478 void
479 xgemmStridedBatched(const char transA,
480 const char transB,
481 const dftfe::uInt m,
482 const dftfe::uInt n,
483 const dftfe::uInt k,
484 const std::complex<float> *alpha,
485 const std::complex<float> *A,
486 const dftfe::uInt lda,
487 long long int strideA,
488 const std::complex<float> *B,
489 const dftfe::uInt ldb,
490 long long int strideB,
491 const std::complex<float> *beta,
492 std::complex<float> *C,
493 const dftfe::uInt ldc,
494 long long int strideC,
495 const dftfe::Int batchCount);
496
497 void
498 xgemmStridedBatched(const char transA,
499 const char transB,
500 const dftfe::uInt m,
501 const dftfe::uInt n,
502 const dftfe::uInt k,
503 const float *alpha,
504 const float *A,
505 const dftfe::uInt lda,
506 long long int strideA,
507 const float *B,
508 const dftfe::uInt ldb,
509 long long int strideB,
510 const float *beta,
511 float *C,
512 const dftfe::uInt ldc,
513 long long int strideC,
514 const dftfe::Int batchCount);
515
516 template <typename ValueTypeComplex, typename ValueTypeReal>
517 void
519 const ValueTypeComplex *complexArr,
520 ValueTypeReal *realArr,
521 ValueTypeReal *imagArr);
522
523
524 template <typename ValueTypeComplex, typename ValueTypeReal>
525 void
527 const ValueTypeReal *realArr,
528 const ValueTypeReal *imagArr,
529 ValueTypeComplex *complexArr);
530
531 template <typename ValueType1, typename ValueType2>
532 void
534 const ValueType1 *valueType1Arr,
535 ValueType2 *valueType2Arr);
536
537
538 template <typename ValueType1, typename ValueType2>
539 void
541 const dftfe::uInt contiguousBlockSize,
542 const dftfe::uInt numContiguousBlocks,
543 const ValueType1 *copyFromVec,
544 ValueType2 *copyToVecBlock,
545 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
546
547
548 template <typename ValueType1, typename ValueType2>
549 void
551 const dftfe::uInt contiguousBlockSize,
552 const dftfe::uInt numContiguousBlocks,
553 const dftfe::uInt startingVecId,
554 const ValueType1 *copyFromVec,
555 ValueType2 *copyToVecBlock,
556 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
557
558 template <typename ValueType1, typename ValueType2>
559 void
561 const dftfe::uInt contiguousBlockSize,
562 const dftfe::uInt numContiguousBlocks,
563 const ValueType1 *copyFromVecBlock,
564 ValueType2 *copyToVec,
565 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
566
567 template <typename ValueType1, typename ValueType2>
568 void
570 const dftfe::uInt blockSizeFrom,
571 const dftfe::uInt numBlocks,
572 const dftfe::uInt startingId,
573 const ValueType1 *copyFromVec,
574 ValueType2 *copyToVec);
575
576
577 template <typename ValueType1, typename ValueType2>
578 void
580 const dftfe::uInt strideTo,
581 const dftfe::uInt strideFrom,
582 const dftfe::uInt numBlocks,
583 const dftfe::uInt startingToId,
584 const dftfe::uInt startingFromId,
585 const ValueType1 *copyFromVec,
586 ValueType2 *copyToVec);
587
588
589 template <typename ValueType1, typename ValueType2>
590 void
592 const dftfe::uInt blockSizeFrom,
593 const dftfe::uInt numBlocks,
594 const dftfe::uInt startingId,
595 const ValueType1 *copyFromVec,
596 ValueType2 *copyToVec);
597
598 template <typename ValueType1, typename ValueType2>
599 void
600 stridedBlockAxpy(const dftfe::uInt contiguousBlockSize,
601 const dftfe::uInt numContiguousBlocks,
602 const ValueType1 *addFromVec,
603 const ValueType2 *scalingVector,
604 const ValueType2 a,
605 ValueType1 *addToVec);
606
607
608 template <typename ValueType1, typename ValueType2>
609 void
610 stridedBlockAxpBy(const dftfe::uInt contiguousBlockSize,
611 const dftfe::uInt numContiguousBlocks,
612 const ValueType1 *addFromVec,
613 const ValueType2 *scalingVector,
614 const ValueType2 a,
615 const ValueType2 b,
616 ValueType1 *addToVec);
617 template <typename ValueType1, typename ValueType2>
618 void
620 const ValueType2 alpha,
621 const ValueType1 *x,
622 const ValueType2 beta,
623 ValueType1 *y);
624 template <typename ValueType0,
625 typename ValueType1,
626 typename ValueType2,
627 typename ValueType3,
628 typename ValueType4>
629 void
631 const dftfe::uInt n,
632 const ValueType0 alpha,
633 const ValueType1 *A,
634 const ValueType2 *B,
635 const ValueType3 *D,
636 ValueType4 *C);
637
638 template <typename ValueType>
639 void
641 const dftfe::uInt contiguousBlockSize,
642 const dftfe::uInt numContiguousBlocks,
643 const ValueType *addFromVec,
644 ValueType *addToVec,
645 const dftfe::uInt *addToVecStartingContiguousBlockIds);
646
647 template <typename ValueType1, typename ValueType2, typename ValueType3>
648 void
650 const dftfe::uInt contiguousBlockSize,
651 const dftfe::uInt numContiguousBlocks,
652 const ValueType1 a,
653 const ValueType1 *s,
654 const ValueType2 *addFromVec,
655 ValueType3 *addToVec,
656 const dftfe::uInt *addToVecStartingContiguousBlockIds);
657 template <typename ValueType1, typename ValueType2, typename ValueType3>
658 void
660 const dftfe::uInt contiguousBlockSize,
661 const dftfe::uInt numContiguousBlocks,
662 const ValueType1 a,
663 const ValueType2 *addFromVec,
664 ValueType3 *addToVec,
665 const dftfe::uInt *addToVecStartingContiguousBlockIds);
666
667 template <typename ValueType1, typename ValueType2>
668 void
669 stridedBlockScale(const dftfe::uInt contiguousBlockSize,
670 const dftfe::uInt numContiguousBlocks,
671 const ValueType1 a,
672 const ValueType1 *s,
673 ValueType2 *x);
674
675 template <typename ValueType1, typename ValueType2>
676 void
678 const dftfe::uInt contiguousBlockSize,
679 const dftfe::uInt numContiguousBlocks,
680 const ValueType1 a,
681 const ValueType1 *s,
682 const ValueType2 *copyFromVec,
683 ValueType2 *copyToVecBlock,
684 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
685
686 template <typename ValueType>
687 void
688 stridedBlockScaleColumnWise(const dftfe::uInt contiguousBlockSize,
689 const dftfe::uInt numContiguousBlocks,
690 const ValueType *beta,
691 ValueType *x);
692
693 template <typename ValueType>
694 void
696 const dftfe::uInt numContiguousBlocks,
697 const ValueType *x,
698 const ValueType *beta,
699 ValueType *y);
700
701 template <typename ValueType>
702 void
704 const dftfe::uInt contiguousBlockSize,
705 const dftfe::uInt numContiguousBlocks,
706 const ValueType *x,
707 const ValueType *alpha,
708 const ValueType *y,
709 const ValueType *beta,
710 ValueType *z);
711
712 template <typename ValueType1, typename ValueType2>
713 void
714 rightDiagonalScale(const dftfe::uInt numberofVectors,
715 const dftfe::uInt sizeOfVector,
716 ValueType1 *X,
717 ValueType2 *D);
718
719 private:
720 };
721#if defined(DFTFE_WITH_DEVICE)
722# include "Exceptions.h"
723 enum class tensorOpDataType
724 {
725 fp32,
726 tf32,
727 bf16,
728 fp16
729 };
730
731 template <>
732 class BLASWrapper<dftfe::utils::MemorySpace::DEVICE>
733 {
734 public:
735 BLASWrapper();
736
737 template <typename ValueType1, typename ValueType2>
738 static void
739 copyValueType1ArrToValueType2ArrDeviceCall(
740 const dftfe::uInt size,
741 const ValueType1 *valueType1Arr,
742 ValueType2 *valueType2Arr,
744
745 template <typename ValueType>
746 void
747 hadamardProduct(const dftfe::uInt m,
748 const ValueType *X,
749 const ValueType *Y,
750 ValueType *output);
751
752 template <typename ValueType>
753 void
754 hadamardProductWithConj(const dftfe::uInt m,
755 const ValueType *X,
756 const ValueType *Y,
757 ValueType *output);
758
759 // Real-Single Precision GEMM
760 void
761 xgemm(const char transA,
762 const char transB,
763 const dftfe::uInt m,
764 const dftfe::uInt n,
765 const dftfe::uInt k,
766 const float *alpha,
767 const float *A,
768 const dftfe::uInt lda,
769 const float *B,
770 const dftfe::uInt ldb,
771 const float *beta,
772 float *C,
773 const dftfe::uInt ldc);
774 // Complex-Single Precision GEMM
775 void
776 xgemm(const char transA,
777 const char transB,
778 const dftfe::uInt m,
779 const dftfe::uInt n,
780 const dftfe::uInt k,
781 const std::complex<float> *alpha,
782 const std::complex<float> *A,
783 const dftfe::uInt lda,
784 const std::complex<float> *B,
785 const dftfe::uInt ldb,
786 const std::complex<float> *beta,
787 std::complex<float> *C,
788 const dftfe::uInt ldc);
789
790 // Real-double precison GEMM
791 void
792 xgemm(const char transA,
793 const char transB,
794 const dftfe::uInt m,
795 const dftfe::uInt n,
796 const dftfe::uInt k,
797 const double *alpha,
798 const double *A,
799 const dftfe::uInt lda,
800 const double *B,
801 const dftfe::uInt ldb,
802 const double *beta,
803 double *C,
804 const dftfe::uInt ldc);
805
806
807 // Complex-double precision GEMM
808 void
809 xgemm(const char transA,
810 const char transB,
811 const dftfe::uInt m,
812 const dftfe::uInt n,
813 const dftfe::uInt k,
814 const std::complex<double> *alpha,
815 const std::complex<double> *A,
816 const dftfe::uInt lda,
817 const std::complex<double> *B,
818 const dftfe::uInt ldb,
819 const std::complex<double> *beta,
820 std::complex<double> *C,
821 const dftfe::uInt ldc);
822
823
824 void
825 xgemv(const char transA,
826 const dftfe::uInt m,
827 const dftfe::uInt n,
828 const double *alpha,
829 const double *A,
830 const dftfe::uInt lda,
831 const double *x,
832 const dftfe::uInt incx,
833 const double *beta,
834 double *y,
835 const dftfe::uInt incy);
836
837 void
838 xgemv(const char transA,
839 const dftfe::uInt m,
840 const dftfe::uInt n,
841 const float *alpha,
842 const float *A,
843 const dftfe::uInt lda,
844 const float *x,
845 const dftfe::uInt incx,
846 const float *beta,
847 float *y,
848 const dftfe::uInt incy);
849
850 void
851 xgemv(const char transA,
852 const dftfe::uInt m,
853 const dftfe::uInt n,
854 const std::complex<double> *alpha,
855 const std::complex<double> *A,
856 const dftfe::uInt lda,
857 const std::complex<double> *x,
858 const dftfe::uInt incx,
859 const std::complex<double> *beta,
860 std::complex<double> *y,
861 const dftfe::uInt incy);
862
863 void
864 xgemv(const char transA,
865 const dftfe::uInt m,
866 const dftfe::uInt n,
867 const std::complex<float> *alpha,
868 const std::complex<float> *A,
869 const dftfe::uInt lda,
870 const std::complex<float> *x,
871 const dftfe::uInt incx,
872 const std::complex<float> *beta,
873 std::complex<float> *y,
874 const dftfe::uInt incy);
875
876 template <typename ValueType>
877 void
878 addVecOverContinuousIndex(const dftfe::uInt numContiguousBlocks,
879 const dftfe::uInt contiguousBlockSize,
880 const ValueType *input1,
881 const ValueType *input2,
882 ValueType *output);
883
884
885
886 template <typename ValueType1, typename ValueType2>
887 void
888 xscal(ValueType1 *x, const ValueType2 alpha, const dftfe::uInt n);
889
890
891
892 // Real double Norm2
893 void
894 xnrm2(const dftfe::uInt n,
895 const double *x,
896 const dftfe::uInt incx,
897 const MPI_Comm &mpi_communicator,
898 double *result);
899
900
901 // Complex double Norm2
902 void
903 xnrm2(const dftfe::uInt n,
904 const std::complex<double> *x,
905 const dftfe::uInt incx,
906 const MPI_Comm &mpi_communicator,
907 double *result);
908
909 // Real dot product
910 void
911 xdot(const dftfe::uInt N,
912 const double *X,
913 const dftfe::uInt INCX,
914 const double *Y,
915 const dftfe::uInt INCY,
916 double *result);
917 // Real dot product
918 void
919 xdot(const dftfe::uInt N,
920 const float *X,
921 const dftfe::uInt INCX,
922 const float *Y,
923 const dftfe::uInt INCY,
924 float *result);
925 //
926 // Real dot product
927 void
928 xdot(const dftfe::uInt N,
929 const double *X,
930 const dftfe::uInt INCX,
931 const double *Y,
932 const dftfe::uInt INCY,
933 const MPI_Comm &mpi_communicator,
934 double *result);
935
936 // Complex dot product
937 void
938 xdot(const dftfe::uInt N,
939 const std::complex<double> *X,
940 const dftfe::uInt INCX,
941 const std::complex<double> *Y,
942 const dftfe::uInt INCY,
943 std::complex<double> *result);
944 // Complex dot product
945 void
946 xdot(const dftfe::uInt N,
947 const std::complex<float> *X,
948 const dftfe::uInt INCX,
949 const std::complex<float> *Y,
950 const dftfe::uInt INCY,
951 std::complex<float> *result);
952 // Complex dot product
953 void
954 xdot(const dftfe::uInt N,
955 const std::complex<double> *X,
956 const dftfe::uInt INCX,
957 const std::complex<double> *Y,
958 const dftfe::uInt INCY,
959 const MPI_Comm &mpi_communicator,
960 std::complex<double> *result);
961
962
963 template <typename ValueType>
964 void
965 MultiVectorXDot(const dftfe::uInt contiguousBlockSize,
966 const dftfe::uInt numContiguousBlocks,
967 const ValueType *X,
968 const ValueType *Y,
969 const ValueType *onesVec,
970 ValueType *tempVector,
971 ValueType *tempResults,
972 ValueType *result);
973
974 template <typename ValueType>
975 void
976 MultiVectorXDot(const dftfe::uInt contiguousBlockSize,
977 const dftfe::uInt numContiguousBlocks,
978 const ValueType *X,
979 const ValueType *Y,
980 const ValueType *onesVec,
981 ValueType *tempVector,
982 ValueType *tempResults,
983 const MPI_Comm &mpi_communicator,
984 ValueType *result);
985
986 // Real double Ax+y
987 void
988 xaxpy(const dftfe::uInt n,
989 const double *alpha,
990 const double *x,
991 const dftfe::uInt incx,
992 double *y,
993 const dftfe::uInt incy);
994
995 // Complex double Ax+y
996 void
997 xaxpy(const dftfe::uInt n,
998 const std::complex<double> *alpha,
999 const std::complex<double> *x,
1000 const dftfe::uInt incx,
1001 std::complex<double> *y,
1002 const dftfe::uInt incy);
1003
1004 // Real copy of double data
1005 void
1006 xcopy(const dftfe::uInt n,
1007 const double *x,
1008 const dftfe::uInt incx,
1009 double *y,
1010 const dftfe::uInt incy);
1011
1012 // Complex double copy of data
1013 void
1014 xcopy(const dftfe::uInt n,
1015 const std::complex<double> *x,
1016 const dftfe::uInt incx,
1017 std::complex<double> *y,
1018 const dftfe::uInt incy);
1019
1020 // Real copy of float data
1021 void
1022 xcopy(const dftfe::uInt n,
1023 const float *x,
1024 const dftfe::uInt incx,
1025 float *y,
1026 const dftfe::uInt incy);
1027
1028 // Complex float copy of data
1029 void
1030 xcopy(const dftfe::uInt n,
1031 const std::complex<float> *x,
1032 const dftfe::uInt incx,
1033 std::complex<float> *y,
1034 const dftfe::uInt incy);
1035
1036 // Real double symmetric matrix-vector product
1037 void
1038 xsymv(const char UPLO,
1039 const dftfe::uInt N,
1040 const double *alpha,
1041 const double *A,
1042 const dftfe::uInt LDA,
1043 const double *X,
1044 const dftfe::uInt INCX,
1045 const double *beta,
1046 double *C,
1047 const dftfe::uInt INCY);
1048
1049 void
1050 xgemmBatched(const char transA,
1051 const char transB,
1052 const dftfe::uInt m,
1053 const dftfe::uInt n,
1054 const dftfe::uInt k,
1055 const double *alpha,
1056 const double *A[],
1057 const dftfe::uInt lda,
1058 const double *B[],
1059 const dftfe::uInt ldb,
1060 const double *beta,
1061 double *C[],
1062 const dftfe::uInt ldc,
1063 const dftfe::Int batchCount);
1064
1065 void
1066 xgemmBatched(const char transA,
1067 const char transB,
1068 const dftfe::uInt m,
1069 const dftfe::uInt n,
1070 const dftfe::uInt k,
1071 const std::complex<double> *alpha,
1072 const std::complex<double> *A[],
1073 const dftfe::uInt lda,
1074 const std::complex<double> *B[],
1075 const dftfe::uInt ldb,
1076 const std::complex<double> *beta,
1077 std::complex<double> *C[],
1078 const dftfe::uInt ldc,
1079 const dftfe::Int batchCount);
1080
1081 void
1082 xgemmBatched(const char transA,
1083 const char transB,
1084 const dftfe::uInt m,
1085 const dftfe::uInt n,
1086 const dftfe::uInt k,
1087 const float *alpha,
1088 const float *A[],
1089 const dftfe::uInt lda,
1090 const float *B[],
1091 const dftfe::uInt ldb,
1092 const float *beta,
1093 float *C[],
1094 const dftfe::uInt ldc,
1095 const dftfe::Int batchCount);
1096
1097 void
1098 xgemmBatched(const char transA,
1099 const char transB,
1100 const dftfe::uInt m,
1101 const dftfe::uInt n,
1102 const dftfe::uInt k,
1103 const std::complex<float> *alpha,
1104 const std::complex<float> *A[],
1105 const dftfe::uInt lda,
1106 const std::complex<float> *B[],
1107 const dftfe::uInt ldb,
1108 const std::complex<float> *beta,
1109 std::complex<float> *C[],
1110 const dftfe::uInt ldc,
1111 const dftfe::Int batchCount);
1112
1113 void
1114 xgemmStridedBatched(const char transA,
1115 const char transB,
1116 const dftfe::uInt m,
1117 const dftfe::uInt n,
1118 const dftfe::uInt k,
1119 const double *alpha,
1120 const double *A,
1121 const dftfe::uInt lda,
1122 long long int strideA,
1123 const double *B,
1124 const dftfe::uInt ldb,
1125 long long int strideB,
1126 const double *beta,
1127 double *C,
1128 const dftfe::uInt ldc,
1129 long long int strideC,
1130 const dftfe::Int batchCount);
1131
1132 void
1133 xgemmStridedBatched(const char transA,
1134 const char transB,
1135 const dftfe::uInt m,
1136 const dftfe::uInt n,
1137 const dftfe::uInt k,
1138 const std::complex<double> *alpha,
1139 const std::complex<double> *A,
1140 const dftfe::uInt lda,
1141 long long int strideA,
1142 const std::complex<double> *B,
1143 const dftfe::uInt ldb,
1144 long long int strideB,
1145 const std::complex<double> *beta,
1146 std::complex<double> *C,
1147 const dftfe::uInt ldc,
1148 long long int strideC,
1149 const dftfe::Int batchCount);
1150
1151 void
1152 xgemmStridedBatched(const char transA,
1153 const char transB,
1154 const dftfe::uInt m,
1155 const dftfe::uInt n,
1156 const dftfe::uInt k,
1157 const std::complex<float> *alpha,
1158 const std::complex<float> *A,
1159 const dftfe::uInt lda,
1160 long long int strideA,
1161 const std::complex<float> *B,
1162 const dftfe::uInt ldb,
1163 long long int strideB,
1164 const std::complex<float> *beta,
1165 std::complex<float> *C,
1166 const dftfe::uInt ldc,
1167 long long int strideC,
1168 const dftfe::Int batchCount);
1169
1170 void
1171 xgemmStridedBatched(const char transA,
1172 const char transB,
1173 const dftfe::uInt m,
1174 const dftfe::uInt n,
1175 const dftfe::uInt k,
1176 const float *alpha,
1177 const float *A,
1178 const dftfe::uInt lda,
1179 long long int strideA,
1180 const float *B,
1181 const dftfe::uInt ldb,
1182 long long int strideB,
1183 const float *beta,
1184 float *C,
1185 const dftfe::uInt ldc,
1186 long long int strideC,
1187 const dftfe::Int batchCount);
1188
1189 template <typename ValueTypeComplex, typename ValueTypeReal>
1190 void
1191 copyComplexArrToRealArrs(const dftfe::uInt size,
1192 const ValueTypeComplex *complexArr,
1193 ValueTypeReal *realArr,
1194 ValueTypeReal *imagArr);
1195
1196
1197 template <typename ValueTypeComplex, typename ValueTypeReal>
1198 void
1199 copyRealArrsToComplexArr(const dftfe::uInt size,
1200 const ValueTypeReal *realArr,
1201 const ValueTypeReal *imagArr,
1202 ValueTypeComplex *complexArr);
1203
1204 template <typename ValueType1, typename ValueType2>
1205 void
1206 copyValueType1ArrToValueType2Arr(const dftfe::uInt size,
1207 const ValueType1 *valueType1Arr,
1208 ValueType2 *valueType2Arr);
1209
1210
1211 template <typename ValueType1, typename ValueType2>
1212 void
1213 stridedCopyToBlock(
1214 const dftfe::uInt contiguousBlockSize,
1215 const dftfe::uInt numContiguousBlocks,
1216 const ValueType1 *copyFromVec,
1217 ValueType2 *copyToVecBlock,
1218 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
1219
1220 template <typename ValueType1, typename ValueType2>
1221 void
1222 stridedCopyToBlock(
1223 const dftfe::uInt contiguousBlockSize,
1224 const dftfe::uInt numContiguousBlocks,
1225 const dftfe::uInt startingVecId,
1226 const ValueType1 *copyFromVec,
1227 ValueType2 *copyToVecBlock,
1228 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
1229
1230
1231 template <typename ValueType1, typename ValueType2>
1232 void
1233 stridedCopyFromBlock(
1234 const dftfe::uInt contiguousBlockSize,
1235 const dftfe::uInt numContiguousBlocks,
1236 const ValueType1 *copyFromVecBlock,
1237 ValueType2 *copyToVec,
1238 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
1239
1240 template <typename ValueType1, typename ValueType2>
1241 void
1242 stridedCopyToBlockConstantStride(const dftfe::uInt blockSizeTo,
1243 const dftfe::uInt blockSizeFrom,
1244 const dftfe::uInt numBlocks,
1245 const dftfe::uInt startingId,
1246 const ValueType1 *copyFromVec,
1247 ValueType2 *copyToVec);
1248
1249
1250 template <typename ValueType1, typename ValueType2>
1251 void
1252 stridedCopyConstantStride(const dftfe::uInt blockSize,
1253 const dftfe::uInt strideTo,
1254 const dftfe::uInt strideFrom,
1255 const dftfe::uInt numBlocks,
1256 const dftfe::uInt startingToId,
1257 const dftfe::uInt startingFromId,
1258 const ValueType1 *copyFromVec,
1259 ValueType2 *copyToVec);
1260
1261
1262 template <typename ValueType1, typename ValueType2>
1263 void
1264 stridedCopyFromBlockConstantStride(const dftfe::uInt blockSizeTo,
1265 const dftfe::uInt blockSizeFrom,
1266 const dftfe::uInt numBlocks,
1267 const dftfe::uInt startingId,
1268 const ValueType1 *copyFromVec,
1269 ValueType2 *copyToVec);
1270 template <typename ValueType1, typename ValueType2>
1271 void
1272 axpby(const dftfe::uInt n,
1273 const ValueType2 alpha,
1274 const ValueType1 *x,
1275 const ValueType2 beta,
1276 ValueType1 *y);
1277
1278 template <typename ValueType1, typename ValueType2>
1279 void
1280 stridedBlockAxpy(const dftfe::uInt contiguousBlockSize,
1281 const dftfe::uInt numContiguousBlocks,
1282 const ValueType1 *addFromVec,
1283 const ValueType2 *scalingVector,
1284 const ValueType2 a,
1285 ValueType1 *addToVec);
1286 template <typename ValueType1, typename ValueType2>
1287 void
1288 stridedBlockAxpBy(const dftfe::uInt contiguousBlockSize,
1289 const dftfe::uInt numContiguousBlocks,
1290 const ValueType1 *addFromVec,
1291 const ValueType2 *scalingVector,
1292 const ValueType2 a,
1293 const ValueType2 b,
1294 ValueType1 *addToVec);
1295
1296 template <typename ValueType0,
1297 typename ValueType1,
1298 typename ValueType2,
1299 typename ValueType3,
1300 typename ValueType4>
1301 void
1302 ApaBD(const dftfe::uInt m,
1303 const dftfe::uInt n,
1304 const ValueType0 alpha,
1305 const ValueType1 *A,
1306 const ValueType2 *B,
1307 const ValueType3 *D,
1308 ValueType4 *C);
1309
1310
1311 template <typename ValueType>
1312 void
1313 axpyStridedBlockAtomicAdd(
1314 const dftfe::uInt contiguousBlockSize,
1315 const dftfe::uInt numContiguousBlocks,
1316 const ValueType *addFromVec,
1317 ValueType *addToVec,
1318 const dftfe::uInt *addToVecStartingContiguousBlockIds);
1319
1320 template <typename ValueType1, typename ValueType2, typename ValueType3>
1321 void
1322 axpyStridedBlockAtomicAdd(
1323 const dftfe::uInt contiguousBlockSize,
1324 const dftfe::uInt numContiguousBlocks,
1325 const ValueType1 a,
1326 const ValueType1 *s,
1327 const ValueType2 *addFromVec,
1328 ValueType3 *addToVec,
1329 const dftfe::uInt *addToVecStartingContiguousBlockIds);
1330 template <typename ValueType1, typename ValueType2, typename ValueType3>
1331 void
1332 axpyStridedBlockAtomicAdd(
1333 const dftfe::uInt contiguousBlockSize,
1334 const dftfe::uInt numContiguousBlocks,
1335 const ValueType1 a,
1336 const ValueType2 *addFromVec,
1337 ValueType3 *addToVec,
1338 const dftfe::uInt *addToVecStartingContiguousBlockIds);
1339
1340 template <typename ValueType1, typename ValueType2>
1341 void
1342 stridedBlockScale(const dftfe::uInt contiguousBlockSize,
1343 const dftfe::uInt numContiguousBlocks,
1344 const ValueType1 a,
1345 const ValueType1 *s,
1346 ValueType2 *x);
1347 template <typename ValueType1, typename ValueType2>
1348 void
1349 stridedBlockScaleCopy(
1350 const dftfe::uInt contiguousBlockSize,
1351 const dftfe::uInt numContiguousBlocks,
1352 const ValueType1 a,
1353 const ValueType1 *s,
1354 const ValueType2 *copyFromVec,
1355 ValueType2 *copyToVecBlock,
1356 const dftfe::uInt *copyFromVecStartingContiguousBlockIds);
1357
1358 template <typename ValueType>
1359 void
1360 stridedBlockScaleColumnWise(const dftfe::uInt contiguousBlockSize,
1361 const dftfe::uInt numContiguousBlocks,
1362 const ValueType *beta,
1363 ValueType *x);
1364
1365 template <typename ValueType>
1366 void
1367 stridedBlockScaleAndAddColumnWise(const dftfe::uInt contiguousBlockSize,
1368 const dftfe::uInt numContiguousBlocks,
1369 const ValueType *x,
1370 const ValueType *beta,
1371 ValueType *y);
1372
1373 template <typename ValueType>
1374 void
1375 stridedBlockScaleAndAddTwoVecColumnWise(
1376 const dftfe::uInt contiguousBlockSize,
1377 const dftfe::uInt numContiguousBlocks,
1378 const ValueType *x,
1379 const ValueType *alpha,
1380 const ValueType *y,
1381 const ValueType *beta,
1382 ValueType *z);
1383
1384 template <typename ValueType1, typename ValueType2>
1385 void
1386 rightDiagonalScale(const dftfe::uInt numberofVectors,
1387 const dftfe::uInt sizeOfVector,
1388 ValueType1 *X,
1389 ValueType2 *D);
1390
1392 getDeviceBlasHandle();
1393
1394
1395 template <typename ValueType1, typename ValueType2>
1396 void
1397 copyBlockDiagonalValueType1OffDiagonalValueType2FromValueType1Arr(
1398 const dftfe::uInt B,
1399 const dftfe::uInt DRem,
1400 const dftfe::uInt D,
1401 const ValueType1 *valueType1SrcArray,
1402 ValueType1 *valueType1DstArray,
1403 ValueType2 *valueType2DstArray);
1404
1405 void
1406 setTensorOpDataType(tensorOpDataType opType)
1407 {
1408 d_opType = opType;
1409 }
1410
1412 setStream(dftfe::utils::deviceStream_t streamId);
1413
1414 inline static dftfe::utils::deviceBlasHandle_t d_deviceBlasHandle;
1415 inline static dftfe::utils::deviceStream_t d_streamId;
1416
1417 private:
1418# ifdef DFTFE_WITH_DEVICE_AMD
1419 void
1420 initialize();
1421# endif
1422
1423 /// storage for deviceblas handle
1424 tensorOpDataType d_opType;
1425
1427 create();
1428
1430 destroy();
1431 };
1432#endif
1433
1434 } // end of namespace linearAlgebra
1435
1436} // end of namespace dftfe
1437
1438
1439#endif // BLASWrapper_h
void xaxpy(const dftfe::uInt n, const std::complex< float > *alpha, const std::complex< float > *x, const dftfe::uInt incx, std::complex< float > *y, const dftfe::uInt incy)
void hadamardProduct(const dftfe::uInt m, const ValueType *X, const ValueType *Y, ValueType *output)
void stridedCopyToBlock(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const dftfe::uInt startingVecId, const ValueType1 *copyFromVec, ValueType2 *copyToVecBlock, const dftfe::uInt *copyFromVecStartingContiguousBlockIds)
void xgemmStridedBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const float *alpha, const float *A, const dftfe::uInt lda, long long int strideA, const float *B, const dftfe::uInt ldb, long long int strideB, const float *beta, float *C, const dftfe::uInt ldc, long long int strideC, const dftfe::Int batchCount)
void xgemmStridedBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< float > *alpha, const std::complex< float > *A, const dftfe::uInt lda, long long int strideA, const std::complex< float > *B, const dftfe::uInt ldb, long long int strideB, const std::complex< float > *beta, std::complex< float > *C, const dftfe::uInt ldc, long long int strideC, const dftfe::Int batchCount)
void xcopy(const dftfe::uInt n, const float *x, const dftfe::uInt incx, float *y, const dftfe::uInt incy)
void xgemmBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< float > *alpha, const std::complex< float > *A[], const dftfe::uInt lda, const std::complex< float > *B[], const dftfe::uInt ldb, const std::complex< float > *beta, std::complex< float > *C[], const dftfe::uInt ldc, const dftfe::Int batchCount)
void xcopy(const dftfe::uInt n, const std::complex< float > *x, const dftfe::uInt incx, std::complex< float > *y, const dftfe::uInt incy)
void stridedBlockScaleCopy(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 a, const ValueType1 *s, const ValueType2 *copyFromVec, ValueType2 *copyToVecBlock, const dftfe::uInt *copyFromVecStartingContiguousBlockIds)
void xcopy(const dftfe::uInt n, const std::complex< double > *x, const dftfe::uInt incx, std::complex< double > *y, const dftfe::uInt incy)
void xdot(const dftfe::uInt N, const double *X, const dftfe::uInt INCX, const double *Y, const dftfe::uInt INCY, double *result)
void axpby(const dftfe::uInt n, const ValueType2 alpha, const ValueType1 *x, const ValueType2 beta, ValueType1 *y)
void xdot(const dftfe::uInt N, const std::complex< double > *X, const dftfe::uInt INCX, const std::complex< double > *Y, const dftfe::uInt INCY, std::complex< double > *result)
void xdot(const dftfe::uInt N, const float *X, const dftfe::uInt INCX, const float *Y, const dftfe::uInt INCY, float *result)
void copyValueType1ArrToValueType2Arr(const dftfe::uInt size, const ValueType1 *valueType1Arr, ValueType2 *valueType2Arr)
void xgemv(const char transA, const dftfe::uInt m, const dftfe::uInt n, const std::complex< double > *alpha, const std::complex< double > *A, const dftfe::uInt lda, const std::complex< double > *x, const dftfe::uInt incx, const std::complex< double > *beta, std::complex< double > *y, const dftfe::uInt incy)
void xgemv(const char transA, const dftfe::uInt m, const dftfe::uInt n, const float *alpha, const float *A, const dftfe::uInt lda, const float *x, const dftfe::uInt incx, const float *beta, float *y, const dftfe::uInt incy)
void MultiVectorXDot(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *X, const ValueType *Y, const ValueType *onesVec, ValueType *tempVector, ValueType *tempResults, ValueType *result)
void xnrm2(const dftfe::uInt n, const std::complex< double > *x, const dftfe::uInt incx, const MPI_Comm &mpi_communicator, double *result)
void axpyStridedBlockAtomicAdd(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 a, const ValueType2 *addFromVec, ValueType3 *addToVec, const dftfe::uInt *addToVecStartingContiguousBlockIds)
void axpyStridedBlockAtomicAdd(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 a, const ValueType1 *s, const ValueType2 *addFromVec, ValueType3 *addToVec, const dftfe::uInt *addToVecStartingContiguousBlockIds)
void xnrm2(const dftfe::uInt n, const double *x, const dftfe::uInt incx, const MPI_Comm &mpi_communicator, double *result)
void xaxpy(const dftfe::uInt n, const double *alpha, const double *x, const dftfe::uInt incx, double *y, const dftfe::uInt incy)
void xgemmStridedBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const double *alpha, const double *A, const dftfe::uInt lda, long long int strideA, const double *B, const dftfe::uInt ldb, long long int strideB, const double *beta, double *C, const dftfe::uInt ldc, long long int strideC, const dftfe::Int batchCount)
void axpyStridedBlockAtomicAdd(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *addFromVec, ValueType *addToVec, const dftfe::uInt *addToVecStartingContiguousBlockIds)
void copyComplexArrToRealArrs(const dftfe::uInt size, const ValueTypeComplex *complexArr, ValueTypeReal *realArr, ValueTypeReal *imagArr)
void stridedBlockScale(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 a, const ValueType1 *s, ValueType2 *x)
void xgemmStridedBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< double > *alpha, const std::complex< double > *A, const dftfe::uInt lda, long long int strideA, const std::complex< double > *B, const dftfe::uInt ldb, long long int strideB, const std::complex< double > *beta, std::complex< double > *C, const dftfe::uInt ldc, long long int strideC, const dftfe::Int batchCount)
void xgemm(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const float *alpha, const float *A, const dftfe::uInt lda, const float *B, const dftfe::uInt ldb, const float *beta, float *C, const dftfe::uInt ldc)
void xgemv(const char transA, const dftfe::uInt m, const dftfe::uInt n, const std::complex< float > *alpha, const std::complex< float > *A, const dftfe::uInt lda, const std::complex< float > *x, const dftfe::uInt incx, const std::complex< float > *beta, std::complex< float > *y, const dftfe::uInt incy)
void xaxpy(const dftfe::uInt n, const std::complex< double > *alpha, const std::complex< double > *x, const dftfe::uInt incx, std::complex< double > *y, const dftfe::uInt incy)
void stridedBlockAxpBy(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 *addFromVec, const ValueType2 *scalingVector, const ValueType2 a, const ValueType2 b, ValueType1 *addToVec)
void xscal(ValueType1 *x, const ValueType2 alpha, const dftfe::uInt n)
void stridedCopyToBlock(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 *copyFromVec, ValueType2 *copyToVecBlock, const dftfe::uInt *copyFromVecStartingContiguousBlockIds)
void stridedCopyFromBlock(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 *copyFromVecBlock, ValueType2 *copyToVec, const dftfe::uInt *copyFromVecStartingContiguousBlockIds)
void xdot(const dftfe::uInt N, const std::complex< double > *X, const dftfe::uInt INCX, const std::complex< double > *Y, const dftfe::uInt INCY, const MPI_Comm &mpi_communicator, std::complex< double > *result)
void stridedBlockAxpy(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType1 *addFromVec, const ValueType2 *scalingVector, const ValueType2 a, ValueType1 *addToVec)
void stridedBlockScaleAndAddTwoVecColumnWise(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *x, const ValueType *alpha, const ValueType *y, const ValueType *beta, ValueType *z)
void xgemm(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< double > *alpha, const std::complex< double > *A, const dftfe::uInt lda, const std::complex< double > *B, const dftfe::uInt ldb, const std::complex< double > *beta, std::complex< double > *C, const dftfe::uInt ldc)
void hadamardProductWithConj(const dftfe::uInt m, const ValueType *X, const ValueType *Y, ValueType *output)
void xcopy(const dftfe::uInt n, const double *x, const dftfe::uInt incx, double *y, const dftfe::uInt incy)
void MultiVectorXDot(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *X, const ValueType *Y, const ValueType *onesVec, ValueType *tempVector, ValueType *tempResults, const MPI_Comm &mpi_communicator, ValueType *result)
void xaxpy(const dftfe::uInt n, const float *alpha, const float *x, const dftfe::uInt incx, float *y, const dftfe::uInt incy)
void stridedBlockScaleColumnWise(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *beta, ValueType *x)
void ApaBD(const dftfe::uInt m, const dftfe::uInt n, const ValueType0 alpha, const ValueType1 *A, const ValueType2 *B, const ValueType3 *D, ValueType4 *C)
void stridedCopyFromBlockConstantStride(const dftfe::uInt blockSizeTo, const dftfe::uInt blockSizeFrom, const dftfe::uInt numBlocks, const dftfe::uInt startingId, const ValueType1 *copyFromVec, ValueType2 *copyToVec)
void xdot(const dftfe::uInt N, const std::complex< float > *X, const dftfe::uInt INCX, const std::complex< float > *Y, const dftfe::uInt INCY, std::complex< float > *result)
void stridedBlockScaleAndAddColumnWise(const dftfe::uInt contiguousBlockSize, const dftfe::uInt numContiguousBlocks, const ValueType *x, const ValueType *beta, ValueType *y)
void xgemm(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const double *alpha, const double *A, const dftfe::uInt lda, const double *B, const dftfe::uInt ldb, const double *beta, double *C, const dftfe::uInt ldc)
void stridedCopyConstantStride(const dftfe::uInt blockSize, const dftfe::uInt strideTo, const dftfe::uInt strideFrom, const dftfe::uInt numBlocks, const dftfe::uInt startingToId, const dftfe::uInt startingFromId, const ValueType1 *copyFromVec, ValueType2 *copyToVec)
void xgemv(const char transA, const dftfe::uInt m, const dftfe::uInt n, const double *alpha, const double *A, const dftfe::uInt lda, const double *x, const dftfe::uInt incx, const double *beta, double *y, const dftfe::uInt incy)
void addVecOverContinuousIndex(const dftfe::uInt numContiguousBlocks, const dftfe::uInt contiguousBlockSize, const ValueType *input1, const ValueType *input2, ValueType *output)
void xgemmBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const double *alpha, const double *A[], const dftfe::uInt lda, const double *B[], const dftfe::uInt ldb, const double *beta, double *C[], const dftfe::uInt ldc, const dftfe::Int batchCount)
void xsymv(const char UPLO, const dftfe::uInt N, const double *alpha, const double *A, const dftfe::uInt LDA, const double *X, const dftfe::uInt INCX, const double *beta, double *C, const dftfe::uInt INCY)
void stridedCopyToBlockConstantStride(const dftfe::uInt blockSizeTo, const dftfe::uInt blockSizeFrom, const dftfe::uInt numBlocks, const dftfe::uInt startingId, const ValueType1 *copyFromVec, ValueType2 *copyToVec)
void xgemmBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const float *alpha, const float *A[], const dftfe::uInt lda, const float *B[], const dftfe::uInt ldb, const float *beta, float *C[], const dftfe::uInt ldc, const dftfe::Int batchCount)
void xdot(const dftfe::uInt N, const double *X, const dftfe::uInt INCX, const double *Y, const dftfe::uInt INCY, const MPI_Comm &mpi_communicator, double *result)
void rightDiagonalScale(const dftfe::uInt numberofVectors, const dftfe::uInt sizeOfVector, ValueType1 *X, ValueType2 *D)
void xgemm(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< float > *alpha, const std::complex< float > *A, const dftfe::uInt lda, const std::complex< float > *B, const dftfe::uInt ldb, const std::complex< float > *beta, std::complex< float > *C, const dftfe::uInt ldc)
void copyRealArrsToComplexArr(const dftfe::uInt size, const ValueTypeReal *realArr, const ValueTypeReal *imagArr, ValueTypeComplex *complexArr)
void xgemmBatched(const char transA, const char transB, const dftfe::uInt m, const dftfe::uInt n, const dftfe::uInt k, const std::complex< double > *alpha, const std::complex< double > *A[], const dftfe::uInt lda, const std::complex< double > *B[], const dftfe::uInt ldb, const std::complex< double > *beta, std::complex< double > *C[], const dftfe::uInt ldc, const dftfe::Int batchCount)
Definition BLASWrapper.h:35
Definition BLASWrapper.h:33
cudaStream_t deviceStream_t
Definition DeviceTypeConfig.cu.h:27
cublasHandle_t deviceBlasHandle_t
Definition DeviceTypeConfig.cu.h:36
@ HOST
Definition MemorySpaceType.h:34
@ DEVICE
Definition MemorySpaceType.h:36
cublasStatus_t deviceBlasStatus_t
Definition DeviceTypeConfig.cu.h:38
static cudaStream_t defaultStream
Definition DeviceTypeConfig.cu.h:62
Definition pseudoPotentialToDftfeConverter.cc:34
std::uint32_t uInt
Definition TypeConfig.h:10
@ LDA
Definition ExcSSDFunctionalBaseClass.h:35
std::int32_t Int
Definition TypeConfig.h:11