Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6403388

Browse files
authoredNov 27, 2024
[hotfix] fix task count (#56)
* [hotfix] fix task count * [hotfix] fix h2d count
1 parent a4d34bf commit 6403388

File tree

14 files changed

+34
-32
lines changed

14 files changed

+34
-32
lines changed
 

‎csrc/aio.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "aio.h"
22

3-
AIOAsyncIO::AIOAsyncIO(unsigned int n_entries)
3+
AIOAsyncIO::AIOAsyncIO(unsigned int n_entries, unsigned int n_tasks)
44
{
55
// printf("Initializing the io Context\n");
66
this->max_nr = n_entries;

‎csrc/async_file_io.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "async_file_io.h"
22

3-
AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {}
3+
AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend, unsigned int n_tasks) : fd(fd), aio(create_asyncio(n_entries, backend, n_tasks)) {}
44

55
void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback)
66
{

‎csrc/backend.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,23 @@ void probe_asyncio(const std::string &backend)
4444
if (backend == "uring")
4545
{
4646
#ifndef DISABLE_URING
47-
aio.reset(new UringAsyncIO(2));
47+
aio.reset(new UringAsyncIO(2, 0));
4848
#else
4949
throw std::runtime_error("backend uring is not installed\n");
5050
#endif
5151
}
5252
else if (backend == "aio")
5353
{
5454
#ifndef DISABLE_AIO
55-
aio.reset(new AIOAsyncIO(2));
55+
aio.reset(new AIOAsyncIO(2, 0));
5656
#else
5757
throw std::runtime_error("backend aio is not installed\n");
5858
#endif
5959
}
6060
else if (backend == "pthread")
6161
{
6262
#ifndef DISABLE_PTHREAD
63-
aio.reset(new PthreadAsyncIO(2));
63+
aio.reset(new PthreadAsyncIO(2, 0));
6464
#else
6565
throw std::runtime_error("backend pthread is not installed\n");
6666
#endif
@@ -160,7 +160,7 @@ std::string get_debug_log()
160160
return std::string(env_);
161161
}
162162

163-
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend)
163+
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend, unsigned int n_tasks)
164164
{
165165
std::unordered_set<std::string> backends = get_backends();
166166
std::string default_backend = get_default_backend();
@@ -188,15 +188,15 @@ AsyncIO *create_asyncio(unsigned int n_entries, std::string backend)
188188

189189
#ifndef DISABLE_URING
190190
if (backend == "uring")
191-
return new UringAsyncIO(n_entries);
191+
return new UringAsyncIO(n_entries, n_tasks);
192192
#endif
193193
#ifndef DISABLE_AIO
194194
if (backend == "aio")
195-
return new AIOAsyncIO(n_entries);
195+
return new AIOAsyncIO(n_entries, n_tasks);
196196
#endif
197197
#ifndef DISABLE_PTHREAD
198198
if (backend == "pthread")
199-
return new PthreadAsyncIO(n_entries);
199+
return new PthreadAsyncIO(n_entries, n_tasks);
200200
#endif
201201
throw std::runtime_error("Unsupported backend: " + backend);
202202
}

‎csrc/offload.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ iovec *tensors_to_iovec(const std::vector<at::Tensor> &tensors)
2828

2929
Offloader::Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend) : filename(filename), space_mgr(SpaceManager(0))
3030
{
31-
this->aio = create_asyncio(n_entries, backend);
31+
this->aio = create_asyncio(n_entries, backend, 0);
3232
this->fd = open(filename.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
3333
this->aio->register_file(fd);
3434
}

‎csrc/pthread_backend.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ void PthreadAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long l
1616
auto val = pwrite(fd, buffer, n_bytes, offset);
1717
if (this->is_debug)
1818
{
19-
auto cur_tasks = this->tasks_in_progress.fetch_sub(1);
20-
if (cur_tasks == 1)
19+
auto cur_tasks = this->tasks_in_progress.fetch_add(1);
20+
if (cur_tasks + 1 == this->total_tasks)
2121
{
2222
if (this->debug_log.empty())
2323
{
@@ -117,23 +117,23 @@ void PthreadAsyncIO::register_file(int fd) {}
117117

118118
void PthreadAsyncIO::register_h2d(unsigned int num_tensors)
119119
{
120-
this->h2d_in_progress.store(num_tensors); // register tensors to write for this run
120+
this->total_h2d = num_tensors;
121121
}
122122

123123
void PthreadAsyncIO::sync_h2d()
124124
{
125125
std::unique_lock<std::mutex> lock(this->mtx);
126126
this->cv.wait(lock, [this]
127-
{ return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed
127+
{ return this->h2d_in_progress == this->total_h2d; }); // block until all in-progress h2d are completed
128128
}
129129

130130
void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned)
131131
{
132132
auto stream = c10::cuda::getCurrentCUDAStream();
133133
if (!t.is_cuda())
134134
{
135-
this->h2d_in_progress.fetch_sub(1); // already moved to cpu
136-
if (this->h2d_in_progress.load() == 0)
135+
auto cur_h2d = this->h2d_in_progress.fetch_add(1); // already moved to cpu
136+
if (cur_h2d + 1 == this->total_h2d)
137137
{ // notify when all h2d are completed and safe to optimizer.step()
138138
std::lock_guard<std::mutex> lock(this->mtx);
139139
cv.notify_one();
@@ -155,8 +155,8 @@ void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long of
155155
{
156156
cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu()
157157
}
158-
this->h2d_in_progress.fetch_sub(1);
159-
if (this->h2d_in_progress.load() == 0)
158+
auto cur_h2d = this->h2d_in_progress.fetch_add(1);
159+
if (cur_h2d + 1 == this->total_h2d)
160160
{ // notify when all h2d are completed and safe to optimizer.step()
161161
std::lock_guard<std::mutex> lock(this->mtx);
162162
cv.notify_one();
@@ -171,8 +171,8 @@ void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long of
171171
auto val = pwrite(fd, buf, n_bytes, offset);
172172
if (this->is_debug)
173173
{
174-
auto cur_tasks = this->tasks_in_progress.fetch_sub(1);
175-
if (cur_tasks == 1)
174+
auto cur_tasks = this->tasks_in_progress.fetch_add(1);
175+
if (cur_tasks + 1 == this->total_tasks)
176176
{
177177
if (this->debug_log.empty())
178178
{

‎csrc/py_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2727
m.def("get_backends", get_backends);
2828
m.def("probe_backend", probe_backend, py::arg("backend"));
2929
py::class_<AsyncFileWriter>(m, "AsyncFileWriter")
30-
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
30+
.def(py::init<int, unsigned int, const std::string &, unsigned int>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio", py::arg("n_tasks") = 0)
3131
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none())
3232
.def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none())
3333
.def("synchronize", &AsyncFileWriter::synchronize)

‎csrc/uring.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include <memory>
33
#include "uring.h"
44

5-
UringAsyncIO::UringAsyncIO(unsigned int n_entries) : n_write_events(0), n_read_events(0), n_entries(n_entries)
5+
UringAsyncIO::UringAsyncIO(unsigned int n_entries, unsigned int n_tasks) : n_write_events(0), n_read_events(0), n_entries(n_entries)
66
{
77
io_uring_queue_init(n_entries, &this->ring, 0);
88
}

‎include/aio.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class AIOAsyncIO : public AsyncIO
1919
void get_event(WaitType wt);
2020

2121
public:
22-
AIOAsyncIO(unsigned int n_entries);
22+
AIOAsyncIO(unsigned int n_entries, unsigned int n_tasks);
2323
~AIOAsyncIO();
2424

2525
void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback);

‎include/async_file_io.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class AsyncFileWriter
1818
{
1919
public:
20-
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend);
20+
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend, unsigned int n_tasks);
2121
void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
2222
void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
2323
void synchronize();

‎include/backend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ std::string get_default_backend();
1414

1515
bool get_debug_flag();
1616

17-
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend);
17+
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend, unsigned int n_tasks);
1818

1919
std::string get_debug_log();

‎include/pthread_backend.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class PthreadAsyncIO : public AsyncIO
2626
private:
2727
BS::thread_pool pool;
2828
std::atomic<unsigned int> h2d_in_progress;
29+
unsigned int total_h2d;
2930
std::condition_variable cv;
3031
std::mutex mtx;
3132
std::deque<std::tuple<std::future<ssize_t>, callback_t>> write_fut;
@@ -34,10 +35,11 @@ class PthreadAsyncIO : public AsyncIO
3435
const std::string debug_log = get_debug_log();
3536

3637
std::atomic<unsigned int> tasks_in_progress;
38+
unsigned int total_tasks;
3739

3840
public:
39-
PthreadAsyncIO(unsigned int n_entries)
40-
: pool(n_entries), h2d_in_progress(0) {}
41+
PthreadAsyncIO(unsigned int n_entries, unsigned int n_tasks)
42+
: pool(n_entries), h2d_in_progress(0), tasks_in_progress(0), total_tasks(n_tasks), total_h2d(0) {}
4143

4244
~PthreadAsyncIO() {}
4345

‎include/uring.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class UringAsyncIO : public AsyncIO
1313
void get_event(WaitType wt);
1414

1515
public:
16-
UringAsyncIO(unsigned int n_entries);
16+
UringAsyncIO(unsigned int n_entries, unsigned int n_tasks);
1717
~UringAsyncIO();
1818

1919
void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback);

‎tensornvme/_C/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_backends() -> Set[str]: ...
2020
def probe_backend(backend: str) -> bool: ...
2121

2222
class AsyncFileWriter:
23-
def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ...
23+
def __init__(self, fd: int, n_entries: int, backend: str = "aio", n_tasks: int = 0) -> None: ...
2424
def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ...
2525
def write_tensor(
2626
self,

‎tensornvme/async_file_io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111

1212
class AsyncFileWriter:
13-
def __init__(self, path: str, n_entries: int = 16, backend=None) -> None:
13+
def __init__(self, path: str, n_entries: int = 16, backend=None, n_tasks: int = 0) -> None:
1414
# this still takes ram buffer, which may lead to OOM
1515
# self.f = open(path, "wb", buffering=0)
1616
self.fd = os.open(path, os.O_WRONLY | os.O_CREAT, mode=0o664)
1717
if backend is not None:
18-
self.io = AsyncFileWriterC(self.fd, n_entries, backend=backend)
18+
self.io = AsyncFileWriterC(self.fd, n_entries, backend=backend, n_tasks=n_tasks)
1919
else:
20-
self.io = AsyncFileWriterC(self.fd, n_entries)
20+
self.io = AsyncFileWriterC(self.fd, n_entries, n_tasks=n_tasks)
2121
self.offset = 0
2222
# must ensure the data is not garbage collected
2323
self.buffers = []

0 commit comments

Comments
 (0)
Failed to load comments.