From 68b75153b9b76d256cf269f273fc14f183d05af7 Mon Sep 17 00:00:00 2001 From: Zebin Wu Date: Sat, 18 Apr 2026 21:58:55 +0800 Subject: [PATCH 1/8] netif: support get interface name --- bridge/meta/netif.lua | 4 ++++ bridge/src/lnetiflib.c | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/bridge/meta/netif.lua b/bridge/meta/netif.lua index f184a1f..6728383 100644 --- a/bridge/meta/netif.lua +++ b/bridge/meta/netif.lua @@ -28,6 +28,10 @@ function M.find(name) end ---@return netif function M.wait(event, netif) end +---Get interface name. +---@return string +function M.getName(netif) end + ---Whether the interface is up. ---@return boolean function M.isUp(netif) end diff --git a/bridge/src/lnetiflib.c b/bridge/src/lnetiflib.c index 251a419..48ed6f4 100644 --- a/bridge/src/lnetiflib.c +++ b/bridge/src/lnetiflib.c @@ -123,6 +123,19 @@ static int lnetif_wait(lua_State *L) { return lua_yieldk(L, 0, (lua_KContext)ctx, finishwait); } +static int lnetif_get_name(lua_State *L) { + luaL_argcheck(L, lua_islightuserdata(L, 1), 1, "not a lightuserdata"); + pal_net_if *netif = lua_touserdata(L, 1); + + char buf[PAL_NET_IF_NAME_MAX_LEN]; + pal_err err = pal_net_if_get_name(netif, buf); + if (err != PAL_ERR_OK) { + luaL_error(L, "failed to get name: %s", pal_err_string(err)); + } + lua_pushstring(L, buf); + return 1; +} + static int lnetif_is_up(lua_State *L) { luaL_argcheck(L, lua_islightuserdata(L, 1), 1, "not a lightuserdata"); pal_net_if *netif = lua_touserdata(L, 1); @@ -176,6 +189,7 @@ static const luaL_Reg lnetif_funcs[] = { {"getInterfaces", lnetif_get_interfaces}, {"find", lnetif_find}, {"wait", lnetif_wait}, + {"getName", lnetif_get_name}, {"isUp", lnetif_is_up}, {"getIpv4Addr", lnetif_get_ipv4_addr}, {"getIpv6Addrs", lnetif_get_ipv6_addrs}, From 3ee0f17b7fd7a93f3b078412feda12884113f20a Mon Sep 17 00:00:00 2001 From: Zebin Wu Date: Sun, 19 Apr 2026 00:27:17 +0800 Subject: [PATCH 2/8] core: add mq:recvUtil() --- bridge/meta/core.lua | 14 ++- bridge/src/lcorelib.c | 262 +++++++++++++++++++++++++++++++++++------- tests/testcore.lua | 82 +++++++++++++ 3 files changed, 317 insertions(+), 41 deletions(-) diff --git a/bridge/meta/core.lua b/bridge/meta/core.lua index 90387d4..5db3044 100644 --- a/bridge/meta/core.lua +++ b/bridge/meta/core.lua @@ -10,6 +10,7 @@ local timer = {} local mq = {} ---Get current time in milliseconds. +---@return integer time function core.time() end ---Cause normal program termination. @@ -47,10 +48,21 @@ function mq:send(...) end --- ---When the message queue is empty, the current coroutine ---waits here until a message is received. ----@return ... +---@return any ... ---@nodiscard function mq:recv() end +---Receive message until the deadline. +--- +---When the message queue is empty, the current coroutine +---waits here until a message is received or the deadline expires. +---Returns ``true, ...`` on success, or ``false, "timeout"`` on timeout. +---@param deadline integer Absolute deadline in milliseconds. +---@return boolean success +---@return any ... +---@nodiscard +function mq:recvUntil(deadline) end + ---Create a message queue. ---@param size integer Queue size. ---@return MessageQueue diff --git a/bridge/src/lcorelib.c b/bridge/src/lcorelib.c index a112d86..bb3ad39 100644 --- a/bridge/src/lcorelib.c +++ b/bridge/src/lcorelib.c @@ -35,6 +35,12 @@ typedef struct { size_t size; } lcore_mq; +typedef struct { + lua_State *co; + HAPPlatformTimerRef timer; + bool with_status; +} lcore_mq_wait_ctx; + static int lcore_time(lua_State *L) { lua_pushnumber(L, HAPPlatformClockGetCurrent()); return 1; @@ -272,10 +278,189 @@ static size_t lcore_mq_size(lcore_mq *obj) { return obj->first > obj->last ? obj->size - obj->first + obj->last : obj->last - obj->first; } +static void lcore_mq_resume(lua_State *L, lua_State *co, int nargs) { + int status, nres; + lua_xmove(L, co, nargs); + status = lc_resume(co, L, nargs, &nres); + if (luai_unlikely(status != LUA_OK && status != LUA_YIELD)) { + HAPLogError(&lcore_log, "%s: %s", __func__, lua_tostring(L, -1)); + } + lua_pop(L, nres); +} + +static int lcore_mq_recv_ready(lua_State *L, int mq_idx, lcore_mq *obj, bool with_status) { + mq_idx = lua_absindex(L, mq_idx); + HAPAssert(lua_getuservalue(L, mq_idx) == LUA_TTABLE); + int store_idx = lua_gettop(L); + lua_geti(L, store_idx, obj->first); + lua_pushnil(L); + lua_seti(L, store_idx, obj->first); + obj->first++; + if (obj->first > obj->size + 1) { + obj->first = 1; + } + int nargs = luaL_len(L, store_idx + 1); + if (with_status) { + lua_pushboolean(L, true); + } + for (int i = 1; i <= nargs; i++) { + lua_geti(L, store_idx + 1, i); + } + return with_status ? nargs + 1 : nargs; +} + +static bool lcore_mq_wait_remove_at(lua_State *L, int wait_idx, int pos) { + wait_idx = lua_absindex(L, wait_idx); + int wait_count = luaL_len(L, wait_idx); + if (pos != wait_count) { + lua_geti(L, wait_idx, wait_count); + lua_seti(L, wait_idx, pos); + } + lua_pushnil(L); + lua_seti(L, wait_idx, wait_count); + return wait_count == 1; +} + +static void lcore_mq_wait_remove(lua_State *L, lcore_mq_wait_ctx *ctx) { + if (lua_rawgetp(L, LUA_REGISTRYINDEX, ctx) != LUA_TUSERDATA) { + lua_pop(L, 1); + return; + } + + int ctx_idx = lua_gettop(L); + HAPAssert(lua_getiuservalue(L, ctx_idx, 1) == LUA_TUSERDATA); + int mq_idx = lua_gettop(L); + HAPAssert(lua_getuservalue(L, mq_idx) == LUA_TTABLE); + int store_idx = lua_gettop(L); + if (lua_getfield(L, store_idx, "wait") == LUA_TTABLE) { + int wait_idx = lua_gettop(L); + int wait_count = luaL_len(L, wait_idx); + for (int i = 1; i <= wait_count; i++) { + if (lua_geti(L, wait_idx, i) == LUA_TUSERDATA && lua_touserdata(L, -1) == ctx) { + lua_pop(L, 1); + if (lcore_mq_wait_remove_at(L, wait_idx, i)) { + lua_pushnil(L); + lua_setfield(L, store_idx, "wait"); + } + break; + } + lua_pop(L, 1); + } + lua_pop(L, 1); + } else { + lua_pop(L, 1); + } + lua_pop(L, 3); +} + +static void lcore_mq_waitctx_release(lua_State *L, int idx, bool cancel_timer) { + idx = lua_absindex(L, idx); + lcore_mq_wait_ctx *ctx = lua_touserdata(L, idx); + + if (cancel_timer && ctx->timer) { + HAPPlatformTimerDeregister(ctx->timer); + } + ctx->timer = 0; + + lua_pushnil(L); + lua_rawsetp(L, LUA_REGISTRYINDEX, ctx); + lua_pushnil(L); + lua_setiuservalue(L, idx, 1); + ctx->co = NULL; +} + +static int lcore_mq_wait_timeout_resume(lua_State *L) { + lcore_mq_wait_ctx *ctx = lua_touserdata(L, 1); + lua_pop(L, 1); + + if (lua_rawgetp(L, LUA_REGISTRYINDEX, ctx) != LUA_TUSERDATA) { + lua_pop(L, 1); + return 0; + } + + lua_State *co = ctx->co; + lcore_mq_wait_remove(L, ctx); + lcore_mq_waitctx_release(L, -1, false); + lua_pop(L, 1); + + if (!co) { + return 0; + } + + lua_pushboolean(L, false); + lua_pushliteral(L, "timeout"); + lcore_mq_resume(L, co, 2); + return 0; +} + +static void lcore_mq_wait_timeout_cb(HAPPlatformTimerRef timer, void *context) { + lcore_mq_wait_ctx *ctx = context; + if (!ctx->co) { + return; + } + lua_State *L = lc_getmainthread(ctx->co); + + ctx->timer = 0; + + HAPAssert(lua_gettop(L) == 0); + + lc_pushtraceback(L); + lua_pushcfunction(L, lcore_mq_wait_timeout_resume); + lua_pushlightuserdata(L, ctx); + int status = lua_pcall(L, 1, 0, 1); + if (luai_unlikely(status != LUA_OK)) { + HAPLogError(&lcore_log, "%s: %s", __func__, lua_tostring(L, -1)); + } + + lua_settop(L, 0); + lc_collectgarbage(L); +} + +static int lcore_mq_wait(lua_State *L, int mq_idx, bool with_status, HAPTime deadline) { + mq_idx = lua_absindex(L, mq_idx); + HAPAssert(lua_getuservalue(L, mq_idx) == LUA_TTABLE); + int store_idx = lua_gettop(L); + int type = lua_getfield(L, store_idx, "wait"); + if (type == LUA_TNIL) { + lua_pop(L, 1); + lua_createtable(L, 1, 0); + lua_pushvalue(L, -1); + lua_setfield(L, store_idx, "wait"); + } else { + HAPAssert(type == LUA_TTABLE); + } + + int wait_idx = lua_gettop(L); + lcore_mq_wait_ctx *ctx = lua_newuserdatauv(L, sizeof(*ctx), 1); + ctx->co = L; + ctx->timer = 0; + ctx->with_status = with_status; + lua_pushvalue(L, mq_idx); + lua_setiuservalue(L, -2, 1); + + int wait_pos = luaL_len(L, wait_idx) + 1; + lua_pushvalue(L, -1); + lua_seti(L, wait_idx, wait_pos); + if (deadline) { + lua_pushvalue(L, -1); + lua_rawsetp(L, LUA_REGISTRYINDEX, ctx); + if (luai_unlikely(HAPPlatformTimerRegister(&ctx->timer, + deadline, lcore_mq_wait_timeout_cb, ctx) != kHAPError_None)) { + if (lcore_mq_wait_remove_at(L, wait_idx, wait_pos)) { + lua_pushnil(L); + lua_setfield(L, store_idx, "wait"); + } + lcore_mq_waitctx_release(L, -1, false); + luaL_error(L, "failed to create a timer"); + } + } + lua_pop(L, 3); + return lua_yield(L, 0); +} + static int lcore_mq_send(lua_State *L) { lcore_mq *obj = luaL_checkudata(L, 1, LUA_MQ_OBJ_NAME); int narg = lua_gettop(L) - 1; - int status, nres; lua_getuservalue(L, 1); @@ -284,22 +469,25 @@ static int lcore_mq_send(lua_State *L) { lua_setfield(L, -3, "wait"); // que.wait = nil int waiting = luaL_len(L, -1); for (int i = 1; i <= waiting; i++) { - HAPAssert(lua_geti(L, -1, i) == LUA_TTHREAD); - lua_State *co = lua_tothread(L, -1); + HAPAssert(lua_geti(L, -1, i) == LUA_TUSERDATA); + lcore_mq_wait_ctx *ctx = lua_touserdata(L, -1); + lua_State *co = ctx->co; + bool with_status = ctx->with_status; + lcore_mq_waitctx_release(L, -1, true); lua_pop(L, 1); - int max = 1 + narg; - if (luai_unlikely(!lua_checkstack(L, narg))) { + int nargs = narg + with_status; + if (luai_unlikely(!lua_checkstack(L, nargs))) { luaL_error(L, "stack overflow"); } - for (int i = 2; i <= max; i++) { - lua_pushvalue(L, i); + if (with_status) { + lua_pushboolean(L, true); } - lua_xmove(L, co, narg); - status = lc_resume(co, L, narg, &nres); - if (luai_unlikely(status != LUA_OK && status != LUA_YIELD)) { - HAPLogError(&lcore_log, "%s: %s", __func__, lua_tostring(L, -1)); + for (int j = 2; j <= 1 + narg; j++) { + lua_pushvalue(L, j); + } + if (co) { + lcore_mq_resume(L, co, nargs); } - lua_pop(L, nres); } } else { if (lcore_mq_size(obj) == obj->size) { @@ -326,37 +514,30 @@ static int lcore_mq_recv(lua_State *L) { if (lua_gettop(L) != 1) { luaL_error(L, "invalid arguements"); } - lua_getuservalue(L, 1); if (obj->last == obj->first) { - int type = lua_getfield(L, 2, "wait"); - if (type == LUA_TNIL) { - lua_pop(L, 1); - lua_createtable(L, 1, 0); - lua_pushthread(L); - lua_seti(L, 3, 1); - lua_setfield(L, 2, "wait"); - } else { - HAPAssert(type == LUA_TTABLE); - lua_pushthread(L); - lua_seti(L, 3, luaL_len(L, 3) + 1); - lua_pop(L, 1); - } - lua_pop(L, 1); - return lua_yield(L, 0); + return lcore_mq_wait(L, 1, false, 0); } else { - lua_geti(L, 2, obj->first); - lua_pushnil(L); - lua_seti(L, 2, obj->first); - obj->first++; - if (obj->first > obj->size + 1) { - obj->first = 1; - } - int nargs = luaL_len(L, 3); - for (int i = 1; i <= nargs; i++) { - lua_geti(L, 3, i); - } - return nargs; + return lcore_mq_recv_ready(L, 1, obj, false); + } +} + +static int lcore_mq_recv_until(lua_State *L) { + lcore_mq *obj = luaL_checkudata(L, 1, LUA_MQ_OBJ_NAME); + lua_Integer deadline = luaL_checkinteger(L, 2); + luaL_argcheck(L, deadline >= 0, 2, "deadline out of range"); + if (lua_gettop(L) != 2) { + luaL_error(L, "invalid arguements"); + } + if (obj->last != obj->first) { + return lcore_mq_recv_ready(L, 1, obj, true); + } + + if ((HAPTime)deadline <= HAPPlatformClockGetCurrent()) { + lua_pushboolean(L, false); + lua_pushliteral(L, "timeout"); + return 2; } + return lcore_mq_wait(L, 1, true, (HAPTime)deadline); } static int lcore_mq_tostring(lua_State *L) { @@ -380,6 +561,7 @@ static const luaL_Reg lcore_mq_metameth[] = { static const luaL_Reg lcore_mq_meth[] = { {"send", lcore_mq_send}, {"recv", lcore_mq_recv}, + {"recvUntil", lcore_mq_recv_until}, {NULL, NULL}, }; diff --git a/tests/testcore.lua b/tests/testcore.lua index 2180d43..4f83a93 100644 --- a/tests/testcore.lua +++ b/tests/testcore.lua @@ -1,4 +1,9 @@ local core = require "core" +local floor = math.floor + +local function spawn(fn, ...) + core.createTimer(fn, ...):start(0) +end -- Tests repeated timer callbacks with a userdata argument. do @@ -48,3 +53,80 @@ do core.sleep(20) assert(ran == true) end + +-- Tests recvUntil times out without consuming future messages. +do + local mq = core.createMQ(1) + local ok, err = mq:recvUntil(floor(core.time())) + assert(ok == false) + assert(err == "timeout") + + core.createTimer(function(queue) + queue:send("late") + end, mq):start(20) + + ok, err = mq:recvUntil(floor(core.time()) + 5) + assert(ok == false) + assert(err == "timeout") + local success, value = mq:recvUntil(floor(core.time()) + 100) + assert(success == true) + assert(value == "late") +end + +-- Tests recvUntil receives queued messages before the deadline. +do + local mq = core.createMQ(1) + + core.createTimer(function(queue) + queue:send("ok", 42) + end, mq):start(20) + + local ok, a, b = mq:recvUntil(floor(core.time()) + 100) + assert(ok == true) + assert(a == "ok") + assert(b == 42) +end + +-- Tests recvUntil returns queued messages immediately. +do + local mq = core.createMQ(2) + + mq:send("ready", 7) + + local ok, a, b = mq:recvUntil(floor(core.time()) + 100) + assert(ok == true) + assert(a == "ready") + assert(b == 7) +end + +-- Tests a timed out waiter is removed without affecting other waiters. +do + local mq = core.createMQ(1) + local done = core.createMQ(2) + + spawn(function(queue, out) + local ok, err = queue:recvUntil(floor(core.time()) + 10) + out:send("short", ok, err) + end, mq, done) + + spawn(function(queue, out) + local ok, value = queue:recvUntil(floor(core.time()) + 100) + out:send("long", ok, value) + end, mq, done) + + core.createTimer(function(queue) + queue:send("payload") + end, mq):start(30) + + local tag1, ok1, value1 = done:recv() + local tag2, ok2, value2 = done:recv() + local results = { + [tag1] = {ok1, value1}, + [tag2] = {ok2, value2}, + } + + assert(results.short[1] == false) + assert(results.short[2] == "timeout") + assert(results.long[1] == true) + assert(results.long[2] == "payload") +end From d1dfbd262210cfa42ffff903ad1957b200d8e57d Mon Sep 17 00:00:00 2001 From: Zebin Wu Date: Sun, 19 Apr 2026 00:28:00 +0800 Subject: [PATCH 3/8] socket: bindif only support interface name --- bridge/meta/socket.lua | 4 ++-- bridge/src/lsocketlib.c | 21 +-------------------- tests/testsocket.lua | 28 ++++++++++++++++++---------- 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/bridge/meta/socket.lua b/bridge/meta/socket.lua index 61eecbc..7ca8cc1 100644 --- a/bridge/meta/socket.lua +++ b/bridge/meta/socket.lua @@ -19,8 +19,8 @@ function socket:reuseaddr() end ---Bind a socket to a network interface. --- ---Call this before `bind`, `connect` or `listen` when you need traffic to stay on a specific interface. ----@param netif netif|string Network interface object or interface name. -function socket:bindif(netif) end +---@param ifname string Network interface name. +function socket:bindif(ifname) end ---Bind a socket to a local IP address and port. ---@param addr string Local address to use. diff --git a/bridge/src/lsocketlib.c b/bridge/src/lsocketlib.c index 7822680..a88c89c 100644 --- a/bridge/src/lsocketlib.c +++ b/bridge/src/lsocketlib.c @@ -5,7 +5,6 @@ // See [CONTRIBUTORS.md] for the list of homekit-bridge project authors. #include -#include #include #include #include @@ -92,25 +91,7 @@ static int lsocket_obj_reuseaddr(lua_State *L) { static int lsocket_obj_bindif(lua_State *L) { lsocket_obj *obj = lsocket_obj_get(L, 1); - const char *netif_name; - char buf[PAL_NET_IF_NAME_MAX_LEN]; - - switch (lua_type(L, 2)) { - case LUA_TSTRING: - netif_name = lua_tostring(L, 2); - break; - case LUA_TLIGHTUSERDATA: { - pal_net_if *netif = lua_touserdata(L, 2); - pal_err err = pal_net_if_get_name(netif, buf); - if (luai_unlikely(err != PAL_ERR_OK)) { - luaL_error(L, pal_err_string(err)); - } - netif_name = buf; - break; - } - default: - return luaL_argerror(L, 2, "expected netif or interface name"); - } + const char *netif_name = luaL_checkstring(L, 2); pal_err err = pal_socket_bind_netif(&obj->socket, netif_name); if (luai_unlikely(err != PAL_ERR_OK)) { diff --git a/tests/testsocket.lua b/tests/testsocket.lua index 616790f..6681fcb 100644 --- a/tests/testsocket.lua +++ b/tests/testsocket.lua @@ -1,3 +1,4 @@ +local core = require "core" local socket = require "socket" local netif = require "netif" @@ -10,6 +11,13 @@ local function fillStr(n, fill) return s .. fill:sub(0, n - #s) end +local function bindEphemeral(sock, addr) + sock:bind(addr, 0) + local _, port = sock:getsockname() + assert(port > 0 and port <= 65535) + return port +end + ---Test socket.create() with valid parameters. for _, type in ipairs({"TCP", "UDP"}) do for _, domain in ipairs({"IPV4", "IPV6"}) do @@ -78,15 +86,15 @@ do assert(port > 0 and port <= 65535) end ----Test socket.bindif() with a netif object. +---Test socket.bindif() rejects a netif object. do local loopback = netif.find("lo") - assert(netif.getIpv4Addr(loopback) == "127.0.0.1") local sock = socket.create("UDP", "IPV4") - sock:bindif(loopback) + local success = pcall(sock.bindif, sock, loopback) + assert(success == false) end ----Test socket.bindif() with a interface name. +---Test socket.bindif() with an interface name. do local sock = socket.create("UDP", "IPV4") sock:bindif("lo") @@ -105,14 +113,14 @@ do local sock2 = socket.create("UDP", "IPV4") sock1:reuseaddr() sock2:reuseaddr() - sock1:bind("127.0.0.1", 8889) - sock2:bind("127.0.0.1", 8889) + local port = bindEphemeral(sock1, "127.0.0.1") + sock2:bind("127.0.0.1", port) end ---Test UDP socket echo do local server = socket.create("UDP", "IPV4") - server:bind("127.0.0.1", 8888) + local port = bindEphemeral(server, "127.0.0.1") core.createTimer(function () while true do local msg, addr, port = server:recvfrom(1024) @@ -124,7 +132,7 @@ do end end):start(0) local client = socket.create("UDP", "IPV4") - client:connect("127.0.0.1", 8888) + client:connect("127.0.0.1", port) for i = 1, 100, 1 do local msg = fillStr(1024) assert(client:send(msg) == #msg) @@ -136,7 +144,7 @@ end ---Test TCP socket echo do local listener = socket.create("TCP", "IPV4") - listener:bind("127.0.0.1", 8888) + local port = bindEphemeral(listener, "127.0.0.1") listener:listen(1024) core.createTimer(function () while true do @@ -152,7 +160,7 @@ do end end):start(0) local client = socket.create("TCP", "IPV4") - client:connect("127.0.0.1", 8888) + client:connect("127.0.0.1", port) for i = 1, 5, 1 do local msg = fillStr(1024) assert(client:send(msg) == #msg) From 8dc9a914eb9a2f54872f4069230e0521d3f0a2ae Mon Sep 17 00:00:00 2001 From: Zebin Wu Date: Sun, 19 Apr 2026 00:30:58 +0800 Subject: [PATCH 4/8] miio: implement transport to provide only one socket for all requests/scans --- plugins/miio/device.lua | 22 +- plugins/miio/protocol.lua | 318 ++++++++++++++++++--------- plugins/miio/transport.lua | 289 ++++++++++++++++++++++++ tests/test.lua | 3 +- tests/testmiiotransport.lua | 422 ++++++++++++++++++++++++++++++++++++ 5 files changed, 952 insertions(+), 102 deletions(-) create mode 100644 plugins/miio/transport.lua create mode 100644 tests/testmiiotransport.lua diff --git a/plugins/miio/device.lua b/plugins/miio/device.lua index b7a7564..80177c9 100644 --- a/plugins/miio/device.lua +++ b/plugins/miio/device.lua @@ -8,6 +8,7 @@ local tunpack = table.unpack local tinsert = table.insert local M = {} +local currentRuntime = nil ---@class MiotIID:table MIOT instance ID. --- @@ -153,6 +154,14 @@ function device:request(method, ...) return self.pcb:request(self.timeout, method, ...) end +---@return MiioProtocolRuntime runtime +local function ensureRuntime() + if currentRuntime == nil then + currentRuntime = protocol.create() + end + return currentRuntime +end + ---Create a device object. ---@param addr string Device address. ---@param token string Device token. @@ -166,7 +175,7 @@ function M.create(addr, token) ---@class MiioDevice local o = { logger = log.getLogger("miio.device:" .. addr), - pcb = protocol.create(addr, util.hex2bin(token)), + pcb = ensureRuntime():createPcb(addr, util.hex2bin(token)), mapping = false, addr = addr, timeout = 1000, @@ -184,9 +193,16 @@ function M.create(addr, token) end ---Initialize the miIO device module. +---@param netifs? string[] Network interface names. ---@param virtualDid? integer Virtual device ID: 64-bit. -function M.init(virtualDid) - protocol.init(virtualDid) +---@return MiioProtocolRuntime runtime +function M.init(netifs, virtualDid) + local nextRuntime = protocol.create(netifs, virtualDid) + if currentRuntime ~= nil then + currentRuntime:close() + end + currentRuntime = nextRuntime + return nextRuntime end return M diff --git a/plugins/miio/protocol.lua b/plugins/miio/protocol.lua index 1819cd8..fbbe365 100644 --- a/plugins/miio/protocol.lua +++ b/plugins/miio/protocol.lua @@ -1,13 +1,15 @@ -local socket = require "socket" local hash = require "hash" local cipher = require "cipher" local json = require "cjson" +local transport = require "miio.transport" local assert = assert +local pcall = pcall local type = type local error = error local floor = math.floor local random = math.random +local tointeger = math.tointeger local spack = string.pack local sunpack = string.unpack local schar = string.char @@ -16,7 +18,12 @@ local tconcat = table.concat local M = {} local logger = log.getLogger("miio.protocol") -local defaultVirtualDid = nil + +---@class MiioProtocolRuntime +---@field transport MiioTransport? +---@field virtualDid integer? +---@field _reqid integer +local runtime = {} --- --- Message format @@ -132,7 +139,7 @@ local function createEncryption(token) end ---Create a virtual device ID for probe packets. ----@return integer virtualDid +---@return integer did ---@nodiscard local function createVirtualDid() local now = floor(core.time()) @@ -141,14 +148,27 @@ local function createVirtualDid() return (high << 32) | low end ----Get the default virtual device ID. ----@return integer virtualDid ----@nodiscard -local function getDefaultVirtualDid() - if defaultVirtualDid == nil then - defaultVirtualDid = createVirtualDid() +---@param netifs? string[] +---@param virtualDid? integer +---@return string[]? netifs +---@return integer? virtualDid +local function normalizeCreateArgs(netifs, virtualDid) + assert(netifs == nil or type(netifs) == "table", "netifs must be a table") + if virtualDid ~= nil then + virtualDid = assert(tointeger(virtualDid), "virtualDid must be an integer") + end + return netifs, virtualDid +end + +---@param self MiioProtocolRuntime +---@return integer reqid +local function nextRequestId(self) + local reqid = self._reqid + 1 + if reqid > 9999 then + reqid = 1 end - return defaultVirtualDid + self._reqid = reqid + return reqid end ---Pack a message to a binary package. @@ -174,21 +194,21 @@ local function pack(did, stamp, token, data) end assert(#checksum == 16) - return tconcat({header, checksum, data or ""}) + return header .. checksum .. (data or "") end ---Pack a probe packet. --- ---The probe packet keeps the 32-byte miIO header layout, but replaces the ---MD5/token area with `MDID + virtualDid + 0x00000000`. ----@param virtualDid integer Virtual device ID: 64-bit. +---@param did integer Virtual device ID: 64-bit. ---@return string package ---@nodiscard -local function packProbe(virtualDid) +local function packProbe(did) return tconcat({ spack(">I2I2I8I4", 0x2131, 32, -1, 0xffffffff), "MDID", - spack(">I8I4", virtualDid, 0), + spack(">I8I4", did, 0), }) end @@ -227,10 +247,52 @@ local function unpack(package, token) data end +---@param queueSize integer +---@param currentTp MiioTransport +---@param matcher fun(packet:string, addr:string, port:integer, ifname:string):any? +---@return MessageQueue queue +---@return fun() close +local function createReceiver(queueSize, currentTp, matcher) + assert(currentTp, "transport not inited") + + local queue = core.createMQ(queueSize) + local subId = currentTp:subscribe(function(packet, addr, port, ifname) + local result = matcher(packet, addr, port, ifname) + if result ~= nil then + queue:send(result) + end + end) + local closed = false + + return queue, function() + if closed then + return + end + closed = true + currentTp:unsubscribe(subId) + end +end + +---@param queueSize integer +---@param currentTp MiioTransport +---@param matcher fun(packet:string, addr:string, port:integer, ifname:string):any? +---@param action fun(queue:MessageQueue):any +---@return any result +local function withReceiver(queueSize, currentTp, matcher, action) + local queue, close = createReceiver(queueSize, currentTp, matcher) + local ok, result = pcall(action, queue) + close() + if not ok then + error(result) + end + return result +end + ---@class ScanResult Scan Result. --- ---@field addr string Device address. ---@field devid integer Device ID: 64-bit. +---@field ifname string Network interface name. ---@field stamp integer Device time stamp. ---Scan for devices in the local network. @@ -238,58 +300,62 @@ end ---This method is used to discover supported devices by sending ---a probe message to the broadcast address on port 54321. ---If the target IP address is given, the probe will be send as an unicast packet. +---@param self MiioProtocolRuntime ---@param timeout integer Timeout period (in milliseconds). ---@param addr? string Target Address. ---@return ScanResult[] results A array of scan results. ---@nodiscard -function M.scan(timeout, addr) +function runtime:scan(timeout, addr) assert(timeout > 0, "timeout must be greater then 0") + local currentTp = assert(self.transport, "protocol closed") + local virtualDid = assert(self.virtualDid, "protocol closed") local numSend = 1 - local sock = socket.create("UDP", "IPV4") - sock:settimeout(timeout) + local probe = packProbe(virtualDid) if not addr then numSend = 3 - sock:enablebroadcast() - end - - local probe = packProbe(getDefaultVirtualDid()) - for _ = 1, numSend do - assert(sock:sendto(probe, addr or "255.255.255.255", 54321), "failed to send probe message") end local seen = {} - local results = {} - - while true do - local success, result, fromAddr, _ = pcall(sock.recvfrom, sock, 1024) - if success == false then - if addr == nil and result:find("timeout") then - return results - end - error(result) + local deadline = floor(core.time()) + timeout + return withReceiver(64, currentTp, function(packet, fromAddr, _, ifname) + if addr ~= nil and fromAddr ~= addr then + return + end + local success, did, stamp, data = pcall(unpack, packet) + if not success or did == -1 or data ~= nil or seen[fromAddr] then + return end - local did, stamp, data = unpack(result) - if did == -1 or data then - goto continue + seen[fromAddr] = true + return { + addr = fromAddr, + devid = did, + ifname = ifname, + stamp = stamp, + } + end, function(queue) + local results = {} + for _ = 1, numSend do + assert(currentTp:sendto(probe, addr or "255.255.255.255", 54321) > 0, "failed to send probe message") end - if not seen[fromAddr] then - table.insert(results, { - addr = fromAddr, - devid = did, - stamp = stamp - }) + + while true do + local ok, result = queue:recvUntil(deadline) + if not ok then + break + end + results[#results + 1] = result if addr then - return results + break end - seen[fromAddr] = true end -::continue:: - end + return results + end) end ---@class MiioPcb: table miio protocol control block. +---@field runtime MiioProtocolRuntime local pcb = {} ---@class MiioError miIO error. @@ -297,14 +363,39 @@ local pcb = {} ---@field code integer Error code. ---@field message string Error message. +---@param self MiioPcb +---@param raw string +---@param reqid integer +---@return any? decoded +local function decodeResponse(self, raw, reqid) + local success, did, _, data = pcall(unpack, raw, self.token) + if not success or did ~= self.devid or data == nil then + return + end + + local ok, payload = pcall(self.encryption.decrypt, self.encryption, data) + if not ok or payload == nil then + return + end + + local decodedOk, decoded = pcall(json.decode, payload) + if not decodedOk or decoded == nil or decoded.id ~= reqid then + return + end + + return decoded +end + ---Handshake. ---@param timeout integer Timeout period (in milliseconds). function pcb:handshake(timeout) logger:debug("Handshake ...") - local results = M.scan(timeout, self.addr) + local results = self.runtime:scan(timeout, self.addr) local result = results[1] + assert(result ~= nil, "handshake failed") logger:debug("Handshake done.") self.devid = result.devid + self.ifname = result.ifname self.stampDiff = floor(core.time() / 1000) - result.stamp end @@ -326,77 +417,86 @@ function pcb:request(timeout, method, ...) self:handshake(timeout) end - local sock = socket.create("UDP", "IPV4") - sock:settimeout(timeout) - sock:connect(self.addr, 54321) + local reqid = nextRequestId(self.runtime) - local reqid = self.reqid + 1 - if reqid == 9999 then - reqid = 1 - end - self.reqid = reqid - do - local data = json.encode({ - id = reqid, - method = method, - params = params - }) - - sock:send(pack(self.devid, floor(core.time() / 1000) - self.stampDiff, - self.token, self.encryption:encrypt(data))) - - logger:debug(("%s => %s"):format(data, self.addr)) - end + local plain = json.encode({ + id = reqid, + method = method, + params = params + }) + local packet = pack( + self.devid, + floor(core.time() / 1000) - self.stampDiff, + self.token, + self.encryption:encrypt(plain) + ) + local currentTp = assert(self.runtime.transport, "protocol closed") + local deadline = floor(core.time()) + timeout + logger:debug(("%s => %s"):format(plain, self.addr)) + local result = withReceiver(1, currentTp, function(raw, fromAddr, _, ifname) + if fromAddr ~= self.addr then + return + end + local decoded = decodeResponse(self, raw, reqid) + if decoded == nil then + return + end + self.ifname = ifname + return decoded + end, function(queue) + local sent = 0 + if self.ifname ~= nil then + local ok, sendResult = pcall(currentTp.sendto, currentTp, packet, self.addr, 54321, self.ifname) + if ok then + sent = sendResult + else + logger:debug(("send via %s failed, retry all netifs, %s"):format(self.ifname, tostring(sendResult))) + self.ifname = nil + end + end + if sent == 0 then + sent = currentTp:sendto(packet, self.addr, 54321) + end + if sent == 0 then + error("failed to send request") + end - local success, result = pcall(sock.recv, sock, 1024) - if success == false then - if result:find("timeout") then + local ok, recvResult = queue:recvUntil(deadline) + if not ok then + self.ifname = nil self.stampDiff = nil + error(recvResult) end - error(result) - end - self.errCnt = 0 - local did, _, data = unpack(result, self.token) + return recvResult + end) - if did ~= self.devid or data == nil then - error("Receive a invalid message.") - end - local s = self.encryption:decrypt(data) - if not s then - error("Failed to decrypt the message.") - end - logger:debug(("%s => %s"):format(self.addr, s)) - local payload = json.decode(s) - if not payload then - error("Failed to parse the JSON string.") - end - if payload.id ~= reqid then - error("response id ~= request id") - end + logger:debug(("%s => %s"):format(self.addr, json.encode(result))) ---@class MiioError - local err = payload.error + local err = result.error if err then error(err) end - return payload.result + return result.result end ---Create a PCB(protocol control block). +---@param self MiioProtocolRuntime ---@param addr string Device address. ---@param token string Device token: 128-bit. ---@return MiioPcb pcb Protocol control block. ---@nodiscard -function M.create(addr, token) +function runtime:createPcb(addr, token) assert(type(addr) == "string") assert(type(token) == "string") assert(#token == 16) + assert(self.transport ~= nil, "protocol closed") ---@class MiioPcb local o = { addr = addr, + runtime = self, token = token, - reqid = 0, } o.encryption = createEncryption(token) @@ -408,13 +508,35 @@ function M.create(addr, token) return o end ----Initialize the miIO protocol module. ----@param virtualDid? integer Virtual device ID: 64-bit. -function M.init(virtualDid) - if virtualDid ~= nil then - assert(type(virtualDid) == "number") +---Close the miIO protocol runtime. +---@param self MiioProtocolRuntime +function runtime:close() + local currentTp = self.transport + if currentTp ~= nil then + currentTp:close() + self.transport = nil end - defaultVirtualDid = virtualDid or createVirtualDid() + self.virtualDid = nil +end + +---Create a miIO protocol runtime. +---@param netifs? string[] Network interface names. +---@param virtualDid? integer Virtual device ID: 64-bit. +---@return MiioProtocolRuntime runtime +---@nodiscard +function M.create(netifs, virtualDid) + netifs, virtualDid = normalizeCreateArgs(netifs, virtualDid) + + ---@class MiioProtocolRuntime + local o = { + transport = transport.create(netifs), + virtualDid = virtualDid or createVirtualDid(), + _reqid = 0, + } + + return setmetatable(o, { + __index = runtime + }) end return M diff --git a/plugins/miio/transport.lua b/plugins/miio/transport.lua new file mode 100644 index 0000000..086439a --- /dev/null +++ b/plugins/miio/transport.lua @@ -0,0 +1,289 @@ +local socket = require "socket" +local netiflib = require "netif" + +local assert = assert +local error = error +local type = type +local ipairs = ipairs +local pairs = pairs +local tostring = tostring +local xpcall = xpcall +local tconcat = table.concat +local traceback = debug.traceback + +local M = {} +local logger = log.getLogger("miio.transport") + +local UDP_PORT = 54321 +local MAX_MSG_LEN = 2048 + +---@class MiioTransportSocket +---@field ifname string +---@field sock Socket +---@field reader Timer + +---@class MiioTransport +---@field sockets table +---@field localPort integer +---@field subscribers table +---@field running boolean +local transport = {} + +local function isUsableNetif(netif) + if not netiflib.isUp(netif) then + return false + end + local addr = netiflib.getIpv4Addr(netif) + return addr ~= "0.0.0.0" and addr ~= "127.0.0.1" +end + +local function resolveNetifs(netifs) + if netifs ~= nil then + assert(type(netifs) == "table", "netifs must be a table") + end + + local available = {} + local results = {} + for _, netif in ipairs(netiflib.getInterfaces()) do + local ifname = netiflib.getName(netif) + local usable = isUsableNetif(netif) + available[ifname] = usable + if netifs == nil and usable then + results[#results + 1] = ifname + end + end + + if netifs ~= nil then + local seen = {} + for i, ifname in ipairs(netifs) do + assert(type(ifname) == "string", ("network interface #%d must be a string"):format(i)) + if not seen[ifname] then + local usable = available[ifname] + if usable == nil then + error(("network interface #%d not found"):format(i)) + end + if not usable then + error(("network interface #%d is not usable"):format(i)) + end + results[#results + 1] = ifname + seen[ifname] = true + end + end + end + + assert(#results > 0, "no available IPv4 network interface") + return results +end + +local function destroySocket(ctx) + if ctx == nil or ctx.sock == nil then + return + end + if ctx.reader ~= nil then + ctx.reader:stop() + ctx.reader = nil + end + pcall(ctx.sock.destroy, ctx.sock) + ctx.sock = nil +end + +local function createSockets(netifs) + local sockets = {} + local localPort = nil + local failures = {} + + for _, ifname in ipairs(netifs) do + local sock = socket.create("UDP", "IPV4") + local addr + local success, result = pcall(function() + sock:settimeout(0) + sock:enablebroadcast() + sock:reuseaddr() + sock:bindif(ifname) + sock:bind("0.0.0.0", localPort or 0) + if localPort == nil then + local _, port = sock:getsockname() + localPort = port + end + addr = netiflib.getIpv4Addr(netiflib.find(ifname)) + end) + if success then + sockets[ifname] = { + ifname = ifname, + sock = sock, + } + logger:info(("created socket, %s, %s, %s"):format(ifname, addr, localPort)) + else + logger:error(("create socket error, %s, %s"):format(ifname, tostring(result))) + failures[#failures + 1] = ("%s: %s"):format(ifname, tostring(result)) + pcall(sock.destroy, sock) + end + end + + if localPort ~= nil and #failures == 0 then + return localPort, sockets + end + + for _, ctx in pairs(sockets) do + destroySocket(ctx) + end + if #failures > 0 then + error("failed to bind miio transport sockets: " .. tconcat(failures, "; ")) + end + error("failed to bind miio transport sockets") +end + +local function receivePacket(ctx) + local success, data, addr, port = xpcall(ctx.sock.recvfrom, traceback, ctx.sock, MAX_MSG_LEN) + if not success then + return nil, data + end + if port ~= UDP_PORT then + return false + end + return true, data, addr, port +end + +local function snapshotSubscribers(self) + if next(self.subscribers) == nil then + return nil + end + + local subscribers = {} + for id, handler in pairs(self.subscribers) do + subscribers[#subscribers + 1] = { + id = id, + handler = handler, + } + end + return subscribers +end + +---@param self MiioTransport +function transport:_dispatch(data, addr, port, ifname) + local subscribers = snapshotSubscribers(self) + if subscribers == nil then + return + end + + for _, entry in ipairs(subscribers) do + if self.subscribers[entry.id] ~= entry.handler then + goto continue + end + local success, result = xpcall(entry.handler, traceback, data, addr, port, ifname) + if not success then + logger:error(("subscriber[%s] failed: %s"):format(entry.id, tostring(result))) + end +::continue:: + end +end + +local function receiveLoop(self, ctx) + while self.running and self.sockets[ctx.ifname] == ctx do + local ok, data, addr, port = receivePacket(ctx) + if ok == nil then + if not self.running then + return + end + logger:error(("recvfrom failed, %s, %s"):format(ctx.ifname, tostring(data))) + return + end + if ok then + self:_dispatch(data, addr, port, ctx.ifname) + end + end +end + +---@param self MiioTransport +function transport:close() + if not self.running then + return + end + self.running = false + for _, ctx in pairs(self.sockets) do + destroySocket(ctx) + end + self.sockets = {} + self.subscribers = {} +end + +---@param self MiioTransport +---@param data string +---@param addr string +---@param port integer +---@param ifname? string +---@return integer sent +function transport:sendto(data, addr, port, ifname) + assert(type(data) == "string") + assert(type(addr) == "string") + assert(type(port) == "number") + + local sent = 0 + + if ifname ~= nil then + local ctx = assert(self.sockets[ifname], ("unknown netif '%s'"):format(ifname)) + local success, result = pcall(ctx.sock.sendto, ctx.sock, data, addr, port) + if not success then + error(result) + end + return 1 + end + + for name, ctx in pairs(self.sockets) do + local success, result = pcall(ctx.sock.sendto, ctx.sock, data, addr, port) + if success then + sent = sent + 1 + else + logger:debug(("sendto failed, %s, %s"):format(name, tostring(result))) + end + end + + return sent +end + +---@param self MiioTransport +---@param handler fun(data:string, addr:string, port:integer, ifname:string) +---@return integer id +function transport:subscribe(handler) + assert(type(handler) == "function") + self._subId = (self._subId or 0) + 1 + self.subscribers[self._subId] = handler + return self._subId +end + +---@param self MiioTransport +---@param id integer +function transport:unsubscribe(id) + self.subscribers[id] = nil +end + +---Create a resident miIO UDP transport. +---@param netifs? string[] +---@return MiioTransport transport +---@nodiscard +function M.create(netifs) + local resolved = resolveNetifs(netifs) + local localPort, sockets = createSockets(resolved) + + ---@type MiioTransport + local o = { + sockets = sockets, + localPort = localPort, + subscribers = {}, + running = true, + _subId = 0, + } + + setmetatable(o, { + __index = transport + }) + + for _, ctx in pairs(o.sockets) do + ctx.reader = core.createTimer(receiveLoop, o, ctx) + ctx.reader:start(0) + end + + return o +end + +return M diff --git a/tests/test.lua b/tests/test.lua index 352ddd5..c2b06fd 100644 --- a/tests/test.lua +++ b/tests/test.lua @@ -2,7 +2,8 @@ local suites = { "testcore", "testsocket", "teststream", - "testnvs" + "testnvs", + "testmiiotransport", } local function runSuite(s) diff --git a/tests/testmiiotransport.lua b/tests/testmiiotransport.lua new file mode 100644 index 0000000..8dc168a --- /dev/null +++ b/tests/testmiiotransport.lua @@ -0,0 +1,422 @@ +local core = require "core" +local socket = require "socket" +local netif = require "netif" +local hash = require "hash" +local cipher = require "cipher" +local json = require "cjson" +local transport = require "miio.transport" +local protocol = require "miio.protocol" + +local floor = math.floor +local ipairs = ipairs +local pairs = pairs +local spack = string.pack +local sunpack = string.unpack +local srep = string.rep +local schar = string.char +local tconcat = table.concat + +local UDP_PORT = 54321 +local MAX_MSG_LEN = 2048 + +local function md5(...) + local ctx = hash.create("MD5") + for i = 1, select("#", ...) do + local part = select(i, ...) + if part ~= nil and part ~= "" then + ctx:update(part) + end + end + return ctx:digest() +end + +local function createEncryption(token) + local ctx = cipher.create("AES-128-CBC") + ctx:setPadding("PKCS7") + + local key = md5(token) + local iv = md5(key, token) + + return { + encrypt = function(_, input) + return ctx:process("encrypt", key, iv, input) + end, + decrypt = function(_, input) + return ctx:process("decrypt", key, iv, input) + end, + } +end + +local function pack(did, stamp, token, data) + local len = 32 + (data and #data or 0) + local header = spack(">I2>I2>I8>I4", 0x2131, len, did, stamp) + local checksum = token and md5(header, token, data or "") or srep(schar(0xff), 16) + return tconcat({header, checksum, data or ""}) +end + +local function unpack(packet, token) + assert(sunpack(">I2", packet, 1) == 0x2131) + local len = sunpack(">I2", packet, 3) + assert(len == #packet and len >= 32) + + local data = nil + if len > 32 then + data = sunpack("c" .. len - 32, packet, 33) + end + + if token then + assert(md5(sunpack("c16", packet, 1), token, data or "") == sunpack("c16", packet, 17)) + end + + return sunpack(">I8", packet, 5), + sunpack(">I4", packet, 13), + data +end + +local function findTestNetif() + for _, iface in ipairs(netif.getInterfaces()) do + if netif.isUp(iface) then + local addr = netif.getIpv4Addr(iface) + if addr ~= "0.0.0.0" and addr ~= "127.0.0.1" then + return netif.getName(iface), addr + end + end + end + error("no usable non-loopback ipv4 interface") +end + +local function waitFor(mq, timeout) + local success, ok, result = mq:recvUntil(floor(core.time()) + timeout) + if not success then + return false, ok + end + return ok, result +end + +local function startUdpDevice(handler) + local server = socket.create("UDP", "IPV4") + server:reuseaddr() + server:bind("0.0.0.0", UDP_PORT) + + core.createTimer(function () + while true do + local msg, addr, port = server:recvfrom(MAX_MSG_LEN) + if #msg == 0 then + server:destroy() + return + end + handler(msg, addr, port, server) + end + end):start(0) + + return function(targetAddr) + local wake = socket.create("UDP", "IPV4") + wake:sendto("", targetAddr, UDP_PORT) + core.sleep(20) + end +end + +-- Tests auto-selected transport skips loopback interfaces. +do + local tp = transport.create() + local count = 0 + for ifname in pairs(tp.sockets) do + count = count + 1 + assert(netif.getIpv4Addr(netif.find(ifname)) ~= "127.0.0.1") + end + assert(count > 0) + tp:close() +end + +-- Tests transport creation fails instead of silently dropping requested interfaces. +do + local origSocketCreate = socket.create + local origGetInterfaces = netif.getInterfaces + local origGetName = netif.getName + local origIsUp = netif.isUp + local origGetIpv4Addr = netif.getIpv4Addr + local origFind = netif.find + + local destroyed = {} + local fakeIfs = { + eth0 = { + name = "eth0", + addr = "192.168.10.2", + }, + wlan0 = { + name = "wlan0", + addr = "192.168.20.2", + }, + } + local nextPort = 40000 + + local function restore() + socket.create = origSocketCreate + netif.getInterfaces = origGetInterfaces + netif.getName = origGetName + netif.isUp = origIsUp + netif.getIpv4Addr = origGetIpv4Addr + netif.find = origFind + end + + local ok, err = pcall(function() + netif.getInterfaces = function() + return {fakeIfs.eth0, fakeIfs.wlan0} + end + netif.getName = function(iface) + return iface.name + end + netif.isUp = function(_) + return true + end + netif.getIpv4Addr = function(iface) + return iface.addr + end + netif.find = function(name) + return assert(fakeIfs[name]) + end + + socket.create = function(sockType, familyName) + assert(sockType == "UDP") + assert(familyName == "IPV4") + nextPort = nextPort + 1 + + local o = { + port = nextPort, + } + + function o:settimeout(ms) + assert(ms == 0) + end + + function o:enablebroadcast() end + + function o:reuseaddr() end + + function o:bindif(ifname) + self.ifname = ifname + if ifname == "wlan0" then + error("bindif failed") + end + end + + function o:bind(addr, port) + assert(addr == "0.0.0.0") + assert(type(port) == "number") + end + + function o:getsockname() + return "0.0.0.0", self.port + end + + function o:destroy() + destroyed[self.ifname] = true + end + + return o + end + + local created, createErr = pcall(transport.create, {"eth0", "wlan0"}) + assert(created == false) + assert(createErr:find("failed to bind miio transport sockets", 1, true) ~= nil) + assert(createErr:find("wlan0", 1, true) ~= nil) + end) + + restore() + assert(ok, err) + assert(destroyed.eth0 == true) + assert(destroyed.wlan0 == true) +end + +-- Tests resident transport send/recv flow, source-port filtering and unsubscribe. +do + local bindif, addr = findTestNetif() + local tp = transport.create({bindif}) + local ctx = assert(tp.sockets[bindif]) + local _, localPort = ctx.sock:getsockname() + assert(localPort == tp.localPort) + + local recvMq = core.createMQ(4) + local stopDevice = startUdpDevice(function(msg, fromAddr, fromPort, server) + assert(fromAddr == addr) + assert(fromPort == tp.localPort) + + if msg == "ping" then + assert(server:sendto("pong", fromAddr, fromPort) == 4) + elseif msg == "notify" then + assert(server:sendto("after-unsub", fromAddr, fromPort) == 11) + end + end) + + local subId = tp:subscribe(function(data, fromAddr, port, inboundIfname) + recvMq:send(true, { + data = data, + addr = fromAddr, + port = port, + ifname = inboundIfname, + }) + end) + + assert(tp:sendto("ping", addr, UDP_PORT, bindif) == 1) + local ok, packet = waitFor(recvMq, 1000) + assert(ok == true) + assert(packet.data == "pong") + assert(packet.addr == addr) + assert(packet.port == UDP_PORT) + assert(packet.ifname == bindif) + + local other = socket.create("UDP", "IPV4") + assert(other:sendto("ignored", addr, tp.localPort) == 7) + ok = waitFor(recvMq, 100) + assert(ok == false) + + tp:unsubscribe(subId) + assert(tp:sendto("notify", addr, UDP_PORT, bindif) == 1) + ok = waitFor(recvMq, 100) + assert(ok == false) + + stopDevice(addr) + tp:close() + tp:close() +end + +-- Tests protocol.scan/request over the resident transport with a virtual miIO device. +do + local bindif, addr = findTestNetif() + local token = "0123456789abcdef" + local virtualDid = 0x1111222233334444 + local deviceDid = 0x0102030405060708 + local deviceStamp = floor(core.time()) - 5 + local enc = createEncryption(token) + local stats = { + probes = 0, + requests = 0, + } + + local stopDevice = startUdpDevice(function(msg, fromAddr, fromPort, server) + local ok, did, stamp, data = pcall(unpack, msg) + if ok and did == -1 and stamp == 0xffffffff and data == nil then + stats.probes = stats.probes + 1 + local resp = pack(deviceDid, deviceStamp) + assert(server:sendto(resp, fromAddr, fromPort) == #resp) + return + end + + did, stamp, data = unpack(msg, token) + assert(did == deviceDid) + local req = json.decode(enc:decrypt(data)) + stats.requests = stats.requests + 1 + + local payload = json.encode({ + id = req.id, + result = { + req.method, + req.params and req.params[1] or false, + stats.requests, + }, + }) + local resp = pack(deviceDid, deviceStamp, token, enc:encrypt(payload)) + assert(server:sendto(resp, fromAddr, fromPort) == #resp) + end) + + local runtime = protocol.create({bindif}, virtualDid) + + local results = runtime:scan(1000, addr) + assert(#results == 1) + assert(results[1].addr == addr) + assert(results[1].devid == deviceDid) + assert(results[1].ifname == bindif) + + local pcb = runtime:createPcb(addr, token) + local result1 = pcb:request(1000, "test.echo", "ping") + assert(result1[1] == "test.echo") + assert(result1[2] == "ping") + assert(result1[3] == 1) + assert(pcb.ifname == bindif) + + local result2 = pcb:request(1000, "test.echo", "pong") + assert(result2[1] == "test.echo") + assert(result2[2] == "pong") + assert(result2[3] == 2) + + assert(stats.probes == 2) + assert(stats.requests == 2) + + stopDevice(addr) + runtime:close() +end + +-- Tests concurrent PCBs talking to the same device keep request/response pairs isolated. +do + local bindif, addr = findTestNetif() + local token = "0123456789abcdef" + local deviceDid = 0x1112131415161718 + local deviceStamp = floor(core.time()) - 5 + local enc = createEncryption(token) + local pending = {} + + local stopDevice = startUdpDevice(function(msg, fromAddr, fromPort, server) + local ok, did, stamp, data = pcall(unpack, msg) + if ok and did == -1 and stamp == 0xffffffff and data == nil then + local resp = pack(deviceDid, deviceStamp) + assert(server:sendto(resp, fromAddr, fromPort) == #resp) + return + end + + did, stamp, data = unpack(msg, token) + assert(did == deviceDid) + + local req = json.decode(enc:decrypt(data)) + pending[#pending + 1] = { + addr = fromAddr, + port = fromPort, + req = req, + } + if #pending < 2 then + return + end + + for _, item in ipairs(pending) do + local payload = json.encode({ + id = item.req.id, + result = { + item.req.method, + item.req.params[1], + }, + }) + local resp = pack(deviceDid, deviceStamp, token, enc:encrypt(payload)) + assert(server:sendto(resp, item.addr, item.port) == #resp) + end + pending = {} + end) + + local runtime = protocol.create({bindif}, 0x5555666677778888) + local pcb1 = runtime:createPcb(addr, token) + local pcb2 = runtime:createPcb(addr, token) + local mq = core.createMQ(2) + + core.createTimer(function() + local ok, result = pcall(pcb1.request, pcb1, 1000, "test.echo", "A") + mq:send("pcb1", ok, result) + end):start(0) + + core.createTimer(function() + local ok, result = pcall(pcb2.request, pcb2, 1000, "test.echo", "B") + mq:send("pcb2", ok, result) + end):start(0) + + local results = {} + for _ = 1, 2 do + local name, ok, result = mq:recv() + assert(ok == true, result) + results[name] = result + end + + assert(results.pcb1[1] == "test.echo") + assert(results.pcb1[2] == "A") + assert(results.pcb2[1] == "test.echo") + assert(results.pcb2[2] == "B") + + stopDevice(addr) + runtime:close() +end From 81f8c8e9d0f96f9e58991e47546cf93205ddb50b Mon Sep 17 00:00:00 2001 From: Zebin Wu Date: Sun, 19 Apr 2026 00:32:25 +0800 Subject: [PATCH 5/8] stream: opimize test cases --- tests/teststream.lua | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/teststream.lua b/tests/teststream.lua index fda29ea..a390f54 100644 --- a/tests/teststream.lua +++ b/tests/teststream.lua @@ -4,9 +4,10 @@ local stream = require "stream" local TIMEOUT = 1000 -local function with_tcp_server(port, handler, test) +local function with_tcp_server(handler, test) local listener = socket.create("TCP", "IPV4") - listener:bind("127.0.0.1", port) + listener:bind("127.0.0.1", 0) + local _, port = listener:getsockname() listener:listen(1) core.createTimer(function () @@ -25,7 +26,7 @@ do end -- Tests read() returning partial data when not requesting all bytes. -with_tcp_server(8889, function (server) +with_tcp_server(function (server) server:sendall("he") core.sleep(50) server:sendall("llo") @@ -37,7 +38,7 @@ end, function (port) end) -- Tests readline() leaving unread bytes in the internal stash buffer. -with_tcp_server(8890, function (server) +with_tcp_server(function (server) server:sendall("header\r\nbody") end, function (port) local client = stream.client("TCP", "127.0.0.1", port, TIMEOUT) @@ -47,7 +48,7 @@ end, function (port) end) -- Tests readline() when the separator is split across multiple reads. -with_tcp_server(8891, function (server) +with_tcp_server(function (server) server:sendall("abc\r") core.sleep(50) server:sendall("\nxyz") @@ -59,7 +60,7 @@ end, function (port) end) -- Tests readall() collecting the remaining stream until EOF. -with_tcp_server(8892, function (server) +with_tcp_server(function (server) server:sendall("chunk1") core.sleep(20) server:sendall("chunk2") @@ -70,7 +71,7 @@ end, function (port) end) -- Tests readline() raising an error on EOF without a separator. -with_tcp_server(8893, function (server) +with_tcp_server(function (server) server:sendall("tail") end, function (port) local client = stream.client("TCP", "127.0.0.1", port, TIMEOUT) From 50593d7be93ac6888ea124c7fe8a749500745498 Mon Sep 17 00:00:00 2001 From: Zebin Wu Date: Mon, 20 Apr 2026 21:50:08 +0800 Subject: [PATCH 6/8] reduce memory usage --- plugins/miio/chuangmi/plug/212a01.lua | 10 ++++++---- plugins/miio/cuco/plug/v3.lua | 10 ++++++---- plugins/miio/dmaker/derh/22l.lua | 16 +++++++++------- plugins/miio/dmaker/fan/p9.lua | 15 +++++++++------ plugins/miio/plugin.lua | 4 ++-- plugins/miio/xiaomi/heater/ma4.lua | 16 +++++++++------- 6 files changed, 41 insertions(+), 30 deletions(-) diff --git a/plugins/miio/chuangmi/plug/212a01.lua b/plugins/miio/chuangmi/plug/212a01.lua index f77f5e8..b9e7363 100644 --- a/plugins/miio/chuangmi/plug/212a01.lua +++ b/plugins/miio/chuangmi/plug/212a01.lua @@ -1,14 +1,16 @@ local M = {} +-- Source https://miot-spec.org/miot-spec-v2/instance?type=urn:miot-spec-v2:device:outlet:0000A002:chuangmi-212a01:1 +local propMapping = { + power = {siid = 2, piid = 1} +} + ---Create a plug. ---@param device MiioDevice Device object. ---@param conf MiioAccessoryConf Device configuration. ---@return HAPAccessory accessory HomeKit Accessory. function M.gen(device, conf) - -- Source https://miot-spec.org/miot-spec-v2/instance?type=urn:miot-spec-v2:device:outlet:0000A002:chuangmi-212a01:1 - device:setMapping({ - power = {siid = 2, piid = 1} - }) + device:setMapping(propMapping) function device:getOn() return self:getProp("power") diff --git a/plugins/miio/cuco/plug/v3.lua b/plugins/miio/cuco/plug/v3.lua index 6c5fe6e..d27cbe1 100644 --- a/plugins/miio/cuco/plug/v3.lua +++ b/plugins/miio/cuco/plug/v3.lua @@ -1,14 +1,16 @@ local M = {} +-- Source https://miot-spec.org/miot-spec-v2/instance?type=urn:miot-spec-v2:device:outlet:0000A002:cuco-v3:1 +local propMapping = { + power = {siid = 2, piid = 1} +} + ---Create a plug. ---@param device MiioDevice Device object. ---@param conf MiioAccessoryConf Device configuration. ---@return HAPAccessory accessory HomeKit Accessory. function M.gen(device, conf) - -- Source https://miot-spec.org/miot-spec-v2/instance?type=urn:miot-spec-v2:device:outlet:0000A002:cuco-v3:1 - device:setMapping({ - power = {siid = 2, piid = 1} - }) + device:setMapping(propMapping) function device:getOn() return self:getProp("power") diff --git a/plugins/miio/dmaker/derh/22l.lua b/plugins/miio/dmaker/derh/22l.lua index 8f360ef..b55363a 100644 --- a/plugins/miio/dmaker/derh/22l.lua +++ b/plugins/miio/dmaker/derh/22l.lua @@ -1,17 +1,19 @@ local M = {} +-- Source https://miot-spec.org/miot-spec-v2/instance?type=urn:miot-spec-v2:device:dehumidifier:0000A02D:dmaker-22l:1 +local propMapping = { + power = {siid = 2, piid = 1}, + tgtHumidity = {siid = 2, piid = 5}, + curHumidity = {siid = 3, piid = 1}, + curTemp = {siid = 3, piid = 2}, +} + ---Create a dehumidifier. ---@param device MiioDevice Device object. ---@param conf MiioAccessoryConf Device configuration. ---@return HAPAccessory accessory HomeKit Accessory. function M.gen(device, conf) - -- Source https://miot-spec.org/miot-spec-v2/instance?type=urn:miot-spec-v2:device:dehumidifier:0000A02D:dmaker-22l:1 - device:setMapping({ - power = {siid = 2, piid = 1}, - tgtHumidity = {siid = 2, piid = 5}, - curHumidity = {siid = 3, piid = 1}, - curTemp = {siid = 3, piid = 2}, - }) + device:setMapping(propMapping) return require("miio.dmaker.derh").gen(device, conf) end diff --git a/plugins/miio/dmaker/fan/p9.lua b/plugins/miio/dmaker/fan/p9.lua index 0c3d12d..b4f2651 100644 --- a/plugins/miio/dmaker/fan/p9.lua +++ b/plugins/miio/dmaker/fan/p9.lua @@ -1,16 +1,19 @@ local M = {} +-- Source https://miot-spec.org/miot-spec-v2/instance?type=urn:miot-spec-v2:device:fan:0000A005:dmaker-p9:1 +local propMapping = { + power = {siid = 2, piid = 1}, + fanSpeed = {siid = 2, piid = 11}, + swingMode = {siid = 2, piid = 5} +} + ---Create a fan. ---@param device MiioDevice Device object. ---@param conf MiioAccessoryConf Device configuration. ---@return HAPAccessory accessory HomeKit Accessory. function M.gen(device, conf) - -- Source https://miot-spec.org/miot-spec-v2/instance?type=urn:miot-spec-v2:device:fan:0000A005:dmaker-p9:1 - device:setMapping({ - power = {siid = 2, piid = 1}, - fanSpeed = {siid = 2, piid = 11}, - swingMode = {siid = 2, piid = 5} - }) + device:setMapping(propMapping) + return require("miio.dmaker.fan").gen(device, conf) end diff --git a/plugins/miio/plugin.lua b/plugins/miio/plugin.lua index 721a99d..dca300f 100644 --- a/plugins/miio/plugin.lua +++ b/plugins/miio/plugin.lua @@ -53,7 +53,7 @@ function M.init() for _, device in ipairs(devices) do if device.ssid == ssid then local sn = device.mac:gsub(":", "") - local handle = nvs.open(sn) + local handle = nvs.open(sn) tinsert(confs, { aid = hapUtil.getBridgedAccessoryIID(handle), iids = hapUtil.getInstanceIDs(handle), @@ -79,7 +79,7 @@ function M.init() if success == false then logger:error(result) else - table.insert(accessories, result) + tinsert(accessories, result) end end return accessories diff --git a/plugins/miio/xiaomi/heater/ma4.lua b/plugins/miio/xiaomi/heater/ma4.lua index 31348e0..262c3bb 100644 --- a/plugins/miio/xiaomi/heater/ma4.lua +++ b/plugins/miio/xiaomi/heater/ma4.lua @@ -1,17 +1,19 @@ local M = {} +-- Source https://miot-spec.org/miot-spec-v2/instance?type=urn:miot-spec-v2:device:heater:0000A01A:xiaomi-ma4:1 +local propMapping = { + on = {siid = 2, piid = 1}, + tgtTemp = {siid = 2, piid = 5}, + curTemp = {siid = 2, piid = 6}, + curState = {siid = 2, piid = 3}, +} + ---Create a dehumidifier. ---@param device MiioDevice Device object. ---@param conf MiioAccessoryConf Device configuration. ---@return HAPAccessory accessory HomeKit Accessory. function M.gen(device, conf) - -- Source https://miot-spec.org/miot-spec-v2/instance?type=urn:miot-spec-v2:device:heater:0000A01A:xiaomi-ma4:1 - device:setMapping({ - on = {siid = 2, piid = 1}, - tgtTemp = {siid = 2, piid = 5}, - curTemp = {siid = 2, piid = 6}, - curState = {siid = 2, piid = 3}, - }) + device:setMapping(propMapping) return require("miio.xiaomi.heater").gen(device, conf) end From 1463a5a64bb95c264bc1fa5ff57e18d9c5c39fc0 Mon Sep 17 00:00:00 2001 From: Zebin Wu Date: Mon, 20 Apr 2026 23:01:53 +0800 Subject: [PATCH 7/8] reduce lwip max sockets --- platform/esp/sdkconfig.defaults | 2 -- 1 file changed, 2 deletions(-) diff --git a/platform/esp/sdkconfig.defaults b/platform/esp/sdkconfig.defaults index e57a562..0a4c952 100644 --- a/platform/esp/sdkconfig.defaults +++ b/platform/esp/sdkconfig.defaults @@ -6,8 +6,6 @@ CONFIG_PARTITION_TABLE_CUSTOM_FILENAME="platform/esp/partitions.csv" CONFIG_FREERTOS_USE_TRACE_FACILITY=y CONFIG_FREERTOS_USE_STATS_FORMATTING_FUNCTIONS=y -CONFIG_LWIP_MAX_SOCKETS=32 - CONFIG_MBEDTLS_POLY1305_C=y CONFIG_MBEDTLS_CHACHA20_C=y CONFIG_MBEDTLS_CHACHAPOLY_C=y From 3670c3cf14a2c3c4fd6f294d3c29f55bd3e0b86b Mon Sep 17 00:00:00 2001 From: Zebin Wu Date: Tue, 21 Apr 2026 22:56:25 +0800 Subject: [PATCH 8/8] miio: merge transport.lua to protocol.lua and refactor recv data distribution --- plugins/miio/protocol.lua | 566 ++++++++++++++---- plugins/miio/transport.lua | 289 --------- tests/test.lua | 2 +- ...miiotransport.lua => testmiioprotocol.lua} | 258 +++----- 4 files changed, 536 insertions(+), 579 deletions(-) delete mode 100644 plugins/miio/transport.lua rename tests/{testmiiotransport.lua => testmiioprotocol.lua} (62%) diff --git a/plugins/miio/protocol.lua b/plugins/miio/protocol.lua index fbbe365..a4df868 100644 --- a/plugins/miio/protocol.lua +++ b/plugins/miio/protocol.lua @@ -1,26 +1,55 @@ +local socket = require "socket" +local netiflib = require "netif" local hash = require "hash" local cipher = require "cipher" local json = require "cjson" -local transport = require "miio.transport" local assert = assert -local pcall = pcall -local type = type local error = error local floor = math.floor +local ipairs = ipairs +local pairs = pairs +local pcall = pcall local random = math.random local tointeger = math.tointeger +local tostring = tostring +local type = type +local xpcall = xpcall local spack = string.pack local sunpack = string.unpack local schar = string.char local srep = string.rep local tconcat = table.concat +local traceback = debug.traceback local M = {} local logger = log.getLogger("miio.protocol") +local UDP_PORT = 54321 +local MAX_MSG_LEN = 2048 +local SCAN_ANY_ADDR = "*" + +---@class MiioProtocolSocket +---@field ifname string +---@field sock Socket +---@field reader Timer + +---@class MiioScanWaiter +---@field mq MessageQueue? + +---@class MiioRequestWaiter +---@field mq MessageQueue? +---@field devid integer +---@field reqid integer +---@field token string +---@field encryption MiioEncryption + ---@class MiioProtocolRuntime ----@field transport MiioTransport? +---@field sockets table +---@field localPort integer +---@field running boolean +---@field scanMqs table +---@field requestMqs table ---@field virtualDid integer? ---@field _reqid integer local runtime = {} @@ -138,6 +167,163 @@ local function createEncryption(token) }) end +local function isUsableNetif(netif) + if not netiflib.isUp(netif) then + return false + end + local addr = netiflib.getIpv4Addr(netif) + return addr ~= "0.0.0.0" and addr ~= "127.0.0.1" +end + +---@param netifs? string[] +---@return string[] results +local function resolveNetifs(netifs) + if netifs ~= nil then + assert(type(netifs) == "table", "netifs must be a table") + end + + local available = {} + local results = {} + for _, netif in ipairs(netiflib.getInterfaces()) do + local ifname = netiflib.getName(netif) + local usable = isUsableNetif(netif) + available[ifname] = usable + if netifs == nil and usable then + results[#results + 1] = ifname + end + end + + if netifs ~= nil then + local seen = {} + for i, ifname in ipairs(netifs) do + assert(type(ifname) == "string", ("network interface #%d must be a string"):format(i)) + if not seen[ifname] then + local usable = available[ifname] + if usable == nil then + error(("network interface #%d not found"):format(i)) + end + if not usable then + error(("network interface #%d is not usable"):format(i)) + end + results[#results + 1] = ifname + seen[ifname] = true + end + end + end + + assert(#results > 0, "no available IPv4 network interface") + return results +end + +local function destroySocket(ctx) + if ctx == nil or ctx.sock == nil then + return + end + if ctx.reader ~= nil then + ctx.reader:stop() + ctx.reader = nil + end + pcall(ctx.sock.destroy, ctx.sock) + ctx.sock = nil +end + +---@param netifs string[] +---@return integer localPort +---@return table sockets +local function createSockets(netifs) + local sockets = {} + local localPort = nil + local failures = {} + + for _, ifname in ipairs(netifs) do + local sock = socket.create("UDP", "IPV4") + local addr + local success, result = pcall(function() + sock:settimeout(0) + sock:enablebroadcast() + sock:reuseaddr() + sock:bindif(ifname) + sock:bind("0.0.0.0", localPort or 0) + if localPort == nil then + local _, port = sock:getsockname() + localPort = port + end + addr = netiflib.getIpv4Addr(netiflib.find(ifname)) + end) + if success then + sockets[ifname] = { + ifname = ifname, + sock = sock, + } + logger:info(("created socket, %s, %s, %s"):format(ifname, addr, localPort)) + else + logger:error(("create socket error, %s, %s"):format(ifname, tostring(result))) + failures[#failures + 1] = ("%s: %s"):format(ifname, tostring(result)) + pcall(sock.destroy, sock) + end + end + + if localPort ~= nil and #failures == 0 then + return localPort, sockets + end + + for _, ctx in pairs(sockets) do + destroySocket(ctx) + end + if #failures > 0 then + error("failed to bind miio transport sockets: " .. tconcat(failures, "; ")) + end + error("failed to bind miio transport sockets") +end + +local function snapshotWaiters(waiters) + if waiters == nil or #waiters == 0 then + return nil + end + local snapshot = {} + for i = 1, #waiters do + snapshot[i] = waiters[i] + end + return snapshot +end + +local function appendWaiter(waitersByKey, key, waiter) + local waiters = waitersByKey[key] + if waiters == nil then + waiters = {} + waitersByKey[key] = waiters + end + waiters[#waiters + 1] = waiter +end + +local function removeWaiter(waitersByKey, key, waiter) + local waiters = waitersByKey[key] + if waiters == nil then + return + end + for i = #waiters, 1, -1 do + if waiters[i] == waiter then + waiters[i] = waiters[#waiters] + waiters[#waiters] = nil + break + end + end + if #waiters == 0 then + waitersByKey[key] = nil + end +end + +local function sendWaiter(waiter, ...) + local mq = waiter.mq + if mq == nil then + return + end + local ok, err = pcall(mq.send, mq, ...) + if not ok then + logger:debug(("drop miio packet: %s"):format(tostring(err))) + end +end + ---Create a virtual device ID for probe packets. ---@return integer did ---@nodiscard @@ -150,25 +336,41 @@ end ---@param netifs? string[] ---@param virtualDid? integer ----@return string[]? netifs +---@return string[] netifs ---@return integer? virtualDid local function normalizeCreateArgs(netifs, virtualDid) - assert(netifs == nil or type(netifs) == "table", "netifs must be a table") if virtualDid ~= nil then virtualDid = assert(tointeger(virtualDid), "virtualDid must be an integer") end - return netifs, virtualDid + return resolveNetifs(netifs), virtualDid +end + +local function hasPendingRequest(self, reqid) + for _, waiters in pairs(self.requestMqs) do + for i = 1, #waiters do + if waiters[i].reqid == reqid then + return true + end + end + end + return false end ---@param self MiioProtocolRuntime ---@return integer reqid local function nextRequestId(self) - local reqid = self._reqid + 1 - if reqid > 9999 then - reqid = 1 + local reqid = self._reqid + for _ = 1, 9999 do + reqid = reqid + 1 + if reqid > 9999 then + reqid = 1 + end + if not hasPendingRequest(self, reqid) then + self._reqid = reqid + return reqid + end end - self._reqid = reqid - return reqid + error("too many pending requests") end ---Pack a message to a binary package. @@ -247,53 +449,163 @@ local function unpack(package, token) data end ----@param queueSize integer ----@param currentTp MiioTransport ----@param matcher fun(packet:string, addr:string, port:integer, ifname:string):any? ----@return MessageQueue queue ----@return fun() close -local function createReceiver(queueSize, currentTp, matcher) - assert(currentTp, "transport not inited") +---@class ScanResult Scan Result. +--- +---@field addr string Device address. +---@field devid integer Device ID: 64-bit. +---@field ifname string Network interface name. +---@field stamp integer Device time stamp. - local queue = core.createMQ(queueSize) - local subId = currentTp:subscribe(function(packet, addr, port, ifname) - local result = matcher(packet, addr, port, ifname) - if result ~= nil then - queue:send(result) +local function receivePacket(ctx) + local success, data, addr, port = xpcall(ctx.sock.recvfrom, traceback, ctx.sock, MAX_MSG_LEN) + if not success then + return nil, data + end + if port ~= UDP_PORT then + return false + end + return true, data, addr, port +end + +---@param self MiioProtocolRuntime +---@param packet string +---@param addr string +---@param ifname string +local function dispatchScanPacket(self, packet, addr, ifname) + local success, did, stamp, data = pcall(unpack, packet) + if not success or did == -1 or data ~= nil then + return false + end + + ---@type ScanResult + local result = { + addr = addr, + devid = did, + ifname = ifname, + stamp = stamp, + } + + local waiters = snapshotWaiters(self.scanMqs[addr]) + if waiters ~= nil then + for i = 1, #waiters do + sendWaiter(waiters[i], result) end - end) - local closed = false + end + + if addr ~= SCAN_ANY_ADDR then + waiters = snapshotWaiters(self.scanMqs[SCAN_ANY_ADDR]) + if waiters ~= nil then + for i = 1, #waiters do + sendWaiter(waiters[i], result) + end + end + end + + return true +end + +---@param waiter MiioRequestWaiter +---@param packet string +---@return any? decoded +local function decodeResponse(waiter, packet) + local success, did, _, data = pcall(unpack, packet, waiter.token) + if not success or did ~= waiter.devid or data == nil then + return + end + + local ok, payload = pcall(waiter.encryption.decrypt, waiter.encryption, data) + if not ok or payload == nil then + return + end + + local decodedOk, decoded = pcall(json.decode, payload) + if not decodedOk or decoded == nil or decoded.id ~= waiter.reqid then + return + end + + return decoded +end + +---@param self MiioProtocolRuntime +---@param packet string +---@param addr string +---@param ifname string +local function dispatchRequestPacket(self, packet, addr, ifname) + local waiters = self.requestMqs[addr] + if waiters == nil then + return + end - return queue, function() - if closed then + for i = 1, #waiters do + local waiter = waiters[i] + local decoded = decodeResponse(waiter, packet) + if decoded ~= nil then + sendWaiter(waiter, decoded, ifname) return end - closed = true - currentTp:unsubscribe(subId) end end ----@param queueSize integer ----@param currentTp MiioTransport ----@param matcher fun(packet:string, addr:string, port:integer, ifname:string):any? ----@param action fun(queue:MessageQueue):any ----@return any result -local function withReceiver(queueSize, currentTp, matcher, action) - local queue, close = createReceiver(queueSize, currentTp, matcher) - local ok, result = pcall(action, queue) - close() - if not ok then - error(result) +---@param self MiioProtocolRuntime +---@param packet string +---@param addr string +---@param ifname string +local function dispatchPacket(self, packet, addr, ifname) + if dispatchScanPacket(self, packet, addr, ifname) then + return end - return result + dispatchRequestPacket(self, packet, addr, ifname) end ----@class ScanResult Scan Result. ---- ----@field addr string Device address. ----@field devid integer Device ID: 64-bit. ----@field ifname string Network interface name. ----@field stamp integer Device time stamp. +local function receiveLoop(self, ctx) + while self.running and self.sockets[ctx.ifname] == ctx do + local ok, data, addr = receivePacket(ctx) + if ok == nil then + if not self.running then + return + end + logger:error(("recvfrom failed, %s, %s"):format(ctx.ifname, tostring(data))) + return + end + if ok then + dispatchPacket(self, data, addr, ctx.ifname) + end + end +end + +---@param self MiioProtocolRuntime +---@param data string +---@param addr string +---@param port integer +---@param ifname? string +---@return integer sent +local function sendto(self, data, addr, port, ifname) + assert(type(data) == "string") + assert(type(addr) == "string") + assert(type(port) == "number") + + local sent = 0 + + if ifname ~= nil then + local ctx = assert(self.sockets[ifname], ("unknown netif '%s'"):format(ifname)) + local success, result = pcall(ctx.sock.sendto, ctx.sock, data, addr, port) + if not success then + error(result) + end + return 1 + end + + for name, ctx in pairs(self.sockets) do + local success, result = pcall(ctx.sock.sendto, ctx.sock, data, addr, port) + if success then + sent = sent + 1 + else + logger:debug(("sendto failed, %s, %s"):format(name, tostring(result))) + end + end + + return sent +end ---Scan for devices in the local network. --- @@ -307,51 +619,48 @@ end ---@nodiscard function runtime:scan(timeout, addr) assert(timeout > 0, "timeout must be greater then 0") - local currentTp = assert(self.transport, "protocol closed") - local virtualDid = assert(self.virtualDid, "protocol closed") - - local numSend = 1 - local probe = packProbe(virtualDid) + assert(self.virtualDid ~= nil, "protocol closed") - if not addr then - numSend = 3 - end + local key = addr or SCAN_ANY_ADDR + local waiter = { + mq = core.createMQ(64), + } + appendWaiter(self.scanMqs, key, waiter) - local seen = {} - local deadline = floor(core.time()) + timeout - return withReceiver(64, currentTp, function(packet, fromAddr, _, ifname) - if addr ~= nil and fromAddr ~= addr then - return - end - local success, did, stamp, data = pcall(unpack, packet) - if not success or did == -1 or data ~= nil or seen[fromAddr] then - return - end - seen[fromAddr] = true - return { - addr = fromAddr, - devid = did, - ifname = ifname, - stamp = stamp, - } - end, function(queue) + local ok, result = pcall(function() + local numSend = addr == nil and 3 or 1 + local probe = packProbe(self.virtualDid) + local deadline = floor(core.time()) + timeout + local seen = {} local results = {} + for _ = 1, numSend do - assert(currentTp:sendto(probe, addr or "255.255.255.255", 54321) > 0, "failed to send probe message") + assert(sendto(self, probe, addr or "255.255.255.255", UDP_PORT) > 0, "failed to send probe message") end while true do - local ok, result = queue:recvUntil(deadline) - if not ok then + local recvOk, item = waiter.mq:recvUntil(deadline) + if not recvOk then break end - results[#results + 1] = result - if addr then - break + if not seen[item.addr] then + seen[item.addr] = true + results[#results + 1] = item + if addr ~= nil then + break + end end end + return results end) + + removeWaiter(self.scanMqs, key, waiter) + waiter.mq = nil + if not ok then + error(result) + end + return result end ---@class MiioPcb: table miio protocol control block. @@ -363,29 +672,6 @@ local pcb = {} ---@field code integer Error code. ---@field message string Error message. ----@param self MiioPcb ----@param raw string ----@param reqid integer ----@return any? decoded -local function decodeResponse(self, raw, reqid) - local success, did, _, data = pcall(unpack, raw, self.token) - if not success or did ~= self.devid or data == nil then - return - end - - local ok, payload = pcall(self.encryption.decrypt, self.encryption, data) - if not ok or payload == nil then - return - end - - local decodedOk, decoded = pcall(json.decode, payload) - if not decodedOk or decoded == nil or decoded.id ~= reqid then - return - end - - return decoded -end - ---Handshake. ---@param timeout integer Timeout period (in milliseconds). function pcb:handshake(timeout) @@ -418,7 +704,6 @@ function pcb:request(timeout, method, ...) end local reqid = nextRequestId(self.runtime) - local plain = json.encode({ id = reqid, method = method, @@ -430,24 +715,23 @@ function pcb:request(timeout, method, ...) self.token, self.encryption:encrypt(plain) ) - local currentTp = assert(self.runtime.transport, "protocol closed") - local deadline = floor(core.time()) + timeout + local waiter = { + mq = core.createMQ(1), + devid = self.devid, + reqid = reqid, + token = self.token, + encryption = self.encryption, + } + appendWaiter(self.runtime.requestMqs, self.addr, waiter) + logger:debug(("%s => %s"):format(plain, self.addr)) - local result = withReceiver(1, currentTp, function(raw, fromAddr, _, ifname) - if fromAddr ~= self.addr then - return - end - local decoded = decodeResponse(self, raw, reqid) - if decoded == nil then - return - end - self.ifname = ifname - return decoded - end, function(queue) + local ok, result = pcall(function() local sent = 0 + local deadline = floor(core.time()) + timeout + if self.ifname ~= nil then - local ok, sendResult = pcall(currentTp.sendto, currentTp, packet, self.addr, 54321, self.ifname) - if ok then + local sendOk, sendResult = pcall(sendto, self.runtime, packet, self.addr, UDP_PORT, self.ifname) + if sendOk then sent = sendResult else logger:debug(("send via %s failed, retry all netifs, %s"):format(self.ifname, tostring(sendResult))) @@ -455,21 +739,28 @@ function pcb:request(timeout, method, ...) end end if sent == 0 then - sent = currentTp:sendto(packet, self.addr, 54321) + sent = sendto(self.runtime, packet, self.addr, UDP_PORT) end if sent == 0 then error("failed to send request") end - local ok, recvResult = queue:recvUntil(deadline) - if not ok then + local recvOk, recvResult, ifname = waiter.mq:recvUntil(deadline) + if not recvOk then self.ifname = nil self.stampDiff = nil error(recvResult) end + self.ifname = ifname return recvResult end) + removeWaiter(self.runtime.requestMqs, self.addr, waiter) + waiter.mq = nil + if not ok then + error(result) + end + logger:debug(("%s => %s"):format(self.addr, json.encode(result))) ---@class MiioError local err = result.error @@ -490,7 +781,7 @@ function runtime:createPcb(addr, token) assert(type(addr) == "string") assert(type(token) == "string") assert(#token == 16) - assert(self.transport ~= nil, "protocol closed") + assert(self.virtualDid ~= nil, "protocol closed") ---@class MiioPcb local o = { @@ -511,11 +802,16 @@ end ---Close the miIO protocol runtime. ---@param self MiioProtocolRuntime function runtime:close() - local currentTp = self.transport - if currentTp ~= nil then - currentTp:close() - self.transport = nil + if not self.running then + return + end + self.running = false + for _, ctx in pairs(self.sockets) do + destroySocket(ctx) end + self.sockets = {} + self.scanMqs = {} + self.requestMqs = {} self.virtualDid = nil end @@ -526,17 +822,29 @@ end ---@nodiscard function M.create(netifs, virtualDid) netifs, virtualDid = normalizeCreateArgs(netifs, virtualDid) + local localPort, sockets = createSockets(netifs) ---@class MiioProtocolRuntime local o = { - transport = transport.create(netifs), + sockets = sockets, + localPort = localPort, + running = true, + scanMqs = {}, + requestMqs = {}, virtualDid = virtualDid or createVirtualDid(), _reqid = 0, } - return setmetatable(o, { + setmetatable(o, { __index = runtime }) + + for _, ctx in pairs(o.sockets) do + ctx.reader = core.createTimer(receiveLoop, o, ctx) + ctx.reader:start(0) + end + + return o end return M diff --git a/plugins/miio/transport.lua b/plugins/miio/transport.lua deleted file mode 100644 index 086439a..0000000 --- a/plugins/miio/transport.lua +++ /dev/null @@ -1,289 +0,0 @@ -local socket = require "socket" -local netiflib = require "netif" - -local assert = assert -local error = error -local type = type -local ipairs = ipairs -local pairs = pairs -local tostring = tostring -local xpcall = xpcall -local tconcat = table.concat -local traceback = debug.traceback - -local M = {} -local logger = log.getLogger("miio.transport") - -local UDP_PORT = 54321 -local MAX_MSG_LEN = 2048 - ----@class MiioTransportSocket ----@field ifname string ----@field sock Socket ----@field reader Timer - ----@class MiioTransport ----@field sockets table ----@field localPort integer ----@field subscribers table ----@field running boolean -local transport = {} - -local function isUsableNetif(netif) - if not netiflib.isUp(netif) then - return false - end - local addr = netiflib.getIpv4Addr(netif) - return addr ~= "0.0.0.0" and addr ~= "127.0.0.1" -end - -local function resolveNetifs(netifs) - if netifs ~= nil then - assert(type(netifs) == "table", "netifs must be a table") - end - - local available = {} - local results = {} - for _, netif in ipairs(netiflib.getInterfaces()) do - local ifname = netiflib.getName(netif) - local usable = isUsableNetif(netif) - available[ifname] = usable - if netifs == nil and usable then - results[#results + 1] = ifname - end - end - - if netifs ~= nil then - local seen = {} - for i, ifname in ipairs(netifs) do - assert(type(ifname) == "string", ("network interface #%d must be a string"):format(i)) - if not seen[ifname] then - local usable = available[ifname] - if usable == nil then - error(("network interface #%d not found"):format(i)) - end - if not usable then - error(("network interface #%d is not usable"):format(i)) - end - results[#results + 1] = ifname - seen[ifname] = true - end - end - end - - assert(#results > 0, "no available IPv4 network interface") - return results -end - -local function destroySocket(ctx) - if ctx == nil or ctx.sock == nil then - return - end - if ctx.reader ~= nil then - ctx.reader:stop() - ctx.reader = nil - end - pcall(ctx.sock.destroy, ctx.sock) - ctx.sock = nil -end - -local function createSockets(netifs) - local sockets = {} - local localPort = nil - local failures = {} - - for _, ifname in ipairs(netifs) do - local sock = socket.create("UDP", "IPV4") - local addr - local success, result = pcall(function() - sock:settimeout(0) - sock:enablebroadcast() - sock:reuseaddr() - sock:bindif(ifname) - sock:bind("0.0.0.0", localPort or 0) - if localPort == nil then - local _, port = sock:getsockname() - localPort = port - end - addr = netiflib.getIpv4Addr(netiflib.find(ifname)) - end) - if success then - sockets[ifname] = { - ifname = ifname, - sock = sock, - } - logger:info(("created socket, %s, %s, %s"):format(ifname, addr, localPort)) - else - logger:error(("create socket error, %s, %s"):format(ifname, tostring(result))) - failures[#failures + 1] = ("%s: %s"):format(ifname, tostring(result)) - pcall(sock.destroy, sock) - end - end - - if localPort ~= nil and #failures == 0 then - return localPort, sockets - end - - for _, ctx in pairs(sockets) do - destroySocket(ctx) - end - if #failures > 0 then - error("failed to bind miio transport sockets: " .. tconcat(failures, "; ")) - end - error("failed to bind miio transport sockets") -end - -local function receivePacket(ctx) - local success, data, addr, port = xpcall(ctx.sock.recvfrom, traceback, ctx.sock, MAX_MSG_LEN) - if not success then - return nil, data - end - if port ~= UDP_PORT then - return false - end - return true, data, addr, port -end - -local function snapshotSubscribers(self) - if next(self.subscribers) == nil then - return nil - end - - local subscribers = {} - for id, handler in pairs(self.subscribers) do - subscribers[#subscribers + 1] = { - id = id, - handler = handler, - } - end - return subscribers -end - ----@param self MiioTransport -function transport:_dispatch(data, addr, port, ifname) - local subscribers = snapshotSubscribers(self) - if subscribers == nil then - return - end - - for _, entry in ipairs(subscribers) do - if self.subscribers[entry.id] ~= entry.handler then - goto continue - end - local success, result = xpcall(entry.handler, traceback, data, addr, port, ifname) - if not success then - logger:error(("subscriber[%s] failed: %s"):format(entry.id, tostring(result))) - end -::continue:: - end -end - -local function receiveLoop(self, ctx) - while self.running and self.sockets[ctx.ifname] == ctx do - local ok, data, addr, port = receivePacket(ctx) - if ok == nil then - if not self.running then - return - end - logger:error(("recvfrom failed, %s, %s"):format(ctx.ifname, tostring(data))) - return - end - if ok then - self:_dispatch(data, addr, port, ctx.ifname) - end - end -end - ----@param self MiioTransport -function transport:close() - if not self.running then - return - end - self.running = false - for _, ctx in pairs(self.sockets) do - destroySocket(ctx) - end - self.sockets = {} - self.subscribers = {} -end - ----@param self MiioTransport ----@param data string ----@param addr string ----@param port integer ----@param ifname? string ----@return integer sent -function transport:sendto(data, addr, port, ifname) - assert(type(data) == "string") - assert(type(addr) == "string") - assert(type(port) == "number") - - local sent = 0 - - if ifname ~= nil then - local ctx = assert(self.sockets[ifname], ("unknown netif '%s'"):format(ifname)) - local success, result = pcall(ctx.sock.sendto, ctx.sock, data, addr, port) - if not success then - error(result) - end - return 1 - end - - for name, ctx in pairs(self.sockets) do - local success, result = pcall(ctx.sock.sendto, ctx.sock, data, addr, port) - if success then - sent = sent + 1 - else - logger:debug(("sendto failed, %s, %s"):format(name, tostring(result))) - end - end - - return sent -end - ----@param self MiioTransport ----@param handler fun(data:string, addr:string, port:integer, ifname:string) ----@return integer id -function transport:subscribe(handler) - assert(type(handler) == "function") - self._subId = (self._subId or 0) + 1 - self.subscribers[self._subId] = handler - return self._subId -end - ----@param self MiioTransport ----@param id integer -function transport:unsubscribe(id) - self.subscribers[id] = nil -end - ----Create a resident miIO UDP transport. ----@param netifs? string[] ----@return MiioTransport transport ----@nodiscard -function M.create(netifs) - local resolved = resolveNetifs(netifs) - local localPort, sockets = createSockets(resolved) - - ---@type MiioTransport - local o = { - sockets = sockets, - localPort = localPort, - subscribers = {}, - running = true, - _subId = 0, - } - - setmetatable(o, { - __index = transport - }) - - for _, ctx in pairs(o.sockets) do - ctx.reader = core.createTimer(receiveLoop, o, ctx) - ctx.reader:start(0) - end - - return o -end - -return M diff --git a/tests/test.lua b/tests/test.lua index c2b06fd..75abeda 100644 --- a/tests/test.lua +++ b/tests/test.lua @@ -3,7 +3,7 @@ local suites = { "testsocket", "teststream", "testnvs", - "testmiiotransport", + "testmiioprotocol", } local function runSuite(s) diff --git a/tests/testmiiotransport.lua b/tests/testmiioprotocol.lua similarity index 62% rename from tests/testmiiotransport.lua rename to tests/testmiioprotocol.lua index 8dc168a..772a656 100644 --- a/tests/testmiiotransport.lua +++ b/tests/testmiioprotocol.lua @@ -4,12 +4,10 @@ local netif = require "netif" local hash = require "hash" local cipher = require "cipher" local json = require "cjson" -local transport = require "miio.transport" local protocol = require "miio.protocol" local floor = math.floor local ipairs = ipairs -local pairs = pairs local spack = string.pack local sunpack = string.unpack local srep = string.rep @@ -93,6 +91,18 @@ local function waitFor(mq, timeout) return ok, result end +local function waitUntil(timeout, predicate) + local deadline = floor(core.time()) + timeout + while floor(core.time()) < deadline do + local result = predicate() + if result then + return result + end + core.sleep(5) + end + return predicate() +end + local function startUdpDevice(handler) local server = socket.create("UDP", "IPV4") server:reuseaddr() @@ -116,171 +126,31 @@ local function startUdpDevice(handler) end end --- Tests auto-selected transport skips loopback interfaces. +-- Tests wildcard scan waiter is registered while waiting and removed after timeout. do - local tp = transport.create() - local count = 0 - for ifname in pairs(tp.sockets) do - count = count + 1 - assert(netif.getIpv4Addr(netif.find(ifname)) ~= "127.0.0.1") - end - assert(count > 0) - tp:close() -end - --- Tests transport creation fails instead of silently dropping requested interfaces. -do - local origSocketCreate = socket.create - local origGetInterfaces = netif.getInterfaces - local origGetName = netif.getName - local origIsUp = netif.isUp - local origGetIpv4Addr = netif.getIpv4Addr - local origFind = netif.find - - local destroyed = {} - local fakeIfs = { - eth0 = { - name = "eth0", - addr = "192.168.10.2", - }, - wlan0 = { - name = "wlan0", - addr = "192.168.20.2", - }, - } - local nextPort = 40000 - - local function restore() - socket.create = origSocketCreate - netif.getInterfaces = origGetInterfaces - netif.getName = origGetName - netif.isUp = origIsUp - netif.getIpv4Addr = origGetIpv4Addr - netif.find = origFind - end - - local ok, err = pcall(function() - netif.getInterfaces = function() - return {fakeIfs.eth0, fakeIfs.wlan0} - end - netif.getName = function(iface) - return iface.name - end - netif.isUp = function(_) - return true - end - netif.getIpv4Addr = function(iface) - return iface.addr - end - netif.find = function(name) - return assert(fakeIfs[name]) - end - - socket.create = function(sockType, familyName) - assert(sockType == "UDP") - assert(familyName == "IPV4") - nextPort = nextPort + 1 - - local o = { - port = nextPort, - } - - function o:settimeout(ms) - assert(ms == 0) - end - - function o:enablebroadcast() end - - function o:reuseaddr() end - - function o:bindif(ifname) - self.ifname = ifname - if ifname == "wlan0" then - error("bindif failed") - end - end - - function o:bind(addr, port) - assert(addr == "0.0.0.0") - assert(type(port) == "number") - end - - function o:getsockname() - return "0.0.0.0", self.port - end - - function o:destroy() - destroyed[self.ifname] = true - end - - return o - end - - local created, createErr = pcall(transport.create, {"eth0", "wlan0"}) - assert(created == false) - assert(createErr:find("failed to bind miio transport sockets", 1, true) ~= nil) - assert(createErr:find("wlan0", 1, true) ~= nil) - end) - - restore() - assert(ok, err) - assert(destroyed.eth0 == true) - assert(destroyed.wlan0 == true) -end - --- Tests resident transport send/recv flow, source-port filtering and unsubscribe. -do - local bindif, addr = findTestNetif() - local tp = transport.create({bindif}) - local ctx = assert(tp.sockets[bindif]) - local _, localPort = ctx.sock:getsockname() - assert(localPort == tp.localPort) - - local recvMq = core.createMQ(4) - local stopDevice = startUdpDevice(function(msg, fromAddr, fromPort, server) - assert(fromAddr == addr) - assert(fromPort == tp.localPort) + local bindif = findTestNetif() + local runtime = protocol.create({bindif}, 0x1111222233334444) + local mq = core.createMQ(1) - if msg == "ping" then - assert(server:sendto("pong", fromAddr, fromPort) == 4) - elseif msg == "notify" then - assert(server:sendto("after-unsub", fromAddr, fromPort) == 11) - end - end) + core.createTimer(function() + local ok, results = pcall(runtime.scan, runtime, 50) + mq:send(ok, results) + end):start(0) - local subId = tp:subscribe(function(data, fromAddr, port, inboundIfname) - recvMq:send(true, { - data = data, - addr = fromAddr, - port = port, - ifname = inboundIfname, - }) - end) + assert(waitUntil(100, function() + local waiters = runtime.scanMqs["*"] + return waiters ~= nil and #waiters == 1 + end)) - assert(tp:sendto("ping", addr, UDP_PORT, bindif) == 1) - local ok, packet = waitFor(recvMq, 1000) + local ok, results = waitFor(mq, 500) assert(ok == true) - assert(packet.data == "pong") - assert(packet.addr == addr) - assert(packet.port == UDP_PORT) - assert(packet.ifname == bindif) - - local other = socket.create("UDP", "IPV4") - assert(other:sendto("ignored", addr, tp.localPort) == 7) - ok = waitFor(recvMq, 100) - assert(ok == false) - - tp:unsubscribe(subId) - assert(tp:sendto("notify", addr, UDP_PORT, bindif) == 1) - ok = waitFor(recvMq, 100) - assert(ok == false) + assert(type(results) == "table") + assert(runtime.scanMqs["*"] == nil) - stopDevice(addr) - tp:close() - tp:close() + runtime:close() end --- Tests protocol.scan/request over the resident transport with a virtual miIO device. +-- Tests protocol.scan/request over the resident protocol runtime with a virtual miIO device. do local bindif, addr = findTestNetif() local token = "0123456789abcdef" @@ -326,6 +196,8 @@ do assert(results[1].addr == addr) assert(results[1].devid == deviceDid) assert(results[1].ifname == bindif) + assert(runtime.scanMqs[addr] == nil) + assert(runtime.scanMqs["*"] == nil) local pcb = runtime:createPcb(addr, token) local result1 = pcb:request(1000, "test.echo", "ping") @@ -333,11 +205,13 @@ do assert(result1[2] == "ping") assert(result1[3] == 1) assert(pcb.ifname == bindif) + assert(runtime.requestMqs[addr] == nil) local result2 = pcb:request(1000, "test.echo", "pong") assert(result2[1] == "test.echo") assert(result2[2] == "pong") assert(result2[3] == 2) + assert(runtime.requestMqs[addr] == nil) assert(stats.probes == 2) assert(stats.requests == 2) @@ -354,6 +228,8 @@ do local deviceStamp = floor(core.time()) - 5 local enc = createEncryption(token) local pending = {} + local readyMq = core.createMQ(1) + local releaseMq = core.createMQ(1) local stopDevice = startUdpDevice(function(msg, fromAddr, fromPort, server) local ok, did, stamp, data = pcall(unpack, msg) @@ -376,7 +252,10 @@ do return end - for _, item in ipairs(pending) do + readyMq:send(true) + releaseMq:recv() + for i = 1, #pending do + local item = pending[i] local payload = json.encode({ id = item.req.id, result = { @@ -405,6 +284,15 @@ do mq:send("pcb2", ok, result) end):start(0) + local ready = waitFor(readyMq, 1000) + assert(ready == true) + assert(waitUntil(100, function() + local waiters = runtime.requestMqs[addr] + return waiters ~= nil and #waiters == 2 + end)) + + releaseMq:send(true) + local results = {} for _ = 1, 2 do local name, ok, result = mq:recv() @@ -416,6 +304,56 @@ do assert(results.pcb1[2] == "A") assert(results.pcb2[1] == "test.echo") assert(results.pcb2[2] == "B") + assert(runtime.requestMqs[addr] == nil) + + stopDevice(addr) + runtime:close() +end + +-- Tests timed out requests remove their waiter entry from the runtime. +do + local bindif, addr = findTestNetif() + local token = "0123456789abcdef" + local deviceDid = 0x2122232425262728 + local deviceStamp = floor(core.time()) - 5 + local requestSeen = core.createMQ(1) + + local stopDevice = startUdpDevice(function(msg, fromAddr, fromPort, server) + local ok, did, stamp, data = pcall(unpack, msg) + if ok and did == -1 and stamp == 0xffffffff and data == nil then + local resp = pack(deviceDid, deviceStamp) + assert(server:sendto(resp, fromAddr, fromPort) == #resp) + return + end + + did, stamp, data = unpack(msg, token) + assert(did == deviceDid) + requestSeen:send(true, fromAddr, fromPort) + end) + + local runtime = protocol.create({bindif}, 0x9999000011112222) + local pcb = runtime:createPcb(addr, token) + pcb:handshake(1000) + + local mq = core.createMQ(1) + core.createTimer(function() + local ok, err = pcall(pcb.request, pcb, 60, "test.timeout") + mq:send(ok, err) + end):start(0) + + local ok = waitFor(requestSeen, 1000) + assert(ok == true) + assert(waitUntil(100, function() + local waiters = runtime.requestMqs[addr] + return waiters ~= nil and #waiters == 1 + end)) + + local success, err = waitFor(mq, 500) + assert(success == false) + assert(tostring(err):find("timeout", 1, true) ~= nil) + assert(runtime.requestMqs[addr] == nil) + assert(pcb.ifname == nil) + assert(pcb.stampDiff == nil) stopDevice(addr) runtime:close()