From c61254d7023e3f6aafbc4b70e7a82be2345b3a4e Mon Sep 17 00:00:00 2001 From: Guillaume ARM Date: Tue, 9 Jun 2026 18:22:24 +0200 Subject: [PATCH] fix(ai): poll async assistant messages --- apis/libai.lua | 34 +++++++++++++++++------ docs/opencode_api.md | 6 ++--- manifest.json | 2 +- packages/index.json | 4 +-- packages/trapos-ai/ccpm.json | 2 +- packages/trapos/ccpm.json | 2 +- tests/ai.lua | 52 +++++++++++++++++++++++++++++++----- 7 files changed, 79 insertions(+), 23 deletions(-) diff --git a/apis/libai.lua b/apis/libai.lua index 6cf3d66..2ecb08e 100644 --- a/apis/libai.lua +++ b/apis/libai.lua @@ -217,7 +217,7 @@ local function createAi(opts) local function createMessageId() local t = math.floor(nowFunc() * 1000); - return 'cc_' .. tostring(t) .. '_' .. tostring(math.random(100000, 999999)); + return 'msg_' .. tostring(t) .. '_' .. tostring(math.random(100000, 999999)); end local function isMessageComplete(message) @@ -230,14 +230,31 @@ local function createAi(opts) return type(message.info.time) == 'table' and message.info.time.completed ~= nil; end - local function decodeMessage(body) - local decoded = textutils.unserializeJSON(body); + local function decodeMessage(value) + local decoded = value; + if type(value) == 'string' then + decoded = textutils.unserializeJSON(value); + end if type(decoded) ~= 'table' or type(decoded.parts) ~= 'table' then return nil, 'reponse message invalide'; end return decoded, nil; end + local function findAssistantMessage(messages, submittedMessageId) + local seenSubmitted = false; + for _, message in ipairs(messages) do + if type(message) == 'table' and type(message.info) == 'table' then + if message.info.id == submittedMessageId then + seenSubmitted = true; + elseif seenSubmitted and message.info.role == 'assistant' then + return message; + end + end + end + return nil; + end + local function handleMissingSession(persist) if persist then settingsLib.unset('opencc.session_id'); @@ -252,17 +269,18 @@ local function createAi(opts) local function pollMessage(cfg, sessionId, messageId, persist) local deadline = nowFunc() + cfg.pollTimeoutSeconds; while true do - local body, code = doGet(cfg, '/session/' .. sessionId .. '/message/' .. messageId); + local body, code = doGet(cfg, '/session/' .. sessionId .. '/message'); if not body then return false, code; end if code == 404 then return handleMissingSession(persist); end if code and code ~= 200 then return false, 'erreur message: HTTP ' .. tostring(code); end - local decoded, decodeErr = decodeMessage(body); - if not decoded then return false, decodeErr; end - local reply = extractTextParts(decoded.parts); - if reply ~= '' and isMessageComplete(decoded) then + local messages = textutils.unserializeJSON(body); + if type(messages) ~= 'table' then return false, 'reponse message invalide'; end + local decoded = findAssistantMessage(messages, messageId); + local reply = decoded and extractTextParts(decoded.parts) or ''; + if decoded and reply ~= '' and isMessageComplete(decoded) then return true, { reply = reply, sessionId = sessionId, messageId = messageId }; end if nowFunc() >= deadline then diff --git a/docs/opencode_api.md b/docs/opencode_api.md index 30248dc..561ac6e 100644 --- a/docs/opencode_api.md +++ b/docs/opencode_api.md @@ -102,7 +102,7 @@ Parts can include non-text types (`tool-call`, `step-start`, etc.) — collect a ### `GET /session/:id/message/:messageID` -Get a message by ID. `ai` uses this to poll async prompts until the assistant message has text parts and completion metadata. +Get a message by ID. Opencode validates caller-provided message IDs; use IDs starting with `msg`. **Response** `200`: ```json @@ -130,9 +130,9 @@ Abort a running generation. ### `POST /session/:id/prompt_async` -Fire-and-forget variant. Returns `204` immediately. Include `messageID` in the request body to make the assistant response addressable by `GET /session/:id/message/:messageID`. +Fire-and-forget variant. Returns `204` immediately. Include `messageID` in the request body to make the assistant response addressable by `GET /session/:id/message/:messageID`. Opencode validates caller-provided message IDs; use IDs starting with `msg`. -`ai` uses this endpoint by default to avoid `504` failures from the blocking `/message` endpoint when the LLM takes longer than one HTTP request timeout. +`ai` uses this endpoint by default to avoid `504` failures from the blocking `/message` endpoint when the LLM takes longer than one HTTP request timeout. The submitted `messageID` identifies the user message; `ai` polls `GET /session/:id/message` and reads the completed assistant message that follows it. --- diff --git a/manifest.json b/manifest.json index fca3756..d6fab76 100644 --- a/manifest.json +++ b/manifest.json @@ -1,6 +1,6 @@ { "name": "TrapOS", - "version": "0.6.3", + "version": "0.6.4", "branch": "next", "packages": [ "trapos" diff --git a/packages/index.json b/packages/index.json index cfc547a..be34388 100644 --- a/packages/index.json +++ b/packages/index.json @@ -5,8 +5,8 @@ "trapos-boot": "0.2.2", "trapos-net": "0.2.1", "trapos-ui": "0.2.2", - "trapos-ai": "0.5.2", + "trapos-ai": "0.5.3", "trapos-sandbox": "0.1.0", - "trapos": "0.6.3" + "trapos": "0.6.4" } } diff --git a/packages/trapos-ai/ccpm.json b/packages/trapos-ai/ccpm.json index d45617f..559803d 100644 --- a/packages/trapos-ai/ccpm.json +++ b/packages/trapos-ai/ccpm.json @@ -1,6 +1,6 @@ { "name": "trapos-ai", - "version": "0.5.2", + "version": "0.5.3", "description": "TrapOS AI client for opencode serve", "dependencies": ["trapos-core"], "files": [ diff --git a/packages/trapos/ccpm.json b/packages/trapos/ccpm.json index d1cbb44..cabb747 100644 --- a/packages/trapos/ccpm.json +++ b/packages/trapos/ccpm.json @@ -1,6 +1,6 @@ { "name": "trapos", - "version": "0.6.3", + "version": "0.6.4", "description": "TrapOS full install meta-package", "dependencies": ["trapos-boot", "trapos-net", "trapos-ui", "trapos-test", "trapos-ai"], "files": [], diff --git a/tests/ai.lua b/tests/ai.lua index 4f7878b..3bdf116 100644 --- a/tests/ai.lua +++ b/tests/ai.lua @@ -81,11 +81,22 @@ local function asyncResp() return response(204, ''); end -local function pendingMessageResp(reply) - return response(200, textutils.serializeJSON({ - info = { time = {} }, +local function messageListResp(messages) + return response(200, textutils.serializeJSON(messages)); +end + +local function userMessage(id, text) + return { + info = { id = id, role = 'user' }, + parts = { { type = 'text', text = text } }, + }; +end + +local function assistantMessage(id, reply, completed) + return { + info = { id = id, role = 'assistant', time = completed and { completed = 1 } or {} }, parts = { { type = 'text', text = reply } }, - })); + }; end local function postedText(call) @@ -266,10 +277,34 @@ testlib.test('ask sends exact prompt text', function() testlib.assertEquals(body.parts[1].text, 'my prompt'); end); +testlib.test('ask generates opencode-compatible message ids', function() + local httpStub = fakeHttp( + { messageResp('reply') }, + {} + ); + local settingsStub = fakeSettings({ + ['opencc.server_url'] = 'http://host', + ['opencc.session_id'] = 'ses_1', + }); + local ai = createAi({ + http = httpStub, + settings = settingsStub, + now = function() return 123.456; end, + }); + + ai.ask('hello'); + + local body = textutils.unserializeJSON(httpStub.postCalls[1].body); + testlib.assertTrue(string.find(body.messageID, '^msg_') ~= nil); +end); + testlib.test('ask polls async message until completion', function() local httpStub = fakeHttp( { sessionResp('ses_1'), asyncResp() }, - { pendingMessageResp('partial'), messageResp('reply') } + { + messageListResp({ userMessage('msg_1', 'hello'), assistantMessage('msg_2', 'partial', false) }), + messageListResp({ userMessage('msg_1', 'hello'), assistantMessage('msg_2', 'reply', true) }), + } ); local settingsStub = fakeSettings({ ['opencc.server_url'] = 'http://host' }); local sleeps = {}; @@ -286,14 +321,17 @@ testlib.test('ask polls async message until completion', function() testlib.assertEquals(result.reply, 'reply'); testlib.assertEquals(result.messageId, 'msg_1'); testlib.assertEquals(#httpStub.getCalls, 2); - testlib.assertTrue(string.find(httpStub.getCalls[1].url, '/session/ses_1/message/msg_1', 1, true) ~= nil); + testlib.assertTrue(string.find(httpStub.getCalls[1].url, '/session/ses_1/message', 1, true) ~= nil); testlib.assertEquals(sleeps[1], 3); end); testlib.test('ask polling times out', function() local httpStub = fakeHttp( { sessionResp('ses_1'), asyncResp() }, - { pendingMessageResp('partial'), pendingMessageResp('partial') } + { + messageListResp({ userMessage('msg_1', 'hello'), assistantMessage('msg_2', 'partial', false) }), + messageListResp({ userMessage('msg_1', 'hello'), assistantMessage('msg_2', 'partial', false) }), + } ); local settingsStub = fakeSettings({ ['opencc.server_url'] = 'http://host' }); local now = 0;