1  
//
1  
//
2  
// Copyright (c) 2025 Vinnie Falco (vinnie.falco@gmail.com)
2  
// Copyright (c) 2025 Vinnie Falco (vinnie.falco@gmail.com)
3  
//
3  
//
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6  
//
6  
//
7  
// Official repository: https://github.com/cppalliance/capy
7  
// Official repository: https://github.com/cppalliance/capy
8  
//
8  
//
9  

9  

10  
#include "src/ex/detail/strand_queue.hpp"
10  
#include "src/ex/detail/strand_queue.hpp"
11  
#include <boost/capy/ex/detail/strand_service.hpp>
11  
#include <boost/capy/ex/detail/strand_service.hpp>
12  
#include <atomic>
12  
#include <atomic>
13  
#include <coroutine>
13  
#include <coroutine>
14  
#include <mutex>
14  
#include <mutex>
15  
#include <thread>
15  
#include <thread>
16  
#include <utility>
16  
#include <utility>
17  

17  

18  
namespace boost {
18  
namespace boost {
19  
namespace capy {
19  
namespace capy {
20  
namespace detail {
20  
namespace detail {
21  

21  

22  
//----------------------------------------------------------
22  
//----------------------------------------------------------
23  

23  

24  
/** Implementation state for a strand.
24  
/** Implementation state for a strand.
25  

25  

26  
    Each strand_impl provides serialization for coroutines
26  
    Each strand_impl provides serialization for coroutines
27  
    dispatched through strands that share it.
27  
    dispatched through strands that share it.
28 -
// Sentinel stored in cached_frame_ after shutdown to prevent
 
29 -
// in-flight invokers from repopulating a freed cache slot.
 
30 -
inline void* const kCacheClosed = reinterpret_cast<void*>(1);
 
31 -

 
32  
*/
28  
*/
33  
struct strand_impl
29  
struct strand_impl
34  
{
30  
{
35  
    std::mutex mutex_;
31  
    std::mutex mutex_;
36  
    strand_queue pending_;
32  
    strand_queue pending_;
37  
    bool locked_ = false;
33  
    bool locked_ = false;
38  
    std::atomic<std::thread::id> dispatch_thread_{};
34  
    std::atomic<std::thread::id> dispatch_thread_{};
39 -
    std::atomic<void*> cached_frame_{nullptr};
35 +
    void* cached_frame_ = nullptr;
40  
};
36  
};
41  

37  

42  
//----------------------------------------------------------
38  
//----------------------------------------------------------
43  

39  

44  
/** Invoker coroutine for strand dispatch.
40  
/** Invoker coroutine for strand dispatch.
45  

41  

46  
    Uses custom allocator to recycle frame - one allocation
42  
    Uses custom allocator to recycle frame - one allocation
47  
    per strand_impl lifetime, stored in trailer for recovery.
43  
    per strand_impl lifetime, stored in trailer for recovery.
48  
*/
44  
*/
49  
struct strand_invoker
45  
struct strand_invoker
50  
{
46  
{
51  
    struct promise_type
47  
    struct promise_type
52  
    {
48  
    {
53  
        void* operator new(std::size_t n, strand_impl& impl)
49  
        void* operator new(std::size_t n, strand_impl& impl)
54  
        {
50  
        {
55  
            constexpr auto A = alignof(strand_impl*);
51  
            constexpr auto A = alignof(strand_impl*);
56  
            std::size_t padded = (n + A - 1) & ~(A - 1);
52  
            std::size_t padded = (n + A - 1) & ~(A - 1);
57  
            std::size_t total = padded + sizeof(strand_impl*);
53  
            std::size_t total = padded + sizeof(strand_impl*);
58  

54  

59 -
            void* p = impl.cached_frame_.exchange(
55 +
            void* p = impl.cached_frame_
60 -
                nullptr, std::memory_order_acquire);
56 +
                ? std::exchange(impl.cached_frame_, nullptr)
61 -
            if(!p || p == kCacheClosed)
57 +
                : ::operator new(total);
62 -
                p = ::operator new(total);
 
63  

58  

64  
            // Trailer lets delete recover impl
59  
            // Trailer lets delete recover impl
65  
            *reinterpret_cast<strand_impl**>(
60  
            *reinterpret_cast<strand_impl**>(
66  
                static_cast<char*>(p) + padded) = &impl;
61  
                static_cast<char*>(p) + padded) = &impl;
67  
            return p;
62  
            return p;
68  
        }
63  
        }
69  

64  

70  
        void operator delete(void* p, std::size_t n) noexcept
65  
        void operator delete(void* p, std::size_t n) noexcept
71  
        {
66  
        {
72  
            constexpr auto A = alignof(strand_impl*);
67  
            constexpr auto A = alignof(strand_impl*);
73  
            std::size_t padded = (n + A - 1) & ~(A - 1);
68  
            std::size_t padded = (n + A - 1) & ~(A - 1);
74  

69  

75  
            auto* impl = *reinterpret_cast<strand_impl**>(
70  
            auto* impl = *reinterpret_cast<strand_impl**>(
76  
                static_cast<char*>(p) + padded);
71  
                static_cast<char*>(p) + padded);
77  

72  

78 -
            void* expected = nullptr;
73 +
            if (!impl->cached_frame_)
79 -
            if(!impl->cached_frame_.compare_exchange_strong(
74 +
                impl->cached_frame_ = p;
80 -
                expected, p, std::memory_order_release))
75 +
            else
81  
                ::operator delete(p);
76  
                ::operator delete(p);
82  
        }
77  
        }
83  

78  

84  
        strand_invoker get_return_object() noexcept
79  
        strand_invoker get_return_object() noexcept
85  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
80  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
86  

81  

87  
        std::suspend_always initial_suspend() noexcept { return {}; }
82  
        std::suspend_always initial_suspend() noexcept { return {}; }
88  
        std::suspend_never final_suspend() noexcept { return {}; }
83  
        std::suspend_never final_suspend() noexcept { return {}; }
89  
        void return_void() noexcept {}
84  
        void return_void() noexcept {}
90  
        void unhandled_exception() { std::terminate(); }
85  
        void unhandled_exception() { std::terminate(); }
91  
    };
86  
    };
92  

87  

93  
    std::coroutine_handle<promise_type> h_;
88  
    std::coroutine_handle<promise_type> h_;
94  
};
89  
};
95  

90  

96  
//----------------------------------------------------------
91  
//----------------------------------------------------------
97  

92  

98  
/** Concrete implementation of strand_service.
93  
/** Concrete implementation of strand_service.
99  

94  

100  
    Holds the fixed pool of strand_impl objects.
95  
    Holds the fixed pool of strand_impl objects.
101  
*/
96  
*/
102  
class strand_service_impl : public strand_service
97  
class strand_service_impl : public strand_service
103  
{
98  
{
104  
    static constexpr std::size_t num_impls = 211;
99  
    static constexpr std::size_t num_impls = 211;
105  

100  

106  
    strand_impl impls_[num_impls];
101  
    strand_impl impls_[num_impls];
107  
    std::size_t salt_ = 0;
102  
    std::size_t salt_ = 0;
108  
    std::mutex mutex_;
103  
    std::mutex mutex_;
109  

104  

110  
public:
105  
public:
111  
    explicit
106  
    explicit
112  
    strand_service_impl(execution_context&)
107  
    strand_service_impl(execution_context&)
113  
    {
108  
    {
114  
    }
109  
    }
115  

110  

116  
    strand_impl*
111  
    strand_impl*
117  
    get_implementation() override
112  
    get_implementation() override
118  
    {
113  
    {
119  
        std::lock_guard<std::mutex> lock(mutex_);
114  
        std::lock_guard<std::mutex> lock(mutex_);
120  
        std::size_t index = salt_++;
115  
        std::size_t index = salt_++;
121  
        index = index % num_impls;
116  
        index = index % num_impls;
122  
        return &impls_[index];
117  
        return &impls_[index];
123  
    }
118  
    }
124  

119  

125  
protected:
120  
protected:
126  
    void
121  
    void
127  
    shutdown() override
122  
    shutdown() override
128  
    {
123  
    {
129  
        for(std::size_t i = 0; i < num_impls; ++i)
124  
        for(std::size_t i = 0; i < num_impls; ++i)
130  
        {
125  
        {
131  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
126  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
132  
            impls_[i].locked_ = true;
127  
            impls_[i].locked_ = true;
133  

128  

134 -
            void* p = impls_[i].cached_frame_.exchange(
129 +
            if(impls_[i].cached_frame_)
135 -
                kCacheClosed, std::memory_order_acquire);
130 +
            {
136 -
            if(p)
131 +
                ::operator delete(impls_[i].cached_frame_);
137 -
                ::operator delete(p);
132 +
                impls_[i].cached_frame_ = nullptr;
 
133 +
            }
138  
        }
134  
        }
139  
    }
135  
    }
140  

136  

141  
private:
137  
private:
142  
    static bool
138  
    static bool
143  
    enqueue(strand_impl& impl, std::coroutine_handle<> h)
139  
    enqueue(strand_impl& impl, std::coroutine_handle<> h)
144  
    {
140  
    {
145  
        std::lock_guard<std::mutex> lock(impl.mutex_);
141  
        std::lock_guard<std::mutex> lock(impl.mutex_);
146  
        impl.pending_.push(h);
142  
        impl.pending_.push(h);
147  
        if(!impl.locked_)
143  
        if(!impl.locked_)
148  
        {
144  
        {
149  
            impl.locked_ = true;
145  
            impl.locked_ = true;
150  
            return true;
146  
            return true;
151  
        }
147  
        }
152  
        return false;
148  
        return false;
153  
    }
149  
    }
154  

150  

155  
    static void
151  
    static void
156  
    dispatch_pending(strand_impl& impl)
152  
    dispatch_pending(strand_impl& impl)
157  
    {
153  
    {
158  
        strand_queue::taken_batch batch;
154  
        strand_queue::taken_batch batch;
159  
        {
155  
        {
160  
            std::lock_guard<std::mutex> lock(impl.mutex_);
156  
            std::lock_guard<std::mutex> lock(impl.mutex_);
161  
            batch = impl.pending_.take_all();
157  
            batch = impl.pending_.take_all();
162  
        }
158  
        }
163  
        impl.pending_.dispatch_batch(batch);
159  
        impl.pending_.dispatch_batch(batch);
164  
    }
160  
    }
165  

161  

166  
    static bool
162  
    static bool
167  
    try_unlock(strand_impl& impl)
163  
    try_unlock(strand_impl& impl)
168  
    {
164  
    {
169  
        std::lock_guard<std::mutex> lock(impl.mutex_);
165  
        std::lock_guard<std::mutex> lock(impl.mutex_);
170  
        if(impl.pending_.empty())
166  
        if(impl.pending_.empty())
171  
        {
167  
        {
172  
            impl.locked_ = false;
168  
            impl.locked_ = false;
173  
            return true;
169  
            return true;
174  
        }
170  
        }
175  
        return false;
171  
        return false;
176  
    }
172  
    }
177  

173  

178  
    static void
174  
    static void
179  
    set_dispatch_thread(strand_impl& impl) noexcept
175  
    set_dispatch_thread(strand_impl& impl) noexcept
180  
    {
176  
    {
181  
        impl.dispatch_thread_.store(std::this_thread::get_id());
177  
        impl.dispatch_thread_.store(std::this_thread::get_id());
182  
    }
178  
    }
183  

179  

184  
    static void
180  
    static void
185  
    clear_dispatch_thread(strand_impl& impl) noexcept
181  
    clear_dispatch_thread(strand_impl& impl) noexcept
186  
    {
182  
    {
187  
        impl.dispatch_thread_.store(std::thread::id{});
183  
        impl.dispatch_thread_.store(std::thread::id{});
188  
    }
184  
    }
189  

185  

190  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
186  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
191  
    // (repost after each batch to let other work run) - explore if starvation observed.
187  
    // (repost after each batch to let other work run) - explore if starvation observed.
192  
    static strand_invoker
188  
    static strand_invoker
193  
    make_invoker(strand_impl& impl)
189  
    make_invoker(strand_impl& impl)
194  
    {
190  
    {
195  
        strand_impl* p = &impl;
191  
        strand_impl* p = &impl;
196  
        for(;;)
192  
        for(;;)
197  
        {
193  
        {
198  
            set_dispatch_thread(*p);
194  
            set_dispatch_thread(*p);
199  
            dispatch_pending(*p);
195  
            dispatch_pending(*p);
200  
            if(try_unlock(*p))
196  
            if(try_unlock(*p))
201  
            {
197  
            {
202  
                clear_dispatch_thread(*p);
198  
                clear_dispatch_thread(*p);
203  
                co_return;
199  
                co_return;
204  
            }
200  
            }
205  
        }
201  
        }
206  
    }
202  
    }
207  

203  

208  
    friend class strand_service;
204  
    friend class strand_service;
209  
};
205  
};
210  

206  

211  
//----------------------------------------------------------
207  
//----------------------------------------------------------
212  

208  

213  
strand_service::
209  
strand_service::
214  
strand_service()
210  
strand_service()
215  
    : service()
211  
    : service()
216  
{
212  
{
217  
}
213  
}
218  

214  

219  
strand_service::
215  
strand_service::
220  
~strand_service() = default;
216  
~strand_service() = default;
221  

217  

222  
bool
218  
bool
223  
strand_service::
219  
strand_service::
224  
running_in_this_thread(strand_impl& impl) noexcept
220  
running_in_this_thread(strand_impl& impl) noexcept
225  
{
221  
{
226  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
222  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
227  
}
223  
}
228  

224  

229  
std::coroutine_handle<>
225  
std::coroutine_handle<>
230  
strand_service::
226  
strand_service::
231  
dispatch(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
227  
dispatch(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
232  
{
228  
{
233  
    if(running_in_this_thread(impl))
229  
    if(running_in_this_thread(impl))
234  
        return h;
230  
        return h;
235  

231  

236  
    if(strand_service_impl::enqueue(impl, h))
232  
    if(strand_service_impl::enqueue(impl, h))
237  
        ex.post(strand_service_impl::make_invoker(impl).h_);
233  
        ex.post(strand_service_impl::make_invoker(impl).h_);
238  
    return std::noop_coroutine();
234  
    return std::noop_coroutine();
239  
}
235  
}
240  

236  

241  
void
237  
void
242  
strand_service::
238  
strand_service::
243  
post(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
239  
post(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
244  
{
240  
{
245  
    if(strand_service_impl::enqueue(impl, h))
241  
    if(strand_service_impl::enqueue(impl, h))
246  
        ex.post(strand_service_impl::make_invoker(impl).h_);
242  
        ex.post(strand_service_impl::make_invoker(impl).h_);
247  
}
243  
}
248  

244  

249  
strand_service&
245  
strand_service&
250  
get_strand_service(execution_context& ctx)
246  
get_strand_service(execution_context& ctx)
251  
{
247  
{
252  
    return ctx.use_service<strand_service_impl>();
248  
    return ctx.use_service<strand_service_impl>();
253  
}
249  
}
254  

250  

255  
} // namespace detail
251  
} // namespace detail
256  
} // namespace capy
252  
} // namespace capy
257  
} // namespace boost
253  
} // namespace boost