beterraba

ref: master

cmd/beterrabad/socket.ha


use beterraba;
use bufio;
use dirs;
use encoding::utf8;
use errors;
use fmt;
use fs;
use io;
use log;
use net::unix;
use net;
use os;
use path;
use strings;
use unix::poll::{event};
use unix::poll;
use unix::signal;

type servererror = !(io::error | fs::error);

type server = struct {
	sock: net::socket,
	signalfd: io::file,
	pollfd: []poll::pollfd,
	clients: []client,
	services: []beterraba::service,
	disconnected: bool
};

type client = struct {
	server: *server,
	sock: io::file,
	pollfd: *poll::pollfd,
	state: state,
	wbuf: []u8,
	rbuf: []u8
};

type state = enum {
	READ,
	WRITE,
	WRITE_ERROR,
};

fn bind(fd: io::file) server = {
	let runtime = match(dirs::runtime()) {
	case let dir: str =>
		yield dir;
	case let err: fs::error =>
		log::fatalf("Some error on trying to find runtime dir {}",
			fs::strerror(err));
	};

	let pathbuf = path::init();
	path::add(&pathbuf, runtime, "beterrabad")!;
	let sockpath = path::string(&pathbuf);

	const sock = match(unix::listen(sockpath, net::sockflags::NOCLOEXEC)) {
		case let err: net::error =>
			log::fatalf("Could not create socket: ", net::strerror(err));
		case let unixsock: net::socket =>
			yield unixsock;
	};

	let pollfd = alloc([poll::pollfd {
		fd = sock,
		events = event::POLLIN,
		...
	}, poll::pollfd {
		fd = fd,
		events = event::POLLIN,
		...
	}]);

	return server {
		sock = sock,
		pollfd = pollfd,
		signalfd = fd,
		...
	};
};

fn dispatch(s: *server) bool = {
	match(poll::poll(s.pollfd, poll::INDEF)) {
	case uint => {
		if (s.pollfd[0].revents & event::POLLIN != 0) {
			accept(s);
		};

		if (s.pollfd[1].revents & event::POLLIN != 0) {
			signal::read(s.signalfd)!;
			return false;
		};

		for (let i = 2z; i < len(s.pollfd); i += 1) {
			dispatch_client(s, &s.clients[i - 2]);

			if (s.disconnected) {
				// Restart loop on client disconnect
				s.disconnected = false;
				i = 2z;
			};
		};
	};
	case let err: errors::error =>
		log::fatal("poll:", errors::strerror(err));
	};

	return true;
};

fn dispatch_client(s: *server, client: *client) void = {
	let cpollfd = client.pollfd;

	if (cpollfd.revents & event::POLLERR != 0) {
		disconnect_client(client);
		return;
	};

	if (cpollfd.revents & event::POLLHUP != 0) {
		disconnect_client(client);
		return;
	};

	if (cpollfd.revents & event::POLLIN != 0) {
		read_client(s, client);
	};

	if (cpollfd.revents & event::POLLOUT != 0) {
		write_client(s, client);
	};
};

fn write_client(s: *server, client: *client) void = {
	let sz = match (io::write(client.sock, client.wbuf)) {
	case let z: size =>
		yield z;
	case errors::again =>
		return;
	case let err: io::error =>
		log::printfln("Couldn't write to client sock due to {}", io::strerror(err));
		disconnect_client(client);

		return;
	};

	// Clean up the buffer
	delete(client.wbuf[..sz]);

	switch (client.state) {
	case state::WRITE =>
		client.state = state::READ;
		client.pollfd.events = event::POLLIN | event::POLLHUP;
	case state::WRITE_ERROR =>
		disconnect_client(client);
	case => abort();
	};
};

fn read_client(s: *server, client: *client) void = {
	let bufline = match (bufio::scanline(client.sock)) {
	case let l: []u8 =>
		yield l;
	case io::EOF =>
		disconnect_client(client);
	case io::error =>
		disconnect_client(client);
	};

	let line = match (strings::fromutf8(bufline as []u8)) {
	case let s: str =>
		yield s;
	case utf8::invalid =>
		log::fatal("invalid utf-8");
	};

	match (exec(line, s, client)) {
	case void =>
		void;
	case servererror =>
		log::fatal("Fudeu");
	};
};

fn exec(line: str, server: *server, client: *client) (servererror | void) = {
	let buf = bufio::dynamic(io::mode::WRITE);
	let sline = strings::cut(line, " ");
	let cmd = sline.0;
	let args = sline.1;

	// TODO: Flesh out command execution here
	switch (cmd) {
	case "start" =>
		let status = start_service(args, server);
		fmt::fprintf(&buf, status)?;
	case "status" =>
		let status = status_service(args, server);
		fmt::fprintf(&buf, status)?;
	case "started" =>
		let status = notify_service("started", args, server);
		fmt::fprintf(&buf, status)?;
	case "crashed" =>
		let status = notify_service("crashed", args, server);
		fmt::fprintf(&buf, status)?;
	case "stopped" =>
		let status = notify_service("stopped", args, server);
		fmt::fprintf(&buf, status)?;
	case "list" =>
		let status = list_services(args, server);
		fmt::fprintf(&buf, status)?;
	case "ping" =>
		fmt::fprintf(&buf, "pong")?;
	case =>
		fmt::fprintf(&buf, "unknown command")?;
	};

	fmt::fprintln(&buf, "\nend")?;
	writebuf(client, bufio::buffer(&buf));
};

// Writes data to the client. Takes ownership over the buffer.
fn commit_write(client: *client, buf: []u8) void = {
	assert(client.state != state::WRITE
		&& client.state != state::WRITE_ERROR);
	client.wbuf = buf;
	client.state = state::WRITE;
	client.pollfd.events = event::POLLOUT | event::POLLHUP;
};

// Writes data to the client. Duplicates the buffer.
fn writebuf(client: *client, buf: []u8) void = {
	commit_write(client, alloc(buf...));
};

// TODO: Check if we reached max client connections first
fn accept(s: *server) void = {
	let clientsock = match(net::accept(s.sock)) {
	case let sock: net::socket =>
		yield sock;
	case let err: net::error =>
		log::fatalf("Couldn't grab a client sock due to {}", net::strerror(err));
	};

	append(s.pollfd, poll::pollfd {
		fd = clientsock,
		events = event::POLLIN | event::POLLHUP,
		...
	});

	let pollfd = &s.pollfd[len(s.pollfd) - 1];

	append(s.clients, client {
		server = s,
		sock = clientsock,
		pollfd = pollfd,
		...
	});
};

// Immediately disconnects a client, without sending them an error message.
fn disconnect_client(c: *client) void = {
	io::close(c.sock)!;
	free(c.rbuf);
	free(c.wbuf);

	let serv = c.server;
	let i = (c: uintptr - serv.clients: *[*]client: uintptr): size / size(client);
	delete(serv.clients[i]);
	delete(serv.pollfd[i + 2z]);
	for (i < len(serv.clients); i += 1) {
		serv.clients[i].pollfd = &serv.pollfd[i + 2z];
	};

	serv.disconnected = true;
};

// TODO: Walk over every connected client and disconnect them & free the clients
fn shutdown(s: *server) void = {
	for (let i = 0z; i < len(s.clients); i += 1) {
		net::close(s.clients[i].sock)!;
		free(s.clients[i].wbuf);
		free(s.clients[i].rbuf);
	};

	free(s.clients);
	free(s.pollfd);

	let pathbuf = path::init();
	path::add(&pathbuf, dirs::runtime()!, "beterrabad")!;

	os::remove(path::string(&pathbuf))!;

	match(net::close(s.sock)) {
	case let err: net::error =>
		log::printfln(
			"There was some error trying to close the socket, {}",
			net::strerror(err)
		);
	case void => void;
	};
};