515 lines
14 KiB
C++
515 lines
14 KiB
C++
/* Copyright 2018 Canaan Inc.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include "FreeRTOS.h"
|
|
#include "devices.h"
|
|
#include "kernel/driver_impl.hpp"
|
|
#include "network.h"
|
|
#include <lwip/sockets.h>
|
|
#include <lwip/errno.h>
|
|
#include <string.h>
|
|
|
|
using namespace sys;
|
|
|
|
static void check_lwip_error(int result)
|
|
{
|
|
if (result < 0)
|
|
{
|
|
throw errno_exception(strerror(errno), errno);
|
|
}
|
|
}
|
|
|
|
static void to_lwip_sockaddr(sockaddr_in &addr, const socket_address_t &socket_addr)
|
|
{
|
|
if (socket_addr.family != AF_INTERNETWORK)
|
|
throw std::runtime_error("Invalid socket address.");
|
|
|
|
addr.sin_len = sizeof(addr);
|
|
addr.sin_family = AF_INET;
|
|
addr.sin_port = htons(*reinterpret_cast<const uint16_t *>(socket_addr.data + 4));
|
|
addr.sin_addr.s_addr = LWIP_MAKEU32(socket_addr.data[3], socket_addr.data[2], socket_addr.data[1], socket_addr.data[0]);
|
|
}
|
|
|
|
static void to_sys_sockaddr(socket_address_t &addr, const sockaddr_in &socket_addr)
|
|
{
|
|
if (socket_addr.sin_family != AF_INET)
|
|
throw std::runtime_error("Invalid socket address.");
|
|
addr.family = AF_INTERNETWORK;
|
|
addr.data[3] = (socket_addr.sin_addr.s_addr >> 24) & 0xFF;
|
|
addr.data[2] = (socket_addr.sin_addr.s_addr >> 16) & 0xFF;
|
|
addr.data[1] = (socket_addr.sin_addr.s_addr >> 8) & 0xFF;
|
|
addr.data[0] = socket_addr.sin_addr.s_addr & 0xFF;
|
|
*reinterpret_cast<uint16_t *>(addr.data + 4) = ntohs(socket_addr.sin_port);
|
|
}
|
|
|
|
class k_network_socket : public network_socket, public heap_object, public exclusive_object_access
|
|
{
|
|
public:
|
|
k_network_socket(address_family_t address_family, socket_type_t type, protocol_type_t protocol)
|
|
{
|
|
int domain;
|
|
switch (address_family)
|
|
{
|
|
case AF_UNSPECIFIED:
|
|
case AF_INTERNETWORK:
|
|
domain = AF_INET;
|
|
break;
|
|
default:
|
|
throw std::invalid_argument("Invalid address family.");
|
|
}
|
|
|
|
int s_type;
|
|
switch (type)
|
|
{
|
|
case SOCKET_STREAM:
|
|
s_type = SOCK_STREAM;
|
|
break;
|
|
case SOCKET_DATAGRAM:
|
|
s_type = SOCK_DGRAM;
|
|
break;
|
|
default:
|
|
throw std::invalid_argument("Invalid socket type.");
|
|
}
|
|
|
|
int s_protocol;
|
|
switch (protocol)
|
|
{
|
|
case PROTCL_IP:
|
|
s_protocol = IPPROTO_IP;
|
|
break;
|
|
default:
|
|
throw std::invalid_argument("Invalid protocol type.");
|
|
}
|
|
|
|
auto sock = lwip_socket(domain, s_type, s_protocol);
|
|
check_lwip_error(sock);
|
|
sock_ = sock;
|
|
}
|
|
|
|
explicit k_network_socket(int sock)
|
|
: sock_(sock)
|
|
{
|
|
}
|
|
|
|
~k_network_socket()
|
|
{
|
|
lwip_close(sock_);
|
|
}
|
|
|
|
virtual void install() override
|
|
{
|
|
}
|
|
|
|
virtual object_accessor<network_socket> accept(socket_address_t *remote_address) override
|
|
{
|
|
object_ptr<k_network_socket> socket(std::in_place, new k_network_socket());
|
|
|
|
sockaddr_in remote;
|
|
socklen_t remote_len = sizeof(remote);
|
|
|
|
auto sock = lwip_accept(sock_, reinterpret_cast<sockaddr *>(&remote), &remote_len);
|
|
check_lwip_error(sock);
|
|
socket->sock_ = sock;
|
|
if (remote_address)
|
|
to_sys_sockaddr(*remote_address, remote);
|
|
return make_accessor(socket);
|
|
}
|
|
|
|
virtual void bind(const socket_address_t &address) override
|
|
{
|
|
sockaddr_in addr;
|
|
to_lwip_sockaddr(addr, address);
|
|
check_lwip_error(lwip_bind(sock_, reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
|
|
}
|
|
|
|
virtual void connect(const socket_address_t &address) override
|
|
{
|
|
sockaddr_in addr;
|
|
to_lwip_sockaddr(addr, address);
|
|
check_lwip_error(lwip_connect(sock_, reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
|
|
}
|
|
|
|
virtual void listen(uint32_t backlog) override
|
|
{
|
|
check_lwip_error(lwip_listen(sock_, backlog));
|
|
}
|
|
|
|
virtual void shutdown(socket_shutdown_t how) override
|
|
{
|
|
int s_how;
|
|
switch (how)
|
|
{
|
|
case SOCKSHTDN_SEND:
|
|
s_how = SHUT_WR;
|
|
break;
|
|
case SOCKSHTDN_RECEIVE:
|
|
s_how = SHUT_RD;
|
|
break;
|
|
case SOCKSHTDN_BOTH:
|
|
s_how = SHUT_RDWR;
|
|
break;
|
|
default:
|
|
throw std::invalid_argument("Invalid how.");
|
|
}
|
|
|
|
check_lwip_error(lwip_shutdown(sock_, s_how));
|
|
}
|
|
|
|
virtual size_t send(gsl::span<const uint8_t> buffer, socket_message_flag_t flags) override
|
|
{
|
|
uint8_t send_flags = 0;
|
|
if (flags & MESSAGE_PEEK)
|
|
send_flags |= MSG_PEEK;
|
|
if (flags & MESSAGE_WAITALL)
|
|
send_flags |= MSG_WAITALL;
|
|
if (flags & MESSAGE_OOB)
|
|
send_flags |= MSG_OOB;
|
|
if (flags & MESSAGE_DONTWAIT)
|
|
send_flags |= MSG_DONTWAIT;
|
|
if (flags & MESSAGE_MORE)
|
|
send_flags |= MSG_MORE;
|
|
|
|
auto ret = lwip_send(sock_, buffer.data(), buffer.size_bytes(), send_flags);
|
|
check_lwip_error(ret);
|
|
configASSERT(ret == buffer.size_bytes());
|
|
return ret;
|
|
}
|
|
|
|
virtual size_t receive(gsl::span<uint8_t> buffer, socket_message_flag_t flags) override
|
|
{
|
|
uint8_t recv_flags = 0;
|
|
if (flags & MESSAGE_PEEK)
|
|
recv_flags |= MSG_PEEK;
|
|
if (flags & MESSAGE_WAITALL)
|
|
recv_flags |= MSG_WAITALL;
|
|
if (flags & MESSAGE_OOB)
|
|
recv_flags |= MSG_OOB;
|
|
if (flags & MESSAGE_DONTWAIT)
|
|
recv_flags |= MSG_DONTWAIT;
|
|
if (flags & MESSAGE_MORE)
|
|
recv_flags |= MSG_MORE;
|
|
auto ret = lwip_recv(sock_, buffer.data(), buffer.size_bytes(), recv_flags);
|
|
check_lwip_error(ret);
|
|
return ret;
|
|
}
|
|
|
|
virtual size_t send_to(gsl::span<const uint8_t> buffer, socket_message_flag_t flags, const socket_address_t &to) override
|
|
{
|
|
uint8_t send_flags = 0;
|
|
if (flags & MESSAGE_PEEK)
|
|
send_flags |= MSG_PEEK;
|
|
if (flags & MESSAGE_WAITALL)
|
|
send_flags |= MSG_WAITALL;
|
|
if (flags & MESSAGE_OOB)
|
|
send_flags |= MSG_OOB;
|
|
if (flags & MESSAGE_DONTWAIT)
|
|
send_flags |= MSG_DONTWAIT;
|
|
if (flags & MESSAGE_MORE)
|
|
send_flags |= MSG_MORE;
|
|
|
|
sockaddr_in remote;
|
|
socklen_t remote_len = sizeof(remote);
|
|
to_lwip_sockaddr(remote, to);
|
|
|
|
auto ret = lwip_sendto(sock_, buffer.data(), buffer.size_bytes(), send_flags, reinterpret_cast<const sockaddr *>(&remote), remote_len);
|
|
check_lwip_error(ret);
|
|
configASSERT(ret == buffer.size_bytes());
|
|
return ret;
|
|
}
|
|
|
|
virtual size_t receive_from(gsl::span<uint8_t> buffer, socket_message_flag_t flags, socket_address_t *from) override
|
|
{
|
|
uint8_t recv_flags = 0;
|
|
if (flags & MESSAGE_PEEK)
|
|
recv_flags |= MSG_PEEK;
|
|
if (flags & MESSAGE_WAITALL)
|
|
recv_flags |= MSG_WAITALL;
|
|
if (flags & MESSAGE_OOB)
|
|
recv_flags |= MSG_OOB;
|
|
if (flags & MESSAGE_DONTWAIT)
|
|
recv_flags |= MSG_DONTWAIT;
|
|
if (flags & MESSAGE_MORE)
|
|
recv_flags |= MSG_MORE;
|
|
|
|
sockaddr_in remote;
|
|
socklen_t remote_len = sizeof(remote);
|
|
|
|
auto ret = lwip_recvfrom(sock_, buffer.data(), buffer.size_bytes(), recv_flags, reinterpret_cast<sockaddr *>(&remote), &remote_len);
|
|
check_lwip_error(ret);
|
|
if (from)
|
|
to_sys_sockaddr(*from, remote);
|
|
return ret;
|
|
}
|
|
|
|
virtual size_t read(gsl::span<uint8_t> buffer) override
|
|
{
|
|
auto ret = lwip_read(sock_, buffer.data(), buffer.size_bytes());
|
|
check_lwip_error(ret);
|
|
return ret;
|
|
}
|
|
|
|
virtual size_t write(gsl::span<const uint8_t> buffer) override
|
|
{
|
|
auto ret = lwip_write(sock_, buffer.data(), buffer.size_bytes());
|
|
check_lwip_error(ret);
|
|
configASSERT(ret == buffer.size_bytes());
|
|
return ret;
|
|
}
|
|
|
|
virtual int fcntl(int cmd, int val) override
|
|
{
|
|
auto ret = lwip_fcntl(sock_, cmd, val);
|
|
check_lwip_error(ret);
|
|
return ret;
|
|
}
|
|
|
|
virtual void select(fd_set *readset, fd_set *writeset, fd_set *exceptset, struct timeval *timeout) override
|
|
{
|
|
check_lwip_error(lwip_select(sock_ + 1, readset, writeset, exceptset, timeout));
|
|
}
|
|
|
|
virtual int control(uint32_t control_code, gsl::span<const uint8_t> write_buffer, gsl::span<uint8_t> read_buffer) override
|
|
{
|
|
int val = *reinterpret_cast<const int *>(write_buffer.data());
|
|
check_lwip_error(lwip_ioctl(sock_, (unsigned int)control_code, &val));
|
|
return 0;
|
|
}
|
|
|
|
private:
|
|
k_network_socket()
|
|
: sock_(0)
|
|
{
|
|
}
|
|
|
|
private:
|
|
int sock_;
|
|
};
|
|
|
|
#define SOCKET_ENTRY \
|
|
auto &obj = system_handle_to_object(socket_handle); \
|
|
configASSERT(obj.is<k_network_socket>()); \
|
|
auto f = obj.as<k_network_socket>();
|
|
|
|
#define CATCH_ALL \
|
|
catch (errno_exception &e) \
|
|
{ \
|
|
errno = e.code(); \
|
|
return -1; \
|
|
}
|
|
|
|
#define CHECK_ARG(x) \
|
|
if (!x) \
|
|
throw std::invalid_argument(#x " is invalid.");
|
|
|
|
handle_t network_socket_open(address_family_t address_family, socket_type_t type, protocol_type_t protocol)
|
|
{
|
|
try
|
|
{
|
|
auto socket = make_object<k_network_socket>(address_family, type, protocol);
|
|
return system_alloc_handle(make_accessor<object_access>(socket));
|
|
}
|
|
catch (...)
|
|
{
|
|
return NULL_HANDLE;
|
|
}
|
|
}
|
|
|
|
handle_t network_socket_close(handle_t socket_handle)
|
|
{
|
|
return io_close(socket_handle);
|
|
}
|
|
|
|
int network_socket_connect(handle_t socket_handle, const socket_address_t *remote_address)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
CHECK_ARG(remote_address);
|
|
|
|
f->connect(*remote_address);
|
|
return 0;
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
int network_socket_listen(handle_t socket_handle, uint32_t backlog)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
|
|
f->listen(backlog);
|
|
return 0;
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
handle_t network_socket_accept(handle_t socket_handle, socket_address_t *remote_address)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
CHECK_ARG(remote_address);
|
|
|
|
return system_alloc_handle(f->accept(remote_address));
|
|
}
|
|
catch (...)
|
|
{
|
|
return NULL_HANDLE;
|
|
}
|
|
}
|
|
|
|
int network_socket_shutdown(handle_t socket_handle, socket_shutdown_t how)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
|
|
f->shutdown(how);
|
|
return 0;
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
int network_socket_bind(handle_t socket_handle, const socket_address_t *local_address)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
CHECK_ARG(local_address);
|
|
|
|
f->bind(*local_address);
|
|
return 0;
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
int network_socket_send(handle_t socket_handle, const uint8_t *data, size_t len, socket_message_flag_t flags)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
|
|
f->send({ data, std::ptrdiff_t(len) }, flags);
|
|
return 0;
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
int network_socket_receive(handle_t socket_handle, uint8_t *data, size_t len, socket_message_flag_t flags)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
|
|
return f->receive({ data, std::ptrdiff_t(len) }, flags);
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
int network_socket_send_to(handle_t socket_handle, const uint8_t *data, size_t len, socket_message_flag_t flags, const socket_address_t *to)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
|
|
f->send_to({ data, std::ptrdiff_t(len) }, flags, *to);
|
|
return 0;
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
int network_socket_receive_from(handle_t socket_handle, uint8_t *data, size_t len, socket_message_flag_t flags, socket_address_t *from)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
|
|
return f->receive_from({ data, std::ptrdiff_t(len) }, flags, from);
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
int network_socket_fcntl(handle_t socket_handle, int cmd, int val)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
|
|
return f->fcntl(cmd, val);
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
int network_socket_select(int socket_handle, fd_set *readset, fd_set *writeset, fd_set *exceptset, struct timeval *timeout)
|
|
{
|
|
try
|
|
{
|
|
SOCKET_ENTRY;
|
|
|
|
f->select(readset, writeset, exceptset, timeout);
|
|
return 0;
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
int network_socket_addr_parse(const char *ip_addr, int port, uint8_t *socket_addr)
|
|
{
|
|
try
|
|
{
|
|
const char *sep = ".";
|
|
char *p;
|
|
int data;
|
|
char ip_addr_p[16];
|
|
strcpy(ip_addr_p, ip_addr);
|
|
uint8_t *socket_addr_p = socket_addr;
|
|
p = strtok(ip_addr_p, sep);
|
|
while (p)
|
|
{
|
|
data = atoi(p);
|
|
if (data > 255)
|
|
throw std::invalid_argument(" ipaddr is invalid.");
|
|
*socket_addr_p++ = (uint8_t)data;
|
|
p = strtok(NULL, sep);
|
|
}
|
|
if (socket_addr_p - socket_addr != 4)
|
|
throw std::invalid_argument(" ipaddr size is invalid.");
|
|
*socket_addr_p++ = port & 0xff;
|
|
*socket_addr_p = (port >> 8) & 0xff;
|
|
return 0;
|
|
}
|
|
CATCH_ALL;
|
|
}
|
|
|
|
int network_socket_addr_to_string(uint8_t *socket_addr, char *ip_addr, int *port)
|
|
{
|
|
try
|
|
{
|
|
char *p = ip_addr;
|
|
|
|
int i = 0;
|
|
do
|
|
{
|
|
char tmp[8] = { 0 };
|
|
itoa(socket_addr[i++], tmp, 10);
|
|
strcpy(p, tmp);
|
|
p += strlen(tmp);
|
|
} while ((i < 4) && (*p++ = '.'));
|
|
|
|
*port = (int)(socket_addr[4] | (socket_addr[5] << 8));
|
|
return 0;
|
|
}
|
|
CATCH_ALL;
|
|
}
|