更新WebSocket库使提升请求支持body,并自动判断host端口

This commit is contained in:
Particle_G
2021-02-15 18:11:18 +08:00
parent 5fd2dd5b95
commit 9946048b00

View File

@@ -1,23 +1,24 @@
--[[ --[[
websocket client pure lua implement for love2d websocket client pure lua implement for love2d
by flaribbit by flaribbit and Particle_G
usage: usage:
local client = require("websocket").new() local client = require("websocket").new()
client:settimeout(1) client:settimeout(1)
client:connect("127.0.0.1", 5000) client:connect("127.0.0.1:5000", "/test", '{"foo":"bar"}')
client:settimeout(0) client:settimeout(0)
client:send("hello from love2d") client:send("hello from love2d", OPCODES.TEXT)
res, opcode = client:read() love.timer.sleep(0.2)
opcode, res, closeCode = client:read()
print(res) print(res)
client:close() client:send("Goodbye from love2d", OPCODES.CLOSE)
]] -- local debug_print=print ]] -- local debug_print=print
local socket = require "socket" local socket = require "socket"
local band, bor, bxor = bit.band, bit.bor, bit.bxor local band, bor, bxor = bit.band, bit.bor, bit.bxor
local shl, shr = bit.lshift, bit.rshift local shl, shr = bit.lshift, bit.rshift
local OPCODES = { OPCODES = {
CONTINUE = 0, CONTINUE = 0,
TEXT = 1, TEXT = 1,
BINARY = 2, BINARY = 2,
@@ -31,6 +32,18 @@ local _M = {
} }
_M.__index = _M _M.__index = _M
function splitStr(s, sep)
local L = {}
local p1 = 1
local p2 = nil
while p1 <= #s do
p2 = string.find(s, sep, p1) or #s + 1
L[#L + 1] = string.sub(s, p1, p2 - 1)
p1 = p2 + #sep
end
return L
end
function _M.new() function _M.new()
local m = { local m = {
socket = socket.tcp() socket = socket.tcp()
@@ -40,20 +53,23 @@ function _M.new()
end end
local seckey = "osT3F7mvlojIvf3/8uIsJQ==" local seckey = "osT3F7mvlojIvf3/8uIsJQ=="
function _M:connect(host, port, path) function _M:connect(server, path, body)
local host, port = unpack(splitStr(server, ":"))
local SOCK = self.socket local SOCK = self.socket
local res, err = SOCK:connect(host, port) local res, err = SOCK:connect(host, port or 80)
if res ~= 1 then if res ~= 1 then
return res, err return res, err
end end
-- debug_print("[handshake] connected")
-- WebSocket handshake -- WebSocket handshake
res, err = SOCK:send("GET " .. (path or "/") .. " HTTP/1.1\r\nHost: " .. host .. ":" .. port .."\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: " ..seckey .. "\r\n\r\n") res, err = SOCK:send("GET " .. (path or "/") .. " HTTP/1.1\r\n" .. "Host: " .. host .. ":" .. port .. "\r\n" ..
"Connection: Upgrade\r\n" .. "Upgrade: websocket\r\n" ..
"Content-Type: application/json\r\n" .. "Content-Length: " .. (body and {#body} or {0})[1] ..
"\r\n" .. "Sec-WebSocket-Version: 13\r\n" .. "Sec-WebSocket-Key: " .. seckey .. "\r\n" ..
'\r\n' .. (body and {body} or {""})[1])
repeat repeat
res = SOCK:receive("*l") res = SOCK:receive("*l")
until res == "" until res == ""
-- debug_print("[handshake] succeed")
end end
local mask_key = {1, 14, 5, 14} local mask_key = {1, 14, 5, 14}
@@ -68,26 +84,30 @@ local function _send(SOCK, opcode, message)
-- length -- length
local length = #message local length = #message
-- debug_print("[encode] message length: "..length)
if length > 65535 then if length > 65535 then
SOCK:send(string.char(bor(127, 0x80), 0, 0, 0, 0, band(shr(length, 24), 0xff), band(shr(length, 16), 0xff),band(shr(length, 8), 0xff), band(length, 0xff))) SOCK:send(string.char(bor(127, 0x80), 0, 0, 0, 0, band(shr(length, 24), 0xff), band(shr(length, 16), 0xff),
band(shr(length, 8), 0xff), band(length, 0xff)))
elseif length > 125 then elseif length > 125 then
SOCK:send(string.char(bor(126, 0x80), band(shr(length, 8), 0xff), band(length, 0xff))) SOCK:send(string.char(bor(126, 0x80), band(shr(length, 8), 0xff), band(length, 0xff)))
else else
SOCK:send(string.char(bor(length, 0x80))) SOCK:send(string.char(bor(length, 0x80)))
end end
-- debug_print("[encode] masking")
SOCK:send(string.char(unpack(mask_key))) SOCK:send(string.char(unpack(mask_key)))
local msgbyte = {message:byte(1, length)} local msgbyte = {message:byte(1, length)}
for i = 1, length do for i = 1, length do
msgbyte[i] = bxor(msgbyte[i], mask_key[(i - 1) % 4 + 1]) msgbyte[i] = bxor(msgbyte[i], mask_key[(i - 1) % 4 + 1])
end end
return SOCK:send(string.char(unpack(msgbyte))) return SOCK:send(string.char(unpack(msgbyte)))
-- debug_print("[encode] end")
end end
function _M:send(type, message) function _M:send(message, type)
_send(self.socket, OPCODES[type or "BINARY"] or OPCODES.BINARY, message) local tempType = OPCODES.BINARY
for _, opcode in pairs(_M.OPCODES) do
if type == opcode then
tempType = type
end
end
_send(self.socket, tempType, message)
end end
function _M:read() function _M:read()
@@ -101,7 +121,6 @@ function _M:read()
local OPCODE = band(res:byte(), 0x0f) local OPCODE = band(res:byte(), 0x0f)
-- local flag_FIN = res:byte()>=0x80 -- local flag_FIN = res:byte()>=0x80
-- local flag_MASK = res:byte(2)>=0x80 -- local flag_MASK = res:byte(2)>=0x80
-- debug_print("[decode] FIN="..tostring(flag_FIN)..", OPCODE="..OPCODE..", MASK="..tostring(flag_MASK))
-- length -- length
local byte = res:byte(2) local byte = res:byte(2)
@@ -115,11 +134,10 @@ function _M:read()
local b = {res:byte(1, 8)} local b = {res:byte(1, 8)}
length = shl(b[5], 32) + shl(b[6], 24) + shl(b[7], 8) + b[8] length = shl(b[5], 32) + shl(b[6], 24) + shl(b[7], 8) + b[8]
end end
-- debug_print("[decode] message length: "..length)
-- data -- data
res = SOCK:receive(length) res = SOCK:receive(length)
local closeCode local closeCode = nil
if OPCODE == OPCODES.PING then if OPCODE == OPCODES.PING then
self:pong(res) self:pong(res)
elseif OPCODE == OPCODES.CLOSE then elseif OPCODE == OPCODES.CLOSE then
@@ -127,11 +145,14 @@ function _M:read()
res = string.sub(res, 3, string.len(res) - 2) res = string.sub(res, 3, string.len(res) - 2)
self:close() self:close()
end end
-- debug_print("[decode] string length: "..#res)
-- debug_print("[decode] end")
return OPCODE, res, closeCode return OPCODE, res, closeCode
end end
function _M:close()
self.socket:close()
end
function _M:settimeout(t) function _M:settimeout(t)
self.socket:settimeout(t) self.socket:settimeout(t)
end end