DFT-FE 1.1.0-pre
Density Functional Theory With Finite-Elements
Loading...
Searching...
No Matches
deviceDirectCCLWrapper.h
Go to the documentation of this file.
1// ---------------------------------------------------------------------
2//
3// Copyright (c) 2017-2025 The Regents of the University of Michigan and DFT-FE
4// authors.
5//
6// This file is part of the DFT-FE code.
7//
8// The DFT-FE code is free software; you can use it, redistribute
9// it, and/or modify it under the terms of the GNU Lesser General
10// Public License as published by the Free Software Foundation; either
11// version 2.1 of the License, or (at your option) any later version.
12// The full text of the license can be found in the file LICENSE at
13// the top level of the DFT-FE distribution.
14//
15// ---------------------------------------------------------------------
16// @author Sambit Das, David M. Rogers
17
18#if defined(DFTFE_WITH_DEVICE)
19# ifndef deviceDirectCCLWrapper_h
20# define deviceDirectCCLWrapper_h
21
22# include <complex>
23# include <mpi.h>
24# include <TypeConfig.h>
25# include <DeviceTypeConfig.h>
26
27# if defined(DFTFE_WITH_CUDA_NCCL)
28# include <nccl.h>
29# include <DeviceTypeConfig.h>
30# elif defined(DFTFE_WITH_HIP_RCCL)
31# include <rccl.h>
32# include <DeviceTypeConfig.h>
33# endif
34
35namespace dftfe
36{
37 namespace utils
38 {
39# if defined(DFTFE_WITH_CUDA_NCCL) || defined(DFTFE_WITH_HIP_RCCL)
40# define NCCLCHECK(cmd) \
41 do \
42 { \
43 ncclResult_t r = cmd; \
44 if (r != ncclSuccess) \
45 { \
46 printf("Failed, NCCL error %s:%d '%s'\n", \
47 __FILE__, \
48 __LINE__, \
49 ncclGetErrorString(r)); \
50 exit(EXIT_FAILURE); \
51 } \
52 } while (0)
53# endif
54 /**
55 * @brief Wrapper class for Device Direct collective communications library.
56 * Adapted from
57 * https://code.ornl.gov/99R/olcf-cookbook/-/blob/develop/comms/nccl_allreduce.rst
58 *
59 * @author Sambit Das, David M. Rogers
60 */
61 class DeviceCCLWrapper
62 {
63 public:
64 DeviceCCLWrapper();
65
66 void
67 init(const MPI_Comm &mpiComm, const bool useDCCL);
68
69 ~DeviceCCLWrapper();
70
72 deviceDirectAllReduceWrapper(const float *send,
73 float *recv,
74 dftfe::Int size,
75 deviceStream_t &stream);
76
77
79 deviceDirectAllReduceWrapper(const double *send,
80 double *recv,
81 dftfe::Int size,
82 deviceStream_t &stream);
83
84
86 deviceDirectAllReduceWrapper(const std::complex<double> *send,
87 std::complex<double> *recv,
88 dftfe::Int size,
89 double *tempReal,
90 double *tempImag,
91 deviceStream_t &stream);
92
94 deviceDirectAllReduceWrapper(const std::complex<float> *send,
95 std::complex<float> *recv,
96 dftfe::Int size,
97 float *tempReal,
98 float *tempImag,
99 deviceStream_t &stream);
100
101
103 deviceDirectAllReduceMixedPrecGroupWrapper(const double *send1,
104 const float *send2,
105 double *recv1,
106 float *recv2,
107 dftfe::Int size1,
108 dftfe::Int size2,
109 deviceStream_t &stream);
110
112 deviceDirectAllReduceMixedPrecGroupWrapper(
113 const std::complex<double> *send1,
114 const std::complex<float> *send2,
115 std::complex<double> *recv1,
116 std::complex<float> *recv2,
117 dftfe::Int size1,
118 dftfe::Int size2,
119 double *tempReal1,
120 float *tempReal2,
121 double *tempImag1,
122 float *tempImag2,
123 deviceStream_t &stream);
124
125
126
127 inline void
128 deviceDirectAllReduceWrapper(const std::complex<float> *send,
129 std::complex<float> *recv,
130 dftfe::Int size,
131 deviceStream_t &stream)
132 {}
133
134
135 inline void
136 deviceDirectAllReduceWrapper(const std::complex<double> *send,
137 std::complex<double> *recv,
138 dftfe::Int size,
139 deviceStream_t &stream)
140 {}
141
142 inline void
143 deviceDirectAllReduceMixedPrecGroupWrapper(
144 const std::complex<double> *send1,
145 const std::complex<float> *send2,
146 std::complex<double> *recv1,
147 std::complex<float> *recv2,
148 dftfe::Int size1,
149 dftfe::Int size2,
150 deviceStream_t &stream)
151 {}
152
153
154 inline void
155 deviceDirectAllReduceWrapper(const double *send,
156 double *recv,
157 dftfe::Int size,
158 double *tempReal,
159 double *tempImag,
160 deviceStream_t &stream)
161 {}
162
163 inline void
164 deviceDirectAllReduceWrapper(const float *send,
165 float *recv,
166 dftfe::Int size,
167 float *tempReal,
168 float *tempImag,
169 deviceStream_t &stream)
170 {}
171
172 inline void
173 deviceDirectAllReduceMixedPrecGroupWrapper(const double *send1,
174 const float *send2,
175 double *recv1,
176 float *recv2,
177 dftfe::Int size1,
178 dftfe::Int size2,
179 double *tempReal1,
180 float *tempReal2,
181 double *tempImag1,
182 float *tempImag2,
183 deviceStream_t &stream)
184 {}
185
186# if defined(DFTFE_WITH_CUDA_NCCL) || defined(DFTFE_WITH_HIP_RCCL)
187 inline static ncclUniqueId *ncclIdPtr;
188 inline static ncclComm_t *ncclCommPtr;
189# endif
190 inline static bool ncclCommInit;
191 inline static dftfe::utils::deviceStream_t d_deviceCommStream;
192 inline static bool commStreamCreated;
193
194 private:
195 int myRank;
196 int totalRanks;
197 MPI_Comm d_mpiComm;
198 };
199 } // namespace utils
200} // namespace dftfe
201
202# endif
203#endif
Definition Cell.h:36
cudaStream_t deviceStream_t
Definition DeviceTypeConfig.cu.h:27
Definition pseudoPotentialToDftfeConverter.cc:34
std::int32_t Int
Definition TypeConfig.h:11