DFT-FE 1.1.0-pre
Density Functional Theory With Finite-Elements
Loading...
Searching...
No Matches
deviceDirectCCLWrapper.h
Go to the documentation of this file.
1// ---------------------------------------------------------------------
2//
3// Copyright (c) 2017-2025 The Regents of the University of Michigan and DFT-FE
4// authors.
5//
6// This file is part of the DFT-FE code.
7//
8// The DFT-FE code is free software; you can use it, redistribute
9// it, and/or modify it under the terms of the GNU Lesser General
10// Public License as published by the Free Software Foundation; either
11// version 2.1 of the License, or (at your option) any later version.
12// The full text of the license can be found in the file LICENSE at
13// the top level of the DFT-FE distribution.
14//
15// ---------------------------------------------------------------------
16// @author Sambit Das, David M. Rogers
17
18#if defined(DFTFE_WITH_DEVICE)
19# ifndef deviceDirectCCLWrapper_h
20# define deviceDirectCCLWrapper_h
21
22# include <complex>
23# include <mpi.h>
24# include <TypeConfig.h>
25# include <DeviceTypeConfig.h>
26
27# if defined(DFTFE_WITH_CUDA_NCCL)
28# include <nccl.h>
29# include <DeviceTypeConfig.h>
30# elif defined(DFTFE_WITH_HIP_RCCL)
31# include <rccl.h>
32# include <DeviceTypeConfig.h>
33# endif
34
35namespace dftfe
36{
37 namespace utils
38 {
39# if defined(DFTFE_WITH_CUDA_NCCL) || defined(DFTFE_WITH_HIP_RCCL)
40# define NCCLCHECK(cmd) \
41 do \
42 { \
43 ncclResult_t r = cmd; \
44 if (r != ncclSuccess) \
45 { \
46 printf("Failed, NCCL error %s:%d '%s'\n", \
47 __FILE__, \
48 __LINE__, \
49 ncclGetErrorString(r)); \
50 exit(EXIT_FAILURE); \
51 } \
52 } while (0)
53# endif
54 /**
55 * @brief Wrapper class for Device Direct collective communications library.
56 * Adapted from
57 * https://code.ornl.gov/99R/olcf-cookbook/-/blob/develop/comms/nccl_allreduce.rst
58 *
59 * @author Sambit Das, David M. Rogers
60 */
61 class DeviceCCLWrapper
62 {
63 public:
64 DeviceCCLWrapper();
65
66 void
67 init(const MPI_Comm &mpiComm, const bool useDCCL);
68
69 ~DeviceCCLWrapper();
70
72 deviceDirectAllReduceWrapper(const float *send,
73 float *recv,
74 dftfe::Int size,
75 deviceStream_t &stream);
76
77
79 deviceDirectAllReduceWrapper(const double *send,
80 double *recv,
81 dftfe::Int size,
82 deviceStream_t &stream);
83
84
86 deviceDirectAllReduceWrapper(const std::complex<double> *send,
87 std::complex<double> *recv,
88 dftfe::Int size,
89 deviceStream_t &stream);
90
92 deviceDirectAllReduceWrapper(const std::complex<float> *send,
93 std::complex<float> *recv,
94 dftfe::Int size,
95 deviceStream_t &stream);
96
97
99 deviceDirectAllReduceMixedPrecGroupWrapper(const double *send1,
100 const float *send2,
101 double *recv1,
102 float *recv2,
103 dftfe::Int size1,
104 dftfe::Int size2,
105 deviceStream_t &stream);
106
108 deviceDirectAllReduceMixedPrecGroupWrapper(
109 const std::complex<double> *send1,
110 const std::complex<float> *send2,
111 std::complex<double> *recv1,
112 std::complex<float> *recv2,
113 dftfe::Int size1,
114 dftfe::Int size2,
115 deviceStream_t &stream);
116
117# if defined(DFTFE_WITH_CUDA_NCCL) || defined(DFTFE_WITH_HIP_RCCL)
118 inline static ncclUniqueId *ncclIdPtr;
119 inline static ncclComm_t *ncclCommPtr;
120# endif
121 inline static bool ncclCommInit;
122 inline static dftfe::utils::deviceStream_t d_deviceCommStream;
123 inline static bool commStreamCreated;
124 inline static dftfe::Int d_deviceDirectDCCLInstanceCounter;
125
126 private:
127 int myRank;
128 int totalRanks;
129 MPI_Comm d_mpiComm;
130 };
131 } // namespace utils
132} // namespace dftfe
133
134# endif
135#endif
Definition Cell.h:36
cudaStream_t deviceStream_t
Definition DeviceTypeConfig.cu.h:27
Definition pseudoPotentialToDftfeConverter.cc:34
std::int32_t Int
Definition TypeConfig.h:11