略微优化ws库代码

This commit is contained in:
MrZ626
2021-02-13 20:42:48 +08:00
parent d6037ad15b
commit ab5d6878de

View File

@@ -1,145 +1,144 @@
--[[ --[[
websocket client pure lua implement for love2d websocket client pure lua implement for love2d
by flaribbit by flaribbit
usage: usage:
local client = require("websocket").new() local client = require("websocket").new()
client:settimeout(1) client:settimeout(1)
client:connect("127.0.0.1", 5000) client:connect("127.0.0.1", 5000)
client:settimeout(0) client:settimeout(0)
client:send("hello from love2d") client:send("hello from love2d")
res, opcode = client:read() res, opcode = client:read()
print(res) print(res)
client:close() client:close()
]] ]]
-- local debug_print=print -- local debug_print=print
local socket = require"socket" local socket = require"socket"
local bit = require"bit" local bit = require"bit"
local band, bor, bxor = bit.band, bit.bor, bit.bxor local band, bor, bxor = bit.band, bit.bor, bit.bxor
local shl, shr = bit.lshift, bit.rshift local shl, shr = bit.lshift, bit.rshift
local OPCODES = { local OPCODES = {
CONTINUE=0, CONTINUE=0,
TEXT =1, TEXT =1,
BINARY =2, BINARY =2,
CLOSE =8, CLOSE =8,
PING =9, PING =9,
PONG =10, PONG =10,
} }
local _M = { local _M = {
OPCODES = OPCODES, OPCODES = OPCODES,
} }
_M.__index = _M _M.__index = _M
function _M.new() function _M.new()
local m = {socket = socket.tcp()} local m = {socket = socket.tcp()}
setmetatable(m, _M) setmetatable(m, _M)
return m return m
end end
local seckey = "osT3F7mvlojIvf3/8uIsJQ==" local seckey = "osT3F7mvlojIvf3/8uIsJQ=="
function _M:connect(host, port, path) function _M:connect(host, port, path)
local SOCK = self.socket local SOCK = self.socket
local res, err = SOCK:connect(host, port) local res, err = SOCK:connect(host, port)
if res~=1 then return res, err end if res~=1 then return res, err end
-- debug_print("[handshake] connected") -- debug_print("[handshake] connected")
-- WebSocket handshake -- WebSocket handshake
res, err = SOCK:send("GET "..(path or"/").." HTTP/1.1\r\nHost: "..host..":"..port.."\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "..seckey.."\r\n\r\n") res, err = SOCK:send("GET "..(path or"/").." HTTP/1.1\r\nHost: "..host..":"..port.."\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "..seckey.."\r\n\r\n")
repeat res = SOCK:receive("*l") until res=="" repeat res = SOCK:receive("*l") until res==""
-- debug_print("[handshake] succeed") -- debug_print("[handshake] succeed")
end end
local function send(SOCK, opcode, message) local mask_key = {1, 14, 5, 14}
local mask_key = {1, 14, 5, 14} local function send(SOCK, opcode, message)
-- message type
-- message type SOCK:send(string.char(bor(0x80, opcode)))
SOCK:send(string.char(bor(0x80, opcode)))
if not message then
if message==nil then SOCK:send(string.char(0x80, unpack(mask_key)))
SOCK:send(string.char(0x80, unpack(mask_key))) return 0
return 0 end
end
-- length
-- length local length = #message
local length = #message -- debug_print("[encode] message length: "..length)
-- debug_print("[encode] message length: "..length) if length>65535 then
if length>65535 then SOCK:send(string.char(bor(127, 0x80), 0, 0, 0, 0,
SOCK:send(string.char(bor(127, 0x80), 0, 0, 0, 0, band(shr(length, 24), 0xff),
band(shr(length, 24), 0xff), band(shr(length, 16), 0xff),
band(shr(length, 16), 0xff), band(shr(length, 8), 0xff),
band(shr(length, 8), 0xff), band(length, 0xff)))
band(length, 0xff))) elseif length>125 then
elseif length>125 then SOCK:send(string.char(bor(126, 0x80),
SOCK:send(string.char(bor(126, 0x80), band(shr(length, 8), 0xff),
band(shr(length, 8), 0xff), band(length, 0xff)))
band(length, 0xff))) else
else SOCK:send(string.char(bor(length, 0x80)))
SOCK:send(string.char(bor(length, 0x80))) end
end -- debug_print("[encode] masking")
-- debug_print("[encode] masking") SOCK:send(string.char(unpack(mask_key)))
SOCK:send(string.char(unpack(mask_key))) local msgbyte = {message:byte(1, length)}
local msgbyte = {message:byte(1, length)} for i = 1, length do
for i = 1, length do msgbyte[i] = bxor(msgbyte[i], mask_key[(i-1)%4+1])
msgbyte[i] = bxor(msgbyte[i], mask_key[(i-1)%4+1]) end
end return SOCK:send(string.char(unpack(msgbyte)))
return SOCK:send(string.char(unpack(msgbyte))) -- debug_print("[encode] end")
-- debug_print("[encode] end") end
end
function _M:send(message)
function _M:send(message) send(self.socket, OPCODES.BINARY, message)
send(self.socket, OPCODES.BINARY, message) end
end
function _M:ping(message)
function _M:ping(message) send(self.socket, OPCODES.PING, message)
send(self.socket, OPCODES.PING, message) end
end
function _M:pong(message)
function _M:pong(message) send(self.socket, OPCODES.PONG, message)
send(self.socket, OPCODES.PONG, message) end
end
function _M:read()
function _M:read() -- byte 0-1
-- byte 0-1 local SOCK = self.socket
local SOCK = self.socket local res, err = SOCK:receive(2)
local res, err = SOCK:receive(2) if not res then return res, err end
if res==nil then return res, err end
local OPCODE = band(res:byte(), 0x0f)
-- local flag_FIN = res:byte()>=0x80 -- local flag_FIN = res:byte()>=0x80
local OPCODE = band(res:byte(), 0x0f) -- local flag_MASK = res:byte(2)>=0x80
-- local flag_MASK = res:byte(2)>=0x80 -- debug_print("[decode] FIN="..tostring(flag_FIN)..", OPCODE="..OPCODE..", MASK="..tostring(flag_MASK))
-- debug_print("[decode] FIN="..tostring(flag_FIN)..", OPCODE="..OPCODE..", MASK="..tostring(flag_MASK))
-- length
-- length local byte = res:byte(2)
local byte = res:byte(2) local length = band(byte, 0x7f)
local length = band(byte, 0x7f) if length==126 then
if length==126 then res = SOCK:receive(2)
res = SOCK:receive(2) local b1, b2 = res:byte(1, 2)
local b1, b2 = res:byte(1, 2) length = shl(b1, 8) + b2
length = shl(b1, 8) + b2 elseif length==127 then
elseif length==127 then res = SOCK:receive(8)
res = SOCK:receive(8) local b = {res:byte(1, 8)}
local b = {res:byte(1, 8)} length = shl(b[5], 32) + shl(b[6], 24) + shl(b[7], 8) + b[8]
length = shl(b[5], 32) + shl(b[6], 24) + shl(b[7], 8) + b[8] end
end -- debug_print("[decode] message length: "..length)
-- debug_print("[decode] message length: "..length)
-- data
-- data res = SOCK:receive(length)
res = SOCK:receive(length) if OPCODE==OPCODES.PING then
if OPCODE==OPCODES.PING then self:pong(res)
self:pong(res) elseif OPCODE==OPCODES.CLOSE then
elseif OPCODE==OPCODES.CLOSE then self:close()
self:close() end
end -- debug_print("[decode] string length: "..#res)
-- debug_print("[decode] string length: "..#res) -- debug_print("[decode] end")
-- debug_print("[decode] end") return res, OPCODE
return res, OPCODE end
end
function _M:close() self.socket:close() end
function _M:close() self.socket:close() end
function _M:settimeout(t) self.socket:settimeout(t) end
function _M:settimeout(t) self.socket:settimeout(t) end
return _M return _M