18#if defined(DFTFE_WITH_DEVICE)
19# ifndef deviceDirectCCLWrapper_h
20# define deviceDirectCCLWrapper_h
26# if defined(DFTFE_WITH_CUDA_NCCL)
29# elif defined(DFTFE_WITH_HIP_RCCL)
38# if defined(DFTFE_WITH_CUDA_NCCL) || defined(DFTFE_WITH_HIP_RCCL)
39# define NCCLCHECK(cmd) \
42 ncclResult_t r = cmd; \
43 if (r != ncclSuccess) \
45 printf("Failed, NCCL error %s:%d '%s'\n", \
48 ncclGetErrorString(r)); \
61 class DeviceCCLWrapper
67 init(
const MPI_Comm &mpiComm,
const bool useDCCL);
72 deviceDirectAllReduceWrapper(
const float * send,
75 deviceStream_t &stream);
79 deviceDirectAllReduceWrapper(
const double * send,
82 deviceStream_t &stream);
86 deviceDirectAllReduceWrapper(
const std::complex<double> *send,
87 std::complex<double> * recv,
91 deviceStream_t & stream);
94 deviceDirectAllReduceWrapper(
const std::complex<float> *send,
95 std::complex<float> * recv,
99 deviceStream_t & stream);
103 deviceDirectAllReduceMixedPrecGroupWrapper(
const double * send1,
109 deviceStream_t &stream);
112 deviceDirectAllReduceMixedPrecGroupWrapper(
113 const std::complex<double> *send1,
114 const std::complex<float> * send2,
115 std::complex<double> * recv1,
116 std::complex<float> * recv2,
123 deviceStream_t & stream);
128 deviceDirectAllReduceWrapper(
const std::complex<float> *send,
129 std::complex<float> * recv,
131 deviceStream_t & stream)
136 deviceDirectAllReduceWrapper(
const std::complex<double> *send,
137 std::complex<double> * recv,
139 deviceStream_t & stream)
143 deviceDirectAllReduceMixedPrecGroupWrapper(
144 const std::complex<double> *send1,
145 const std::complex<float> * send2,
146 std::complex<double> * recv1,
147 std::complex<float> * recv2,
150 deviceStream_t & stream)
155 deviceDirectAllReduceWrapper(
const double * send,
160 deviceStream_t &stream)
164 deviceDirectAllReduceWrapper(
const float * send,
169 deviceStream_t &stream)
173 deviceDirectAllReduceMixedPrecGroupWrapper(
const double * send1,
183 deviceStream_t &stream)
186# if defined(DFTFE_WITH_CUDA_NCCL) || defined(DFTFE_WITH_HIP_RCCL)
187 inline static ncclUniqueId *ncclIdPtr;
188 inline static ncclComm_t * ncclCommPtr;
190 inline static bool ncclCommInit;
192 inline static bool commStreamCreated;
cudaStream_t deviceStream_t
Definition DeviceTypeConfig.cu.h:27
Definition pseudoPotentialToDftfeConverter.cc:34