refactor: move scheduler to async

This commit is contained in:
Folke Lemaitre 2024-06-26 15:11:31 +02:00
parent 0eb46e7816
commit 768de1ebf6
No known key found for this signature in database
GPG Key ID: 41F8B1FBACAE2040
2 changed files with 129 additions and 62 deletions

107
lua/lazy/async.lua Normal file
View File

@ -0,0 +1,107 @@
---@class AsyncOpts
---@field on_done? fun()
---@field on_error? fun(err:string)
---@field on_yield? fun(res:any)
local M = {}
---@type Async[]
M._queue = {}
M._executor = assert(vim.loop.new_check())
M._running = false
---@class Async
---@field co thread
---@field opts AsyncOpts
local Async = {}
---@param fn async fun()
---@param opts? AsyncOpts
---@return Async
function Async.new(fn, opts)
local self = setmetatable({}, { __index = Async })
self.co = coroutine.create(fn)
self.opts = opts or {}
return self
end
function Async:running()
return coroutine.status(self.co) ~= "dead"
end
function Async:step()
local status = coroutine.status(self.co)
if status == "suspended" then
local ok, res = coroutine.resume(self.co)
if not ok then
if self.opts.on_error then
self.opts.on_error(tostring(res))
end
elseif res then
if self.opts.on_yield then
self.opts.on_yield(res)
end
end
end
if self:running() then
return true
end
if self.opts.on_done then
self.opts.on_done()
end
end
function M.step()
M._running = true
local budget = 1 * 1e6
local start = vim.loop.hrtime()
local count = #M._queue
local i = 0
while #M._queue > 0 and vim.loop.hrtime() - start < budget do
---@type Async
local state = table.remove(M._queue, 1)
if state:step() then
table.insert(M._queue, state)
end
i = i + 1
if i >= count then
break
end
end
M._running = false
if #M._queue == 0 then
return M._executor:stop()
end
end
---@param async Async
function M.add(async)
table.insert(M._queue, async)
if not M._executor:is_active() then
M._executor:start(vim.schedule_wrap(M.step))
end
return async
end
---@param fn async fun()
---@param opts? AsyncOpts
function M.run(fn, opts)
return M.add(Async.new(fn, opts))
end
---@generic T: async fun()
---@param fn T
---@param opts? AsyncOpts
---@return T
function M.wrap(fn, opts)
return function(...)
local args = { ... }
---@async
local wrapped = function()
return fn(unpack(args))
end
return M.run(wrapped, opts)
end
end
return M

View File

@ -1,3 +1,4 @@
local Async = require("lazy.async")
local Process = require("lazy.manage.process")
---@class LazyTaskDef
@ -6,44 +7,6 @@ local Process = require("lazy.manage.process")
---@alias LazyTaskState {task:LazyTask, thread:thread}
local Scheduler = {}
---@type LazyTaskState[]
Scheduler._queue = {}
Scheduler._executor = assert(vim.loop.new_check())
Scheduler._running = false
function Scheduler.step()
Scheduler._running = true
local budget = 1 * 1e6
local start = vim.loop.hrtime()
local count = #Scheduler._queue
local i = 0
while #Scheduler._queue > 0 and vim.loop.hrtime() - start < budget do
---@type LazyTaskState
local state = table.remove(Scheduler._queue, 1)
state.task:_step(state.thread)
if coroutine.status(state.thread) ~= "dead" then
table.insert(Scheduler._queue, state)
end
i = i + 1
if i >= count then
break
end
end
Scheduler._running = false
if #Scheduler._queue == 0 then
return Scheduler._executor:stop()
end
end
---@param state LazyTaskState
function Scheduler.add(state)
table.insert(Scheduler._queue, state)
if not Scheduler._executor:is_active() then
Scheduler._executor:start(vim.schedule_wrap(Scheduler.step))
end
end
---@class LazyTask
---@field plugin LazyPlugin
---@field name string
@ -55,7 +18,7 @@ end
---@field private _started? number
---@field private _ended? number
---@field private _opts TaskOptions
---@field private _threads thread[]
---@field private _running Async[]
local Task = {}
---@class TaskOptions: {[string]:any}
@ -70,7 +33,7 @@ function Task.new(plugin, name, task, opts)
__index = Task,
})
self._opts = opts or {}
self._threads = {}
self._running = {}
self._task = task
self._started = nil
self.plugin = plugin
@ -137,34 +100,31 @@ end
---@param fn async fun()
function Task:async(fn)
local co = coroutine.create(fn)
table.insert(self._threads, co)
Scheduler.add({ task = self, thread = co })
end
---@param co thread
function Task:_step(co)
local status = coroutine.status(co)
if status == "suspended" then
local ok, res = coroutine.resume(co)
if not ok then
self:notify_error(tostring(res))
elseif res then
self:notify(tostring(res))
end
end
for _, t in ipairs(self._threads) do
if coroutine.status(t) ~= "dead" then
return
end
end
self:_done()
local async = Async.run(fn, {
on_done = function()
self:_done()
end,
on_error = function(err)
self:notify_error(err)
end,
on_yield = function(res)
self:notify(res)
end,
})
table.insert(self._running, async)
end
---@private
function Task:_done()
assert(self:has_started(), "task not started")
assert(not self:has_ended(), "task already done")
for _, t in ipairs(self._running) do
if t:running() then
return
end
end
self._ended = vim.uv.hrtime()
if self._opts.on_done then
self._opts.on_done(self)