DFT-FE 1.1.0-pre
Density Functional Theory With Finite-Elements
Loading...
Searching...
No Matches
deviceDirectCCLWrapper.h
Go to the documentation of this file.
1// ---------------------------------------------------------------------
2//
3// Copyright (c) 2017-2025 The Regents of the University of Michigan and DFT-FE
4// authors.
5//
6// This file is part of the DFT-FE code.
7//
8// The DFT-FE code is free software; you can use it, redistribute
9// it, and/or modify it under the terms of the GNU Lesser General
10// Public License as published by the Free Software Foundation; either
11// version 2.1 of the License, or (at your option) any later version.
12// The full text of the license can be found in the file LICENSE at
13// the top level of the DFT-FE distribution.
14//
15// ---------------------------------------------------------------------
16// @author Sambit Das, David M. Rogers
17
18#if defined(DFTFE_WITH_DEVICE)
19# ifndef deviceDirectCCLWrapper_h
20# define deviceDirectCCLWrapper_h
21
22# include <complex>
23# include <mpi.h>
24# include <DeviceTypeConfig.h>
25
26# if defined(DFTFE_WITH_CUDA_NCCL)
27# include <nccl.h>
28# include <DeviceTypeConfig.h>
29# elif defined(DFTFE_WITH_HIP_RCCL)
30# include <rccl.h>
31# include <DeviceTypeConfig.h>
32# endif
33
34namespace dftfe
35{
36 namespace utils
37 {
38# if defined(DFTFE_WITH_CUDA_NCCL) || defined(DFTFE_WITH_HIP_RCCL)
39# define NCCLCHECK(cmd) \
40 do \
41 { \
42 ncclResult_t r = cmd; \
43 if (r != ncclSuccess) \
44 { \
45 printf("Failed, NCCL error %s:%d '%s'\n", \
46 __FILE__, \
47 __LINE__, \
48 ncclGetErrorString(r)); \
49 exit(EXIT_FAILURE); \
50 } \
51 } \
52 while (0)
53# endif
54 /**
55 * @brief Wrapper class for Device Direct collective communications library.
56 * Adapted from
57 * https://code.ornl.gov/99R/olcf-cookbook/-/blob/develop/comms/nccl_allreduce.rst
58 *
59 * @author Sambit Das, David M. Rogers
60 */
61 class DeviceCCLWrapper
62 {
63 public:
64 DeviceCCLWrapper();
65
66 void
67 init(const MPI_Comm &mpiComm, const bool useDCCL);
68
69 ~DeviceCCLWrapper();
70
71 int
72 deviceDirectAllReduceWrapper(const float * send,
73 float * recv,
74 int size,
75 deviceStream_t &stream);
76
77
78 int
79 deviceDirectAllReduceWrapper(const double * send,
80 double * recv,
81 int size,
82 deviceStream_t &stream);
83
84
85 int
86 deviceDirectAllReduceWrapper(const std::complex<double> *send,
87 std::complex<double> * recv,
88 int size,
89 double * tempReal,
90 double * tempImag,
91 deviceStream_t & stream);
92
93 int
94 deviceDirectAllReduceWrapper(const std::complex<float> *send,
95 std::complex<float> * recv,
96 int size,
97 float * tempReal,
98 float * tempImag,
99 deviceStream_t & stream);
100
101
102 int
103 deviceDirectAllReduceMixedPrecGroupWrapper(const double * send1,
104 const float * send2,
105 double * recv1,
106 float * recv2,
107 int size1,
108 int size2,
109 deviceStream_t &stream);
110
111 int
112 deviceDirectAllReduceMixedPrecGroupWrapper(
113 const std::complex<double> *send1,
114 const std::complex<float> * send2,
115 std::complex<double> * recv1,
116 std::complex<float> * recv2,
117 int size1,
118 int size2,
119 double * tempReal1,
120 float * tempReal2,
121 double * tempImag1,
122 float * tempImag2,
123 deviceStream_t & stream);
124
125
126
127 inline void
128 deviceDirectAllReduceWrapper(const std::complex<float> *send,
129 std::complex<float> * recv,
130 int size,
131 deviceStream_t & stream)
132 {}
133
134
135 inline void
136 deviceDirectAllReduceWrapper(const std::complex<double> *send,
137 std::complex<double> * recv,
138 int size,
139 deviceStream_t & stream)
140 {}
141
142 inline void
143 deviceDirectAllReduceMixedPrecGroupWrapper(
144 const std::complex<double> *send1,
145 const std::complex<float> * send2,
146 std::complex<double> * recv1,
147 std::complex<float> * recv2,
148 int size1,
149 int size2,
150 deviceStream_t & stream)
151 {}
152
153
154 inline void
155 deviceDirectAllReduceWrapper(const double * send,
156 double * recv,
157 int size,
158 double * tempReal,
159 double * tempImag,
160 deviceStream_t &stream)
161 {}
162
163 inline void
164 deviceDirectAllReduceWrapper(const float * send,
165 float * recv,
166 int size,
167 float * tempReal,
168 float * tempImag,
169 deviceStream_t &stream)
170 {}
171
172 inline void
173 deviceDirectAllReduceMixedPrecGroupWrapper(const double * send1,
174 const float * send2,
175 double * recv1,
176 float * recv2,
177 int size1,
178 int size2,
179 double * tempReal1,
180 float * tempReal2,
181 double * tempImag1,
182 float * tempImag2,
183 deviceStream_t &stream)
184 {}
185
186# if defined(DFTFE_WITH_CUDA_NCCL) || defined(DFTFE_WITH_HIP_RCCL)
187 inline static ncclUniqueId *ncclIdPtr;
188 inline static ncclComm_t * ncclCommPtr;
189# endif
190 inline static bool ncclCommInit;
191 inline static dftfe::utils::deviceStream_t d_deviceCommStream;
192 inline static bool commStreamCreated;
193
194 private:
195 int myRank;
196 int totalRanks;
197 MPI_Comm d_mpiComm;
198 };
199 } // namespace utils
200} // namespace dftfe
201
202# endif
203#endif
Definition Cell.h:36
cudaStream_t deviceStream_t
Definition DeviceTypeConfig.cu.h:27
Definition pseudoPotentialToDftfeConverter.cc:34