diff --git a/docs/HRANA_3_SPEC.md b/docs/HRANA_3_SPEC.md new file mode 100644 index 00000000..73dd56d9 --- /dev/null +++ b/docs/HRANA_3_SPEC.md @@ -0,0 +1,1710 @@ +# The Hrana protocol specification (version 3) + +Hrana (from Czech "hrana", which means "edge") is a protocol for connecting to a +SQLite database over the network. It is designed to be used from edge functions +and other environments where low latency and small overhead is important. + +This is a specification for version 3 of the Hrana protocol (Hrana 3). + +## Overview + +The Hrana protocol provides SQL _streams_. Each stream corresponds to a SQLite +connection and executes a sequence of SQL statements. + +### Variants (WebSocket / HTTP) + +The protocol has two variants: + +- Hrana over WebSocket, which uses WebSocket as the underlying protocol. + Multiple streams can be multiplexed over a single WebSocket. +- Hrana over HTTP, which communicates with the server using HTTP requests. This + is less efficient than WebSocket, but HTTP is the only reliable protocol in + some environments. + +Each of these variants is described later. + +### Encoding + +The protocol has two encodings: + +- [JSON][rfc8259] is the canonical encoding, backward compatible with Hrana 1 + and 2. +- Protobuf ([Protocol Buffers][protobuf]) is a more compact binary encoding, + introduced in Hrana 3. + +[rfc8259]: https://datatracker.ietf.org/doc/html/rfc8259 +[protobuf]: https://protobuf.dev/ + +This document defines protocol structures in JSON and specifies the schema using +TypeScript type notation. The Protobuf schema is described in proto3 syntax in +an appendix. + +The encoding is negotiated between the server and client. This process depends +on the variant (WebSocket or HTTP) and is described later. All Hrana 3 servers +must support both JSON and Protobuf; clients can choose which encodings to +support and use. + +Both encodings support forward compatibility: when a peer (client or server) +receives a protocol structure that includes an unrecognized field (object +property in JSON or a message field in Protobuf), it must ignore this field. + + + +## Hrana over WebSocket + +Hrana over WebSocket runs on top of the [WebSocket protocol][rfc6455]. + +### Version and encoding negotiation + +The version of the protocol and the encoding is negotiated as a WebSocket +subprotocol: the client includes a list of supported subprotocols in the +`Sec-WebSocket-Protocol` request header in the opening handshake, and the server +replies with the selected subprotocol in the same response header. + +The negotiation mechanism provides backward compatibility with older versions of +the Hrana protocol and forward compatibility with newer versions. + +[rfc6455]: https://www.rfc-editor.org/rfc/rfc6455 + +The WebSocket subprotocols defined in all Hrana versions are as follows: + +| Subprotocol | Version | Encoding | +|-------------|---------|----------| +| `hrana1` | 1 | JSON | +| `hrana2` | 2 | JSON | +| `hrana3` | 3 | JSON | +| `hrana3-protobuf` | 3 | Protobuf | + +This document describes version 3 of the Hrana protocol. Versions 1 and 2 are +described in their own specifications. + +Version 3 of Hrana over WebSocket is designed to be a strict superset of +versions 1 and 2: every server that implements Hrana 3 over WebSocket also +implements versions 1 and 2 and should accept clients that indicate subprotocol +`hrana1` or `hrana2`. + +### Overview + +The client starts the connection by sending a _hello_ message, which +authenticates the client to the server. The server responds with either a +confirmation or with an error message, closing the connection. The client can +choose not to wait for the confirmation and immediately send further messages to +reduce latency. + +A single connection can host an arbitrary number of streams. In effect, one +Hrana connection works as a "connection pool" in traditional SQL servers. + +After a stream is opened, the client can execute SQL statements on it. For the +purposes of this protocol, the statements are arbitrary strings with optional +parameters. + +To reduce the number of roundtrips, the protocol supports batches of statements +that are executed conditionally, based on success or failure of previous +statements. Clients can use this mechanism to implement non-interactive +transactions in a single roundtrip. + +### Messages + +If the negotiated encoding is JSON, all messages exchanged between the client +and server are sent as text frames (opcode 0x1) on the WebSocket. If the +negotiated encoding is Protobuf, messages are sent as binary frames (opcode +0x2). + +```typescript +type ClientMsg = + | HelloMsg + | RequestMsg + +type ServerMsg = + | HelloOkMsg + | HelloErrorMsg + | ResponseOkMsg + | ResponseErrorMsg +``` + +The client sends messages of type `ClientMsg`, and the server sends messages of +type `ServerMsg`. The type of the message is determined by its `type` field. + +#### Hello + +```typescript +type HelloMsg = { + "type": "hello", + "jwt": string | null, +} +``` + +The `hello` message is sent as the first message by the client. It authenticates +the client to the server using the [Json Web Token (JWT)][rfc7519] passed in the +`jwt` field. If no authentication is required (which might be useful for +development and debugging, or when authentication is performed by other means, +such as with mutual TLS), the `jwt` field might be set to `null`. + +[rfc7519]: https://www.rfc-editor.org/rfc/rfc7519 + +The client can also send the `hello` message again anytime during the lifetime +of the connection to reauthenticate, by providing a new JWT. If the provided JWT +expires and the client does not provide a new one in a `hello` message, the +server may terminate the connection. + +```typescript +type HelloOkMsg = { + "type": "hello_ok", +} + +type HelloErrorMsg = { + "type": "hello_error", + "error": Error, +} +``` + +The server waits for the `hello` message from the client and responds with a +`hello_ok` message if the client can proceed, or with a `hello_error` message +describing the failure. + +The client may choose not to wait for a response to its `hello` message before +sending more messages to save a network roundtrip. If the server responds with +`hello_error`, it must ignore all further messages sent by the client and it +should close the WebSocket immediately. + +#### Request/response + +```typescript +type RequestMsg = { + "type": "request", + "request_id": int32, + "request": Request, +} +``` + +After sending the `hello` message, the client can start sending `request` +messages. The client uses requests to open SQL streams and execute statements on +them. The client assigns an identifier to every request, which is then used to +match a response to the request. + +The `Request` structure represents the payload of the request and is defined +later. + +```typescript +type ResponseOkMsg = { + "type": "response_ok", + "request_id": int32, + "response": Response, +} + +type ResponseErrorMsg = { + "type": "response_error", + "request_id": int32, + "error": Error, +} +``` + +When the server receives a `request` message, it must eventually send either a +`response_ok` with the response or a `response_error` that describes a failure. +The response from the server includes the same `request_id` that was provided by +the client in the request. The server can send the responses in arbitrary order. + +The request ids are arbitrary 32-bit signed integers, the server does not +interpret them in any way. + +The server should limit the number of outstanding requests to a reasonable +value, and stop receiving messages when this limit is reached. This will cause +the TCP flow control to kick in and apply back-pressure to the client. On the +other hand, the client should always receive messages, to avoid deadlock. + +### Requests + +Most of the work in the protocol happens in request/response interactions. + +```typescript +type Request = + | OpenStreamReq + | CloseStreamReq + | ExecuteReq + | BatchReq + | OpenCursorReq + | CloseCursorReq + | FetchCursorReq + | SequenceReq + | DescribeReq + | StoreSqlReq + | CloseSqlReq + | GetAutocommitReq + +type Response = + | OpenStreamResp + | CloseStreamResp + | ExecuteResp + | BatchResp + | OpenCursorResp + | CloseCursorResp + | FetchCursorResp + | SequenceResp + | DescribeResp + | StoreSqlReq + | CloseSqlReq + | GetAutocommitResp +``` + +The type of the request and response is determined by its `type` field. The +`type` of the response must always match the `type` of the request. The +individual requests and responses are defined in the rest of this section. + +#### Open stream + +```typescript +type OpenStreamReq = { + "type": "open_stream", + "stream_id": int32, +} + +type OpenStreamResp = { + "type": "open_stream", +} +``` + +The client uses the `open_stream` request to open an SQL stream, which is then +used to execute SQL statements. The streams are identified by arbitrary 32-bit +signed integers assigned by the client. + +The client can optimistically send follow-up requests on a stream before it +receives the response to its `open_stream` request. If the server receives a +request that refers to a stream that failed to open, it should respond with an +error, but it should not close the connection. + +Even if the `open_stream` request returns an error, the stream id is still +considered as used, and the client cannot reuse it until it sends a +`close_stream` request. + +The server can impose a reasonable limit to the number of streams opened at the +same time. + +> This request was introduced in Hrana 1. + +#### Close stream + +```typescript +type CloseStreamReq = { + "type": "close_stream", + "stream_id": int32, +} + +type CloseStreamResp = { + "type": "close_stream", +} +``` + +When the client is done with a stream, it should close it using the +`close_stream` request. The client can safely reuse the stream id after it +receives the response. + +The client should close even streams for which the `open_stream` request +returned an error. + +If there is an open cursor for the stream, the cursor is closed together with +the stream. + +> This request was introduced in Hrana 1. + +#### Execute a statement + +```typescript +type ExecuteReq = { + "type": "execute", + "stream_id": int32, + "stmt": Stmt, +} + +type ExecuteResp = { + "type": "execute", + "result": StmtResult, +} +``` + +The client sends an `execute` request to execute an SQL statement on a stream. +The server responds with the result of the statement. The `Stmt` and +`StmtResult` structures are defined later. + +If the statement fails, the server responds with an error response (message of +type `"response_error"`). + +> This request was introduced in Hrana 1. + +#### Execute a batch + +```typescript +type BatchReq = { + "type": "batch", + "stream_id": int32, + "batch": Batch, +} + +type BatchResp = { + "type": "batch", + "result": BatchResult, +} +``` + +The `batch` request runs a batch of statements on a stream. The server responds +with the result of the batch execution. + +If a statement in the batch fails, the error is returned inside the +`BatchResult` structure in a normal response (message of type `"response_ok"`). +However, if the server encounters a serious error that prevents it from +executing the batch, it responds with an error response (message of type +`"response_error"`). + +> This request was introduced in Hrana 1. + +#### Open a cursor executing a batch + +```typescript +type OpenCursorReq = { + "type": "open_cursor", + "stream_id": int32, + "cursor_id": int32, + "batch": Batch, +} + +type OpenCursorResp = { + "type": "open_cursor", +} +``` + +The `open_cursor` request runs a batch of statements like the `batch` request, +but instead of returning all statement results in the request response, it opens +a _cursor_ which the client can then use to read the results incrementally. + +The `cursor_id` is an arbitrary 32-bit integer id assigned by the client. This +id must be unique for the given connection and must not be used by another +cursor that was not yet closed using the `close_cursor` request. + +Even if the `open_cursor` request returns an error, the cursor id is still +considered as used, and the client cannot reuse it until it sends a +`close_cursor` request. + +After the `open_cursor` request, the client must not send more requests on the +stream until the cursor is closed using the `close_cursor` request. + +> This request was introduced in Hrana 3. + +#### Close a cursor + +```typescript +type CloseCursorReq = { + "type": "close_cursor", + "cursor_id": int32, +} + +type CloseCursorResp = { + "type": "close_cursor", +} +``` + +The `close_cursor` request closes a cursor opened by an `open_cursor` request +and allows the server to release resources and continue processing other +requests for the given stream. + +> This request was introduced in Hrana 3. + +#### Fetch entries from a cursor + +```typescript +type FetchCursorReq = { + "type": "fetch_cursor", + "cursor_id": int32, + "max_count": uint32, +} + +type FetchCursorResp = { + "type": "fetch_cursor", + "entries": Array, + "done": boolean, +} +``` + +The `fetch_cursor` request reads data from a cursor previously opened with the +`open_cursor` request. The cursor data is encoded as a sequence of entries +(`CursorEntry` structure). `max_count` in the request specifies the maximum +number of entries that the client wants to receive in the response; however, the +server may decide to send fewer entries. + +If the `done` field in the response is set to true, then the cursor is finished +and all subsequent calls to `fetch_cursor` are guaranteed to return zero +entries. The client should then close the cursor by sending the `close_cursor` +request. + +If the `cursor_id` refers to a cursor for which the `open_cursor` request +returned an error, and the cursor hasn't yet been closed with `close_cursor`, +then the server should return an error, but it must not close the connection +(i.e., this is not a protocol error). + +> This request was introduced in Hrana 3. + +#### Store an SQL text on the server + +```typescript +type StoreSqlReq = { + "type": "store_sql", + "sql_id": int32, + "sql": string, +} + +type StoreSqlResp = { + "type": "store_sql", +} +``` + +The `store_sql` request stores an SQL text on the server. The client can then +refer to this SQL text in other requests by its id, instead of repeatedly +sending the same string over the network. + +SQL text ids are arbitrary 32-bit signed integers assigned by the client. It is +a protocol error if the client tries to store an SQL text with an id which is +already in use. + +> This request was introduced in Hrana 2. + +#### Close a stored SQL text + +```typescript +type CloseSqlReq = { + "type": "close_sql", + "sql_id": int32, +} + +type CloseSqlResp = { + "type": "close_sql", +} +``` + +The `close_sql` request can be used to delete an SQL text stored on the server +with `store_sql`. The client can safely reuse the SQL text id after it receives +the response. + +It is not an error if the client attempts to close a SQL text id that is not +used. + +> This request was introduced in Hrana 2. + +#### Execute a sequence of SQL statements + +```typescript +type SequenceReq = { + "type": "sequence", + "stream_id": int32, + "sql"?: string | null, + "sql_id"?: int32 | null, +} + +type SequenceResp = { + "type": "sequence", +} +``` + +The `sequence` request executes a sequence of SQL statements separated by +semicolons on the stream given by `stream_id`. `sql` or `sql_id` specify the SQL +text; exactly one of these fields must be specified. + +Any rows returned by the statements are ignored. If any statement fails, the +subsequent statements are not executed and the request returns an error +response. + +> This request was introduced in Hrana 2. + +#### Describe a statement + +```typescript +type DescribeReq = { + "type": "describe", + "stream_id": int32, + "sql"?: string | null, + "sql_id"?: int32 | null, +} + +type DescribeResp = { + "type": "describe", + "result": DescribeResult, +} +``` + +The `describe` request is used to parse and analyze a SQL statement. `stream_id` +specifies the stream on which the statement is parsed. `sql` or `sql_id` specify +the SQL text: exactly one of these two fields must be specified, `sql` passes +the SQL directly as a string, while `sql_id` refers to a SQL text previously +stored with `store_sql`. In the response, `result` contains the result of +describing a statement. + +> This request was introduced in Hrana 2. + +#### Get the autocommit state + +```typescript +type GetAutocommitReq = { + "type": "get_autocommit", + "stream_id": int32, +} + +type GetAutocommitResp = { + "type": "get_autocommit", + "is_autocommit": bool, +} +``` + +The `get_autocommit` request can be used to check whether the stream is in +autocommit state (not inside an explicit transaction). + +> This request was introduced in Hrana 3. + +### Errors + +If either peer detects that the protocol has been violated, it should close the +WebSocket with an appropriate WebSocket close code and reason. Some examples of +protocol violations include: + +- Text message payload that is not a valid JSON. +- Data frame type that does not match the negotiated encoding (i.e., binary frame when + the encoding is JSON or a text frame when the encoding is Protobuf). +- Unrecognized `ClientMsg` or `ServerMsg` (the field `type` is unknown or + missing) +- Client receives a `ResponseOkMsg` or `ResponseErrorMsg` with a `request_id` + that has not been sent in a `RequestMsg` or that has already received a + response. + +### Ordering + +The protocol allows the server to reorder the responses: it is not necessary to +send the responses in the same order as the requests. However, the server must +process requests related to a single stream id in order. + +For example, this means that a client can send an `open_stream` request +immediately followed by a batch of `execute` requests on that stream and the +server will always process them in correct order. + + + +## Hrana over HTTP + +Hrana over HTTP runs on top of HTTP. Any version of the HTTP protocol can be +used. + +### Overview + +HTTP is a stateless protocol, so there is no concept of a connection like in the +WebSocket protocol. However, Hrana needs to expose stateful streams, so it needs +to ensure that requests on the same stream are tied together. + +This is accomplished by the use of a baton, which is similar to a session cookie. +The server returns a baton in every response to a request on the stream, and the +client then needs to include the baton in the subsequent request. The client +must serialize the requests on a stream: it must wait for a response to the +previous request before sending next request on the same stream. + +The server can also optionally specify a different URL that the client should +use for the requests on the stream. This can be used to ensure that stream +requests are "sticky" and reach the same server. + +If the client terminates without closing a stream, the server has no way of +finding this out: with Hrana over WebSocket, the WebSocket connection is closed +and the server can close the streams that belong to this connection, but there +is no connection in Hrana over HTTP. Therefore, the server will close streams +after a short period of inactivity, to make sure that abandoned streams don't +accumulate on the server. + +### Version and encoding negotiation + +With Hrana over HTTP, the client indicates the Hrana version and encoding in the +URI path of the HTTP request. The client can check whether the server supports a +given Hrana version by sending an HTTP request (described later). + +### Endpoints + +The client communicates with the server by sending HTTP requests with a +specified method and URL. + +#### Check support for version 3 (JSON) + +``` +GET v3 +``` + +If the server supports version 3 of Hrana over HTTP with JSON encoding, it +should return a 2xx response to this request. + +#### Check support for version 3 (Protobuf) + +``` +GET v3-protobuf +``` + +If the server supports version 3 of Hrana over HTTP with Protobuf encoding, it +should return a 2xx response to this request. + +#### Execute a pipeline of requests (JSON) + +``` +POST v3/pipeline +-> JSON: PipelineReqBody +<- JSON: PipelineRespBody +``` + +```typescript +type PipelineReqBody = { + "baton": string | null, + "requests": Array, +} + +type PipelineRespBody = { + "baton": string | null, + "base_url": string | null, + "results": Array +} + +type StreamResult = + | StreamResultOk + | StreamResultError + +type StreamResultOk = { + "type": "ok", + "response": StreamResponse, +} + +type StreamResultError = { + "type": "error", + "error": Error, +} +``` + +The `v3/pipeline` endpoint is used to execute a pipeline of requests on a +stream. `baton` in the request specifies the stream. If the client sets `baton` +to `null`, the server should create a new stream. + +Server responds with another `baton` value in the response. If the `baton` value +in the response is `null`, it means that the server has closed the stream. The +client must use this value to refer to this stream in the next request (the +`baton` in the response should be different from the `baton` in the request). +This forces the client to issue the requests serially: it must wait for the +response from a previous `pipeline` request before issuing another request on +the same stream. + +The server should ensure that the `baton` values are unpredictable and +unforgeable, for example by cryptographically signing them. + +If the `base_url` in the response is not `null`, the client should use this URL +when sending further requests on this stream. If it is `null`, the client should +use the same URL that it has used for the previous request. The `base_url` +must be an absolute URL with "http" or "https" scheme. + +The `requests` array in the request specifies a sequence of stream requests that +should be executed on the stream. The server executes them in order and returns +the results in the `results` array in the response. Result is either a success +(`type` set to `"ok"`) or an error (`type` set to `"error"`). The server always +executes all requests, even if some of them return errors. + +#### Execute a pipeline of requests (Protobuf) + +``` +POST v3-protobuf/pipeline +-> Protobuf: PipelineReqBody +<- Protobuf: PipelineRespBody +``` + +The `v3-protobuf/pipeline` endpoint is the same as `v3/pipeline`, but it encodes +the request and response body using Protobuf. + +#### Execute a batch using a cursor (JSON) + +``` +POST v3/cursor +-> JSON: CursorReqBody +<- line of JSON: CursorRespBody + lines of JSON: CursorEntry +``` + +```typescript +type CursorReqBody = { + "baton": string | null, + "batch": Batch, +} + +type CursorRespBody = { + "baton": string | null, + "base_url": string | null, +} +``` + +The `v3/cursor` endpoint executes a batch of statements on a stream using a +cursor, so the results can be streamed from the server to the client. + +The HTTP response is composed of JSON structures separated with a newline. The +first line contains the `CursorRespBody` structure, and the following lines +contain `CursorEntry` structures, which encode the result of the batch. + +The `baton` field in the request and the `baton` and `base_url` fields in the +response have the same meaning as in the `v3/pipeline` endpoint. + +#### Execute a batch using a cursor (Protobuf) + +``` +POST v3-protobuf/cursor +-> Protobuf: CursorReqBody +<- length-delimited Protobuf: CursorRespBody + length-delimited Protobufs: CursorEntry +``` + +The `v3-protobuf/cursor` endpoint is the same as `v3/cursor` endpoint, but the +request and response are encoded using Protobuf. + +In the response body, the structures are prefixed with a length delimiter: a +Protobuf varint that encodes the length of the structure. The first structure is +`CursorRespBody`, followed by an arbitrary number of `CursorEntry` structures. + +### Requests + +Requests in Hrana over HTTP closely mirror stream requests in Hrana over +WebSocket: + +```typescript +type StreamRequest = + | CloseStreamReq + | ExecuteStreamReq + | BatchStreamReq + | SequenceStreamReq + | DescribeStreamReq + | StoreSqlStreamReq + | CloseSqlStreamReq + | GetAutocommitStreamReq + +type StreamResponse = + | CloseStreamResp + | ExecuteStreamResp + | BatchStreamResp + | SequenceStreamResp + | DescribeStreamResp + | StoreSqlStreamResp + | CloseSqlStreamResp + | GetAutocommitStreamReq +``` + +#### Close stream + +```typescript +type CloseStreamReq = { + "type": "close", +} + +type CloseStreamResp = { + "type": "close", +} +``` + +The `close` request closes the stream. It is an error if the client tries to +execute more requests on the same stream. + +> This request was introduced in Hrana 2. + +#### Execute a statement + +```typescript +type ExecuteStreamReq = { + "type": "execute", + "stmt": Stmt, +} + +type ExecuteStreamResp = { + "type": "execute", + "result": StmtResult, +} +``` + +The `execute` request has the same semantics as the `execute` request in Hrana +over WebSocket. + +> This request was introduced in Hrana 2. + +#### Execute a batch + +```typescript +type BatchStreamReq = { + "type": "batch", + "batch": Batch, +} + +type BatchStreamResp = { + "type": "batch", + "result": BatchResult, +} +``` + +The `batch` request has the same semantics as the `batch` request in Hrana over +WebSocket. + +> This request was introduced in Hrana 2. + +#### Execute a sequence of SQL statements + +```typescript +type SequenceStreamReq = { + "type": "sequence", + "sql"?: string | null, + "sql_id"?: int32 | null, +} + +type SequenceStreamResp = { + "type": "sequence", +} +``` + +The `sequence` request has the same semantics as the `sequence` request in +Hrana over WebSocket. + +> This request was introduced in Hrana 2. + +#### Describe a statement + +```typescript +type DescribeStreamReq = { + "type": "describe", + "sql"?: string | null, + "sql_id"?: int32 | null, +} + +type DescribeStreamResp = { + "type": "describe", + "result": DescribeResult, +} +``` + +The `describe` request has the same semantics as the `describe` request in +Hrana over WebSocket. + +> This request was introduced in Hrana 2. + +#### Store an SQL text on the server + +```typescript +type StoreSqlStreamReq = { + "type": "store_sql", + "sql_id": int32, + "sql": string, +} + +type StoreSqlStreamResp = { + "type": "store_sql", +} +``` + +The `store_sql` request has the same semantics as the `store_sql` request in +Hrana over WebSocket, except that the scope of the SQL texts is just a single +stream (with WebSocket, it is the whole connection). + +> This request was introduced in Hrana 2. + +#### Close a stored SQL text + +```typescript +type CloseSqlStreamReq = { + "type": "close_sql", + "sql_id": int32, +} + +type CloseSqlStreamResp = { + "type": "close_sql", +} +``` + +The `close_sql` request has the same semantics as the `close_sql` request in +Hrana over WebSocket, except that the scope of the SQL texts is just a single +stream. + +> This request was introduced in Hrana 2. + +#### Get the autocommit state + +```typescript +type GetAutocommitStreamReq = { + "type": "get_autocommit", +} + +type GetAutocommitStreamResp = { + "type": "get_autocommit", + "is_autocommit": bool, +} +``` + +The `get_autocommit` request has the same semantics as the `get_autocommit` +request in Hrana over WebSocket. + +> This request was introduced in Hrana 3. + +### Errors + +If the client receives an HTTP error (4xx or 5xx response), it means that the +server encountered an internal error and the stream is no longer valid. The +client should attempt to parse the response body as an `Error` structure (using +the encoding indicated by the `Content-Type` response header), but the client +must be able to handle responses with different bodies, such as plaintext or +HTML, which might be returned by various components in the HTTP stack. + + + +## Shared structures + +This section describes protocol structures that are common for both Hrana over +WebSocket and Hrana over HTTP. + +### Errors + +```typescript +type Error = { + "message": string, + "code"?: string | null, +} +``` + +Errors can be returned by the server in many places in the protocol, and they +are always represented with the `Error` structure. The `message` field contains +an English human-readable description of the error. The `code` contains a +machine-readable error code. + +At this moment, the error codes are not yet stabilized and depend on the server +implementation. + +> This structure was introduced in Hrana 1. + +### Statements + +```typescript +type Stmt = { + "sql"?: string | null, + "sql_id"?: int32 | null, + "args"?: Array, + "named_args"?: Array, + "want_rows"?: boolean, +} + +type NamedArg = { + "name": string, + "value": Value, +} +``` + +A SQL statement is represented by the `Stmt` structure. The text of the SQL +statement is specified either by passing a string directly in the `sql` field, +or by passing SQL text id that has previously been stored with the `store_sql` +request. Exactly one of `sql` and `sql_id` must be passed. + +The arguments in `args` are bound to parameters in the SQL statement by +position. The arguments in `named_args` are bound to parameters by name. + +In SQLite, the names of arguments include the prefix sign (`:`, `@` or `$`). If +the name of the argument does not start with this prefix, the server will try to +guess the correct prefix. If an argument is specified both as a positional +argument and as a named argument, the named argument should take precedence. + +It is an error if the request specifies an argument that is not expected by the +SQL statement, or if the request does not specify an argument that is expected +by the SQL statement. Some servers may not support specifying both positional +and named arguments. + +The `want_rows` field specifies whether the client is interested in the rows +produced by the SQL statement. If it is set to `false`, the server should always +reply with no rows, even if the statement produced some. If the field is +omitted, the default value is `true`. + +The SQL text should contain just a single statement. Issuing multiple statements +separated by a semicolon is not supported. + +> This structure was introduced in Hrana 1. In Hrana 2, the `sql_id` field was +> added and the `sql` and `want_rows` fields were made optional. + +### Statement results + +```typescript +type StmtResult = { + "cols": Array, + "rows": Array>, + "affected_row_count": uint32, + "last_insert_rowid": string | null, +} + +type Col = { + "name": string | null, + "decltype": string | null, +} +``` + +The result of executing an SQL statement is represented by the `StmtResult` +structure and it contains information about the returned columns in `cols` and +the returned rows in `rows` (the array is empty if the statement did not produce +any rows or if `want_rows` was `false` in the request). + +`affected_row_count` counts the number of rows that were changed by the +statement. This is meaningful only if the statement was an INSERT, UPDATE or +DELETE, and the value is otherwise undefined. + +`last_insert_rowid` is the ROWID of the last successful insert into a rowid +table. The rowid value is a 64-bit signed integer encoded as a string in JSON. +For other statements, the value is undefined. + +> This structure was introduced in Hrana 1. The `decltype` field in the `Col` +> strucure was added in Hrana 2. + +### Batches + +```typescript +type Batch = { + "steps": Array, +} + +type BatchStep = { + "condition"?: BatchCond | null, + "stmt": Stmt, +} +``` + +A batch is represented by the `Batch` structure. It is a list of steps +(statements) which are always executed sequentially. If the `condition` of a +step is present and evaluates to false, the statement is not executed. + +> This structure was introduced in Hrana 1. + +#### Conditions + +```typescript +type BatchCond = + | { "type": "ok", "step": uint32 } + | { "type": "error", "step": uint32 } + | { "type": "not", "cond": BatchCond } + | { "type": "and", "conds": Array } + | { "type": "or", "conds": Array } + | { "type": "is_autocommit" } +``` + +Conditions are expressions that evaluate to true or false: + +- `ok` evaluates to true if the `step` (referenced by its 0-based index) was +executed successfully. If the statement was skipped, this condition evaluates to +false. +- `error` evaluates to true if the `step` (referenced by its 0-based index) has +produced an error. If the statement was skipped, this condition evaluates to +false. +- `not` evaluates `cond` and returns the logical negative. +- `and` evaluates `conds` and returns the logical conjunction of them. +- `or` evaluates `conds` and returns the logical disjunction of them. +- `is_autocommit` evaluates to true if the stream is currently in the autocommit + state (not inside an explicit transaction) + +> This structure was introduced in Hrana 1. The `is_autocommit` type was added in Hrana 3. + +### Batch results + +```typescript +type BatchResult = { + "step_results": Array, + "step_errors": Array, +} +``` + +The result of executing a batch is represented by `BatchResult`. The result +contains the results or errors of statements from each step. For the step in +`steps[i]`, `step_results[i]` contains the result of the statement if the +statement was executed and succeeded, and `step_errors[i]` contains the error if +the statement was executed and failed. If the statement was skipped because its +condition evaluated to false, both `step_results[i]` and `step_errors[i]` will +be `null`. + +> This structure was introduced in Hrana 1. + +### Cursor entries + +```typescript +type CursorEntry = + | StepBeginEntry + | StepEndEntry + | StepErrorEntry + | RowEntry + | ErrorEntry +``` + +Cursor entries are produced by cursors. A sequence of entries encodes the same +information as a `BatchResult`, but it is sent to the client incrementally, so +both peers don't need to keep the whole result in memory. + +> These structures were introduced in Hrana 3. + +#### Step results + +```typescript +type StepBeginEntry = { + "type": "step_begin", + "step": uint32, + "cols": Array, +} + +type StepEndEntry = { + "type": "step_end", + "affected_row_count": uint32, + "last_insert_rowid": string | null, +} + +type RowEntry = { + "type": "row", + "row": Array, +} +``` + +At the beginning of every batch step that is executed, the server produces a +`step_begin` entry. This entry specifies the index of the step (which refers to +the `steps` array in the `Batch` structure). The server sends entries for steps +in the order in which they are executed. If a step is skipped (because its +condition evalated to false), the server does not send any entry for it. + +After a `step_begin` entry, the server sends an arbitrary number of `row` +entries that encode the individual rows produced by the statement, terminated by +the `step_end` entry. Together, these entries encode the same information as the +`StmtResult` structure. + +The server can send another `step_entry` only after the previous step was +terminated by `step_end` or by `step_error`, described below. + +#### Errors + +```typescript +type StepErrorEntry = { + "type": "step_error", + "step": uint32, + "error": Error, +} + +type ErrorEntry = { + "type": "error", + "error": Error, +} +``` + +The `step_error` entry indicates that the execution of a statement failed with +an error. There are two ways in which the server may produce this entry: + +1. Before a `step_begin` entry was sent: this means that the statement failed + very early, without producing any results. The `step` field indicates which + step has failed (similar to the `step_begin` entry). +2. After a `step_begin` entry was sent: in this case, the server has started + executing the statement and produced `step_begin` (and perhaps a number of + `row` entries), but then encountered an error. The `step` field must in this + case be equal to the `step` of the currently processed step. + +The `error` entry means that the execution of the whole batch has failed. This +can be produced by the server at any time, and it is always the last entry in +the cursor. + +### Result of describing a statement + +```typescript +type DescribeResult = { + "params": Array, + "cols": Array, + "is_explain": boolean, + "is_readonly": boolean, +} +``` + +The `DescribeResult` structure is the result of describing a statement. +`is_explain` is true if the statement was an `EXPLAIN` statement, and +`is_readonly` is true if the statement does not modify the database. + +> This structure was introduced in Hrana 2. + +#### Parameters + +```typescript +type DescribeParam = { + "name": string | null, +} +``` + +Information about parameters of the statement is returned in `params`. SQLite +indexes parameters from 1, so the first object in the `params` array describes +parameter 1. + +For each parameter, the `name` field specifies the name of the parameter. For +parameters of the form `?NNN`, `:AAA`, `@AAA` and `$AAA`, the name includes the +initial `?`, `:`, `@` or `$` character. Parameters of the form `?` are nameless, +their `name` is `null`. + +It is also possible that some parameters are not referenced in the statement, in +which case the `name` is also `null`. + +> This structure was introduced in Hrana 2. + +#### Columns + +```typescript +type DescribeCol = { + "name": string, + "decltype": string | null, +} +``` + +Information about columns of the statement is returned in `cols`. + +For each column, `name` specifies the name assigned by the SQL `AS` clause. For +columns without `AS` clause, the name is not specified. + +For result columns that directly originate from tables in the database, +`decltype` specifies the declared type of the column. For other columns (such as +results of expressions), `decltype` is `null`. + +> This structure was introduced in Hrana 2. + +### Values + +```typescript +type Value = + | { "type": "null" } + | { "type": "integer", "value": string } + | { "type": "float", "value": number } + | { "type": "text", "value": string } + | { "type": "blob", "base64": string } +``` + +SQLite values are represented by the `Value` structure. The type of the value +depends on the `type` field: + +- `null`: the SQL NULL value. +- `integer`: a 64-bit signed integer. In JSON, the `value` is a string to avoid + losing precision, because some JSON implementations treat all numbers as + 64-bit floats. +- `float`: a 64-bit float. +- `text`: a UTF-8 string. +- `blob`: a binary blob with. In JSON, the value is base64-encoded. + +> This structure was introduced in Hrana 1. + + + + +## Protobuf schema + +### Hrana over WebSocket + +```proto +syntax = "proto3"; +package hrana.ws; + +message ClientMsg { + oneof msg { + HelloMsg hello = 1; + RequestMsg request = 2; + } +} + +message ServerMsg { + oneof msg { + HelloOkMsg hello_ok = 1; + HelloErrorMsg hello_error = 2; + ResponseOkMsg response_ok = 3; + ResponseErrorMsg response_error = 4; + } +} + +message HelloMsg { + optional string jwt = 1; +} + +message HelloOkMsg { +} + +message HelloErrorMsg { + Error error = 1; +} + +message RequestMsg { + int32 request_id = 1; + oneof request { + OpenStreamReq open_stream = 2; + CloseStreamReq close_stream = 3; + ExecuteReq execute = 4; + BatchReq batch = 5; + OpenCursorReq open_cursor = 6; + CloseCursorReq close_cursor = 7; + FetchCursorReq fetch_cursor = 8; + SequenceReq sequence = 9; + DescribeReq describe = 10; + StoreSqlReq store_sql = 11; + CloseSqlReq close_sql = 12; + GetAutocommitReq get_autocommit = 13; + } +} + +message ResponseOkMsg { + int32 request_id = 1; + oneof response { + OpenStreamResp open_stream = 2; + CloseStreamResp close_stream = 3; + ExecuteResp execute = 4; + BatchResp batch = 5; + OpenCursorResp open_cursor = 6; + CloseCursorResp close_cursor = 7; + FetchCursorResp fetch_cursor = 8; + SequenceResp sequence = 9; + DescribeResp describe = 10; + StoreSqlResp store_sql = 11; + CloseSqlResp close_sql = 12; + GetAutocommitResp get_autocommit = 13; + } +} + +message ResponseErrorMsg { + int32 request_id = 1; + Error error = 2; +} + +message OpenStreamReq { + int32 stream_id = 1; +} + +message OpenStreamResp { +} + +message CloseStreamReq { + int32 stream_id = 1; +} + +message CloseStreamResp { +} + +message ExecuteReq { + int32 stream_id = 1; + Stmt stmt = 2; +} + +message ExecuteResp { + StmtResult result = 1; +} + +message BatchReq { + int32 stream_id = 1; + Batch batch = 2; +} + +message BatchResp { + BatchResult result = 1; +} + +message OpenCursorReq { + int32 stream_id = 1; + int32 cursor_id = 2; + Batch batch = 3; +} + +message OpenCursorResp { +} + +message CloseCursorReq { + int32 cursor_id = 1; +} + +message CloseCursorResp { +} + +message FetchCursorReq { + int32 cursor_id = 1; + uint32 max_count = 2; +} + +message FetchCursorResp { + repeated CursorEntry entries = 1; + bool done = 2; +} + +message StoreSqlReq { + int32 sql_id = 1; + string sql = 2; +} + +message StoreSqlResp { +} + +message CloseSqlReq { + int32 sql_id = 1; +} + +message CloseSqlResp { +} + +message SequenceReq { + int32 stream_id = 1; + optional string sql = 2; + optional int32 sql_id = 3; +} + +message SequenceResp { +} + +message DescribeReq { + int32 stream_id = 1; + optional string sql = 2; + optional int32 sql_id = 3; +} + +message DescribeResp { + DescribeResult result = 1; +} + +message GetAutocommitReq { + int32 stream_id = 1; +} + +message GetAutocommitResp { + bool is_autocommit = 1; +} +``` + +### Hrana over HTTP + +```proto +syntax = "proto3"; +package hrana.http; + +message PipelineReqBody { + optional string baton = 1; + repeated StreamRequest requests = 2; +} + +message PipelineRespBody { + optional string baton = 1; + optional string base_url = 2; + repeated StreamResult results = 3; +} + +message StreamResult { + oneof result { + StreamResponse ok = 1; + Error error = 2; + } +} + +message CursorReqBody { + optional string baton = 1; + Batch batch = 2; +} + +message CursorRespBody { + optional string baton = 1; + optional string base_url = 2; +} + +message StreamRequest { + oneof request { + CloseStreamReq close = 1; + ExecuteStreamReq execute = 2; + BatchStreamReq batch = 3; + SequenceStreamReq sequence = 4; + DescribeStreamReq describe = 5; + StoreSqlStreamReq store_sql = 6; + CloseSqlStreamReq close_sql = 7; + GetAutocommitStreamReq get_autocommit = 8; + } +} + +message StreamResponse { + oneof response { + CloseStreamResp close = 1; + ExecuteStreamResp execute = 2; + BatchStreamResp batch = 3; + SequenceStreamResp sequence = 4; + DescribeStreamResp describe = 5; + StoreSqlStreamResp store_sql = 6; + CloseSqlStreamResp close_sql = 7; + GetAutocommitStreamResp get_autocommit = 8; + } +} + +message CloseStreamReq { +} + +message CloseStreamResp { +} + +message ExecuteStreamReq { + Stmt stmt = 1; +} + +message ExecuteStreamResp { + StmtResult result = 1; +} + +message BatchStreamReq { + Batch batch = 1; +} + +message BatchStreamResp { + BatchResult result = 1; +} + +message SequenceStreamReq { + optional string sql = 1; + optional int32 sql_id = 2; +} + +message SequenceStreamResp { +} + +message DescribeStreamReq { + optional string sql = 1; + optional int32 sql_id = 2; +} + +message DescribeStreamResp { + DescribeResult result = 1; +} + +message StoreSqlStreamReq { + int32 sql_id = 1; + string sql = 2; +} + +message StoreSqlStreamResp { +} + +message CloseSqlStreamReq { + int32 sql_id = 1; +} + +message CloseSqlStreamResp { +} + +message GetAutocommitStreamReq { +} + +message GetAutocommitStreamResp { + bool is_autocommit = 1; +} +``` + +### Shared structures + +```proto +syntax = "proto3"; +package hrana; + +message Error { + string message = 1; + optional string code = 2; +} + +message Stmt { + optional string sql = 1; + optional int32 sql_id = 2; + repeated Value args = 3; + repeated NamedArg named_args = 4; + optional bool want_rows = 5; +} + +message NamedArg { + string name = 1; + Value value = 2; +} + +message StmtResult { + repeated Col cols = 1; + repeated Row rows = 2; + uint64 affected_row_count = 3; + optional sint64 last_insert_rowid = 4; +} + +message Col { + optional string name = 1; + optional string decltype = 2; +} + +message Row { + repeated Value values = 1; +} + +message Batch { + repeated BatchStep steps = 1; +} + +message BatchStep { + optional BatchCond condition = 1; + Stmt stmt = 2; +} + +message BatchCond { + oneof cond { + uint32 step_ok = 1; + uint32 step_error = 2; + BatchCond not = 3; + CondList and = 4; + CondList or = 5; + IsAutocommit is_autocommit = 6; + } + + message CondList { + repeated BatchCond conds = 1; + } + + message IsAutocommit { + } +} + +message BatchResult { + map step_results = 1; + map step_errors = 2; +} + +message CursorEntry { + oneof entry { + StepBeginEntry step_begin = 1; + StepEndEntry step_end = 2; + StepErrorEntry step_error = 3; + Row row = 4; + Error error = 5; + } +} + +message StepBeginEntry { + uint32 step = 1; + repeated Col cols = 2; +} + +message StepEndEntry { + uint64 affected_row_count = 1; + optional sint64 last_insert_rowid = 2; +} + +message StepErrorEntry { + uint32 step = 1; + Error error = 2; +} + +message DescribeResult { + repeated DescribeParam params = 1; + repeated DescribeCol cols = 2; + bool is_explain = 3; + bool is_readonly = 4; +} + +message DescribeParam { + optional string name = 1; +} + +message DescribeCol { + string name = 1; + optional string decltype = 2; +} + +message Value { + oneof value { + Null null = 1; + sint64 integer = 2; + double float = 3; + string text = 4; + bytes blob = 5; + } + + message Null {} +} +``` diff --git a/sqld/proto/proxy.proto b/sqld/proto/proxy.proto index aeb26ed6..d8cc4c3e 100644 --- a/sqld/proto/proxy.proto +++ b/sqld/proto/proxy.proto @@ -100,6 +100,7 @@ message Cond { NotCond not = 3; AndCond and = 4; OrCond or = 5; + IsAutocommitCond is_autocommit = 6; } } @@ -123,6 +124,9 @@ message OrCond { repeated Cond conds = 1; } +message IsAutocommitCond { +} + enum Authorized { READONLY = 0; FULL = 1; diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index 0011c149..43ca9142 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -314,8 +314,9 @@ impl<'a> Connection<'a> { builder: &mut impl QueryResultBuilder, ) -> Result { builder.begin_step()?; + let mut enabled = match step.cond.as_ref() { - Some(cond) => match eval_cond(cond, results) { + Some(cond) => match eval_cond(cond, results, self.is_autocommit()) { Ok(enabled) => enabled, Err(e) => { builder.step_error(e).unwrap(); @@ -453,25 +454,29 @@ impl<'a> Connection<'a> { is_readonly, }) } + + fn is_autocommit(&self) -> bool { + self.conn.is_autocommit() + } } -fn eval_cond(cond: &Cond, results: &[bool]) -> Result { +fn eval_cond(cond: &Cond, results: &[bool], is_autocommit: bool) -> Result { let get_step_res = |step: usize| -> Result { let res = results.get(step).ok_or(Error::InvalidBatchStep(step))?; - Ok(*res) }; Ok(match cond { Cond::Ok { step } => get_step_res(*step)?, Cond::Err { step } => !get_step_res(*step)?, - Cond::Not { cond } => !eval_cond(cond, results)?, - Cond::And { conds } => conds - .iter() - .try_fold(true, |x, cond| eval_cond(cond, results).map(|y| x & y))?, - Cond::Or { conds } => conds - .iter() - .try_fold(false, |x, cond| eval_cond(cond, results).map(|y| x | y))?, + Cond::Not { cond } => !eval_cond(cond, results, is_autocommit)?, + Cond::And { conds } => conds.iter().try_fold(true, |x, cond| { + eval_cond(cond, results, is_autocommit).map(|y| x & y) + })?, + Cond::Or { conds } => conds.iter().try_fold(false, |x, cond| { + eval_cond(cond, results, is_autocommit).map(|y| x | y) + })?, + Cond::IsAutocommit => is_autocommit, }) } @@ -558,6 +563,20 @@ impl super::Connection for LibSqlConnection { Ok(receiver.await?) } + + async fn is_autocommit(&self) -> Result { + let (resp, receiver) = oneshot::channel(); + let cb = Box::new(move |maybe_conn: Result<&mut Connection>| { + let res = maybe_conn.map(|c| c.is_autocommit()); + if resp.send(res).is_err() { + anyhow::bail!("connection closed"); + } + Ok(()) + }); + + let _: Result<_, _> = self.sender.send(cb); + receiver.await? + } } #[cfg(test)] diff --git a/sqld/src/connection/mod.rs b/sqld/src/connection/mod.rs index ec6da8ac..6545dd57 100644 --- a/sqld/src/connection/mod.rs +++ b/sqld/src/connection/mod.rs @@ -99,6 +99,9 @@ pub trait Connection: Send + Sync + 'static { /// Parse the SQL statement and return information about it. async fn describe(&self, sql: String, auth: Authenticated) -> Result; + + /// Check whether the connection is in autocommit mode. + async fn is_autocommit(&self) -> Result; } fn make_batch_program(batch: Vec) -> Vec { @@ -273,6 +276,11 @@ impl Connection for TrackedConnection { async fn describe(&self, sql: String, auth: Authenticated) -> crate::Result { self.inner.describe(sql, auth).await } + + #[inline] + async fn is_autocommit(&self) -> crate::Result { + self.inner.is_autocommit().await + } } #[cfg(test)] @@ -299,6 +307,10 @@ mod test { ) -> crate::Result { unreachable!() } + + async fn is_autocommit(&self) -> crate::Result { + unreachable!() + } } #[tokio::test] diff --git a/sqld/src/connection/program.rs b/sqld/src/connection/program.rs index c85110ac..fabfbd18 100644 --- a/sqld/src/connection/program.rs +++ b/sqld/src/connection/program.rs @@ -57,6 +57,7 @@ pub enum Cond { Not { cond: Box }, Or { conds: Vec }, And { conds: Vec }, + IsAutocommit, } pub type DescribeResult = crate::Result; diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index ddbb4045..8781ca93 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -294,6 +294,14 @@ impl Connection for WriteProxyConnection { self.wait_replication_sync().await?; self.read_db.describe(sql, auth).await } + + async fn is_autocommit(&self) -> Result { + let state = self.state.lock().await; + Ok(match *state { + State::Txn => false, + State::Init | State::Invalid => true, + }) + } } impl Drop for WriteProxyConnection { diff --git a/sqld/src/hrana/batch.rs b/sqld/src/hrana/batch.rs index 31025dab..9ee32a39 100644 --- a/sqld/src/hrana/batch.rs +++ b/sqld/src/hrana/batch.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use std::collections::HashMap; use std::sync::Arc; @@ -27,8 +27,12 @@ pub enum BatchError { ResponseTooLarge, } -fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> Result { - let try_convert_step = |step: i32| -> Result { +fn proto_cond_to_cond( + cond: &proto::BatchCond, + version: Version, + max_step_i: usize, +) -> Result { + let try_convert_step = |step: u32| -> Result { let step = usize::try_from(step).map_err(|_| ProtocolError::BatchCondBadStep)?; if step >= max_step_i { return Err(ProtocolError::BatchCondBadStep); @@ -37,6 +41,9 @@ fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> Result { + bail!(ProtocolError::NoneBatchCond) + } proto::BatchCond::Ok { step } => Cond::Ok { step: try_convert_step(*step)?, }, @@ -44,20 +51,31 @@ fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> Result Cond::Not { - cond: proto_cond_to_cond(cond, max_step_i)?.into(), + cond: proto_cond_to_cond(cond, version, max_step_i)?.into(), }, - proto::BatchCond::And { conds } => Cond::And { - conds: conds + proto::BatchCond::And(cond_list) => Cond::And { + conds: cond_list + .conds .iter() - .map(|cond| proto_cond_to_cond(cond, max_step_i)) + .map(|cond| proto_cond_to_cond(cond, version, max_step_i)) .collect::>()?, }, - proto::BatchCond::Or { conds } => Cond::Or { - conds: conds + proto::BatchCond::Or(cond_list) => Cond::Or { + conds: cond_list + .conds .iter() - .map(|cond| proto_cond_to_cond(cond, max_step_i)) + .map(|cond| proto_cond_to_cond(cond, version, max_step_i)) .collect::>()?, }, + proto::BatchCond::IsAutocommit {} => { + if version < Version::Hrana3 { + bail!(ProtocolError::NotSupported { + what: "BatchCond of type `is_autocommit`", + min_version: Version::Hrana3, + }) + } + Cond::IsAutocommit + } }; Ok(cond) @@ -74,7 +92,7 @@ pub fn proto_batch_to_program( let cond = step .condition .as_ref() - .map(|cond| proto_cond_to_cond(cond, step_i)) + .map(|cond| proto_cond_to_cond(cond, version, step_i)) .transpose()?; let step = Step { query, cond }; @@ -149,12 +167,12 @@ pub async fn execute_sequence( fn catch_batch_error(sqld_error: SqldError) -> anyhow::Error { match batch_error_from_sqld_error(sqld_error) { - Ok(stmt_error) => anyhow!(stmt_error), + Ok(batch_error) => anyhow!(batch_error), Err(sqld_error) => anyhow!(sqld_error), } } -fn batch_error_from_sqld_error(sqld_error: SqldError) -> Result { +pub fn batch_error_from_sqld_error(sqld_error: SqldError) -> Result { Ok(match sqld_error { SqldError::LibSqlTxTimeout => BatchError::TransactionTimeout, SqldError::LibSqlTxBusy => BatchError::TransactionBusy, @@ -165,6 +183,13 @@ fn batch_error_from_sqld_error(sqld_error: SqldError) -> Result proto::Error { + proto::Error { + message: error.to_string(), + code: error.code().into(), + } +} + impl BatchError { pub fn code(&self) -> &'static str { match self { diff --git a/sqld/src/hrana/cursor.rs b/sqld/src/hrana/cursor.rs new file mode 100644 index 00000000..1bc068a6 --- /dev/null +++ b/sqld/src/hrana/cursor.rs @@ -0,0 +1,247 @@ +use anyhow::{anyhow, Result}; +use rusqlite::types::ValueRef; +use std::mem::take; +use std::sync::Arc; +use std::task; +use tokio::sync::{mpsc, oneshot}; + +use crate::auth::Authenticated; +use crate::connection::program::Program; +use crate::connection::Connection; +use crate::query_result_builder::{ + Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, +}; + +use super::result_builder::{estimate_cols_json_size, value_json_size, value_to_proto}; +use super::{batch, proto, stmt}; + +pub struct CursorHandle { + open_tx: Option>>, + entry_rx: mpsc::Receiver>, +} + +#[derive(Debug)] +pub struct SizedEntry { + pub entry: proto::CursorEntry, + pub size: u64, +} + +struct OpenReq { + db: Arc, + auth: Authenticated, + pgm: Program, +} + +impl CursorHandle { + pub fn spawn(join_set: &mut tokio::task::JoinSet<()>) -> Self + where + C: Connection, + { + let (open_tx, open_rx) = oneshot::channel(); + let (entry_tx, entry_rx) = mpsc::channel(1); + + join_set.spawn(run_cursor(open_rx, entry_tx)); + Self { + open_tx: Some(open_tx), + entry_rx, + } + } + + pub fn open(&mut self, db: Arc, auth: Authenticated, pgm: Program) { + let open_tx = self.open_tx.take().unwrap(); + let _: Result<_, _> = open_tx.send(OpenReq { db, auth, pgm }); + } + + pub async fn fetch(&mut self) -> Result> { + self.entry_rx.recv().await.transpose() + } + + pub fn poll_fetch(&mut self, cx: &mut task::Context) -> task::Poll>> { + self.entry_rx.poll_recv(cx) + } +} + +async fn run_cursor( + open_rx: oneshot::Receiver>, + entry_tx: mpsc::Sender>, +) { + let Ok(open_req) = open_rx.await else { + return + }; + + let result_builder = CursorResultBuilder { + entry_tx: entry_tx.clone(), + step_i: 0, + step_state: StepState::default(), + }; + + if let Err(err) = open_req + .db + .execute_program(open_req.pgm, open_req.auth, result_builder) + .await + { + let entry = match batch::batch_error_from_sqld_error(err) { + Ok(batch_error) => Ok(SizedEntry { + entry: proto::CursorEntry::Error { + error: batch::proto_error_from_batch_error(&batch_error), + }, + size: 0, + }), + Err(sqld_error) => Err(anyhow!(sqld_error)), + }; + let _: Result<_, _> = entry_tx.send(entry).await; + } +} + +struct CursorResultBuilder { + entry_tx: mpsc::Sender>, + step_i: u32, + step_state: StepState, +} + +#[derive(Debug, Default)] +struct StepState { + emitted_begin: bool, + emitted_error: bool, + row: Vec, + row_size: u64, +} + +impl CursorResultBuilder { + fn emit_entry(&self, entry: Result) { + let _: Result<_, _> = self.entry_tx.blocking_send(entry); + } +} + +impl QueryResultBuilder for CursorResultBuilder { + type Ret = (); + + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + if self.step_state.emitted_begin && !self.step_state.emitted_error { + self.emit_entry(Ok(SizedEntry { + entry: proto::CursorEntry::StepEnd(proto::StepEndEntry { + affected_row_count, + last_insert_rowid, + }), + size: 100, // rough, order-of-magnitude estimate of the size of the entry + })); + } + + self.step_i += 1; + self.step_state = StepState::default(); + Ok(()) + } + + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + match stmt::stmt_error_from_sqld_error(error) { + Ok(stmt_error) => { + if self.step_state.emitted_error { + return Ok(()); + } + + self.emit_entry(Ok(SizedEntry { + entry: proto::CursorEntry::StepError(proto::StepErrorEntry { + step: self.step_i, + error: stmt::proto_error_from_stmt_error(&stmt_error), + }), + size: 100, + })); + self.step_state.emitted_error = true; + } + Err(err) => { + self.emit_entry(Err(anyhow!(err))); + } + } + Ok(()) + } + + fn cols_description<'a>( + &mut self, + col_iter: impl IntoIterator>>, + ) -> Result<(), QueryResultBuilderError> { + assert!(!self.step_state.emitted_begin); + if self.step_state.emitted_error { + return Ok(()); + } + + let mut cols_size = 0; + let cols = col_iter + .into_iter() + .map(Into::into) + .map(|col| { + cols_size += estimate_cols_json_size(&col); + proto::Col { + name: Some(col.name.to_owned()), + decltype: col.decl_ty.map(ToString::to_string), + } + }) + .collect(); + + self.emit_entry(Ok(SizedEntry { + entry: proto::CursorEntry::StepBegin(proto::StepBeginEntry { + step: self.step_i, + cols, + }), + size: cols_size, + })); + self.step_state.emitted_begin = true; + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert!(self.step_state.row.is_empty()); + Ok(()) + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + if self.step_state.emitted_begin && !self.step_state.emitted_error { + self.step_state.row_size += value_json_size(&v); + self.step_state.row.push(value_to_proto(v)?); + } + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + if self.step_state.emitted_begin && !self.step_state.emitted_error { + let values = take(&mut self.step_state.row); + self.emit_entry(Ok(SizedEntry { + entry: proto::CursorEntry::Row { + row: proto::Row { values }, + }, + size: self.step_state.row_size, + })); + } else { + self.step_state.row.clear(); + } + + self.step_state.row_size = 0; + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert!(self.step_state.row.is_empty()); + Ok(()) + } + + fn finish(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn into_ret(self) {} +} diff --git a/sqld/src/hrana/http/mod.rs b/sqld/src/hrana/http/mod.rs index 5c94a5d1..aa60347d 100644 --- a/sqld/src/hrana/http/mod.rs +++ b/sqld/src/hrana/http/mod.rs @@ -1,20 +1,30 @@ -use std::sync::Arc; - use anyhow::{Context, Result}; +use bytes::Bytes; +use futures::stream::Stream; use parking_lot::Mutex; use serde::{de::DeserializeOwned, Serialize}; +use std::pin::Pin; +use std::sync::Arc; +use std::task; -use super::ProtocolError; +use super::{batch, cursor, Encoding, ProtocolError, Version}; use crate::auth::Authenticated; use crate::connection::{Connection, MakeConnection}; mod proto; +mod protobuf; mod request; mod stream; -pub struct Server { +pub struct Server { self_url: Option, baton_key: [u8; 32], - stream_state: Mutex>, + stream_state: Mutex>, +} + +#[derive(Debug, Copy, Clone)] +pub enum Endpoint { + Pipeline, + Cursor, } impl Server { @@ -30,19 +40,30 @@ impl Server { stream::run_expire(self).await } - pub async fn handle_pipeline( + pub async fn handle_request( &self, + connection_maker: Arc>, auth: Authenticated, req: hyper::Request, - connection_maker: Arc>, + endpoint: Endpoint, + version: Version, + encoding: Encoding, ) -> Result> { - handle_pipeline(self, connection_maker, auth, req) - .await - .or_else(|err| { - err.downcast::() - .map(stream_error_response) - }) - .or_else(|err| err.downcast::().map(protocol_error_response)) + handle_request( + self, + connection_maker, + auth, + req, + endpoint, + version, + encoding, + ) + .await + .or_else(|err| { + err.downcast::() + .map(|err| stream_error_response(err, encoding)) + }) + .or_else(|err| err.downcast::().map(protocol_error_response)) } } @@ -53,64 +74,197 @@ pub(crate) async fn handle_index() -> hyper::Response { ) } -async fn handle_pipeline( - server: &Server, - connection_maker: Arc>, +async fn handle_request( + server: &Server, + connection_maker: Arc>, auth: Authenticated, req: hyper::Request, + endpoint: Endpoint, + version: Version, + encoding: Encoding, ) -> Result> { - let req_body: proto::PipelineRequestBody = read_request_json(req).await?; + match endpoint { + Endpoint::Pipeline => { + handle_pipeline(server, connection_maker, auth, req, version, encoding).await + } + Endpoint::Cursor => { + handle_cursor(server, connection_maker, auth, req, version, encoding).await + } + } +} + +async fn handle_pipeline( + server: &Server, + connection_maker: Arc>, + auth: Authenticated, + req: hyper::Request, + version: Version, + encoding: Encoding, +) -> Result> { + let req_body: proto::PipelineReqBody = read_decode_request(req, encoding).await?; let mut stream_guard = - stream::acquire(server, req_body.baton.as_deref(), connection_maker).await?; + stream::acquire(server, connection_maker, req_body.baton.as_deref()).await?; let mut results = Vec::with_capacity(req_body.requests.len()); for request in req_body.requests.into_iter() { - let result = request::handle(&mut stream_guard, auth, request) + let result = request::handle(&mut stream_guard, auth, request, version) .await .context("Could not execute a request in pipeline")?; results.push(result); } - let resp_body = proto::PipelineResponseBody { + let resp_body = proto::PipelineRespBody { baton: stream_guard.release(), base_url: server.self_url.clone(), results, }; - Ok(json_response(hyper::StatusCode::OK, &resp_body)) + Ok(encode_response(hyper::StatusCode::OK, &resp_body, encoding)) } -async fn read_request_json(req: hyper::Request) -> Result { +async fn handle_cursor( + server: &Server, + connection_maker: Arc>, + auth: Authenticated, + req: hyper::Request, + version: Version, + encoding: Encoding, +) -> Result> { + let req_body: proto::CursorReqBody = read_decode_request(req, encoding).await?; + let stream_guard = stream::acquire(server, connection_maker, req_body.baton.as_deref()).await?; + + let mut join_set = tokio::task::JoinSet::new(); + let mut cursor_hnd = cursor::CursorHandle::spawn(&mut join_set); + let db = stream_guard.get_db_owned()?; + let sqls = stream_guard.sqls(); + let pgm = batch::proto_batch_to_program(&req_body.batch, sqls, version)?; + cursor_hnd.open(db, auth, pgm); + + let resp_body = proto::CursorRespBody { + baton: stream_guard.release(), + base_url: server.self_url.clone(), + }; + let body = hyper::Body::wrap_stream(CursorStream { + resp_body: Some(resp_body), + join_set, + cursor_hnd, + encoding, + }); + let content_type = match encoding { + Encoding::Json => "text/plain", + Encoding::Protobuf => "application/octet-stream", + }; + + Ok(hyper::Response::builder() + .status(hyper::StatusCode::OK) + .header(hyper::http::header::CONTENT_TYPE, content_type) + .body(body) + .unwrap()) +} + +struct CursorStream { + resp_body: Option, + join_set: tokio::task::JoinSet<()>, + cursor_hnd: cursor::CursorHandle, + encoding: Encoding, +} + +impl Stream for CursorStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> task::Poll>> { + let this = self.get_mut(); + + if let Some(resp_body) = this.resp_body.take() { + let chunk = encode_stream_item(&resp_body, this.encoding); + return task::Poll::Ready(Some(Ok(chunk))); + } + + match this.join_set.poll_join_next(cx) { + task::Poll::Pending => {} + task::Poll::Ready(Some(Ok(()))) => {} + task::Poll::Ready(Some(Err(err))) => panic!("Cursor task crashed: {}", err), + task::Poll::Ready(None) => {} + }; + + match this.cursor_hnd.poll_fetch(cx) { + task::Poll::Pending => task::Poll::Pending, + task::Poll::Ready(None) => task::Poll::Ready(None), + task::Poll::Ready(Some(Ok(entry))) => { + let chunk = encode_stream_item(&entry.entry, this.encoding); + task::Poll::Ready(Some(Ok(chunk))) + } + task::Poll::Ready(Some(Err(err))) => task::Poll::Ready(Some(Err(err))), + } + } +} + +fn encode_stream_item(item: &T, encoding: Encoding) -> Bytes { + let mut data: Vec; + match encoding { + Encoding::Json => { + data = serde_json::to_vec(item).unwrap(); + data.push(b'\n'); + } + Encoding::Protobuf => { + data = ::encode_length_delimited_to_vec(item); + } + } + Bytes::from(data) +} + +async fn read_decode_request( + req: hyper::Request, + encoding: Encoding, +) -> Result { let req_body = hyper::body::to_bytes(req.into_body()) .await .context("Could not read request body")?; - let req_body = serde_json::from_slice(&req_body) - .map_err(|err| ProtocolError::Deserialize { source: err }) - .context("Could not deserialize JSON request body")?; - Ok(req_body) + match encoding { + Encoding::Json => serde_json::from_slice(&req_body) + .map_err(|err| ProtocolError::JsonDeserialize { source: err }) + .context("Could not deserialize JSON request body"), + Encoding::Protobuf => ::decode(req_body) + .map_err(|err| ProtocolError::ProtobufDecode { source: err }) + .context("Could not decode Protobuf request body"), + } } fn protocol_error_response(err: ProtocolError) -> hyper::Response { text_response(hyper::StatusCode::BAD_REQUEST, err.to_string()) } -fn stream_error_response(err: stream::StreamError) -> hyper::Response { - json_response( +fn stream_error_response( + err: stream::StreamError, + encoding: Encoding, +) -> hyper::Response { + encode_response( hyper::StatusCode::INTERNAL_SERVER_ERROR, &proto::Error { message: err.to_string(), code: err.code().into(), }, + encoding, ) } -fn json_response( +fn encode_response( status: hyper::StatusCode, resp_body: &T, + encoding: Encoding, ) -> hyper::Response { - let resp_body = serde_json::to_vec(resp_body).unwrap(); + let (resp_body, content_type) = match encoding { + Encoding::Json => (serde_json::to_vec(resp_body).unwrap(), "application/json"), + Encoding::Protobuf => ( + ::encode_to_vec(resp_body), + "application/x-protobuf", + ), + }; hyper::Response::builder() .status(status) - .header(hyper::http::header::CONTENT_TYPE, "application/json") + .header(hyper::http::header::CONTENT_TYPE, content_type) .body(hyper::Body::from(resp_body)) .unwrap() } diff --git a/sqld/src/hrana/http/proto.rs b/sqld/src/hrana/http/proto.rs index ba1285f1..82e6f23e 100644 --- a/sqld/src/hrana/http/proto.rs +++ b/sqld/src/hrana/http/proto.rs @@ -3,29 +3,59 @@ pub use super::super::proto::*; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Debug)] -pub struct PipelineRequestBody { +#[derive(Deserialize, prost::Message)] +pub struct PipelineReqBody { + #[prost(string, optional, tag = "1")] pub baton: Option, + #[prost(message, repeated, tag = "2")] pub requests: Vec, } -#[derive(Serialize, Debug)] -pub struct PipelineResponseBody { +#[derive(Serialize, prost::Message)] +pub struct PipelineRespBody { + #[prost(string, optional, tag = "1")] pub baton: Option, + #[prost(string, optional, tag = "2")] pub base_url: Option, + #[prost(message, repeated, tag = "3")] pub results: Vec, } -#[derive(Serialize, Debug)] +#[derive(Serialize, Default, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum StreamResult { - Ok { response: StreamResponse }, - Error { error: Error }, + #[default] + None, + Ok { + response: StreamResponse, + }, + Error { + error: Error, + }, +} + +#[derive(Deserialize, prost::Message)] +pub struct CursorReqBody { + #[prost(string, optional, tag = "1")] + pub baton: Option, + #[prost(message, required, tag = "2")] + pub batch: Batch, +} + +#[derive(Serialize, prost::Message)] +pub struct CursorRespBody { + #[prost(string, optional, tag = "1")] + pub baton: Option, + #[prost(string, optional, tag = "2")] + pub base_url: Option, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Default)] #[serde(tag = "type", rename_all = "snake_case")] pub enum StreamRequest { + #[serde(skip_deserializing)] + #[default] + None, Close(CloseStreamReq), Execute(ExecuteStreamReq), Batch(BatchStreamReq), @@ -33,6 +63,7 @@ pub enum StreamRequest { Describe(DescribeStreamReq), StoreSql(StoreSqlStreamReq), CloseSql(CloseSqlStreamReq), + GetAutocommit(GetAutocommitStreamReq), } #[derive(Serialize, Debug)] @@ -45,71 +76,93 @@ pub enum StreamResponse { Describe(DescribeStreamResp), StoreSql(StoreSqlStreamResp), CloseSql(CloseSqlStreamResp), + GetAutocommit(GetAutocommitStreamResp), } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct CloseStreamReq {} -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct CloseStreamResp {} -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct ExecuteStreamReq { + #[prost(message, required, tag = "1")] pub stmt: Stmt, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct ExecuteStreamResp { + #[prost(message, required, tag = "1")] pub result: StmtResult, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct BatchStreamReq { + #[prost(message, required, tag = "1")] pub batch: Batch, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct BatchStreamResp { + #[prost(message, required, tag = "1")] pub result: BatchResult, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct SequenceStreamReq { #[serde(default)] + #[prost(string, optional, tag = "1")] pub sql: Option, #[serde(default)] + #[prost(int32, optional, tag = "2")] pub sql_id: Option, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct SequenceStreamResp {} -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct DescribeStreamReq { #[serde(default)] + #[prost(string, optional, tag = "1")] pub sql: Option, #[serde(default)] + #[prost(int32, optional, tag = "2")] pub sql_id: Option, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct DescribeStreamResp { + #[prost(message, required, tag = "1")] pub result: DescribeResult, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct StoreSqlStreamReq { + #[prost(int32, tag = "1")] pub sql_id: i32, + #[prost(string, tag = "2")] pub sql: String, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct StoreSqlStreamResp {} -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct CloseSqlStreamReq { + #[prost(int32, tag = "1")] pub sql_id: i32, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct CloseSqlStreamResp {} + +#[derive(Deserialize, prost::Message)] +pub struct GetAutocommitStreamReq {} + +#[derive(Serialize, prost::Message)] +pub struct GetAutocommitStreamResp { + #[prost(bool, tag = "1")] + pub is_autocommit: bool, +} diff --git a/sqld/src/hrana/http/protobuf.rs b/sqld/src/hrana/http/protobuf.rs new file mode 100644 index 00000000..b108316d --- /dev/null +++ b/sqld/src/hrana/http/protobuf.rs @@ -0,0 +1,149 @@ +use super::proto::{StreamRequest, StreamResponse, StreamResult}; +use ::bytes::{Buf, BufMut}; +use prost::encoding::{message, skip_field, DecodeContext, WireType}; +use prost::DecodeError; +use std::mem::replace; + +impl prost::Message for StreamResult { + fn encode_raw(&self, buf: &mut B) + where + B: BufMut, + Self: Sized, + { + match self { + StreamResult::None => {} + StreamResult::Ok { response } => message::encode(1, response, buf), + StreamResult::Error { error } => message::encode(2, error, buf), + } + } + + fn encoded_len(&self) -> usize { + match self { + StreamResult::None => 0, + StreamResult::Ok { response } => message::encoded_len(1, response), + StreamResult::Error { error } => message::encoded_len(2, error), + } + } + + fn merge_field( + &mut self, + _tag: u32, + _wire_type: WireType, + _buf: &mut B, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + panic!("StreamResult can only be encoded, not decoded") + } + + fn clear(&mut self) { + panic!("StreamResult can only be encoded, not decoded") + } +} + +impl prost::Message for StreamRequest { + fn encode_raw(&self, _buf: &mut B) + where + B: BufMut, + Self: Sized, + { + panic!("StreamRequest can only be decoded, not encoded") + } + + fn encoded_len(&self) -> usize { + panic!("StreamRequest can only be decoded, not encoded") + } + + fn merge_field( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + macro_rules! merge { + ($variant:ident) => {{ + let mut msg = match replace(self, StreamRequest::None) { + StreamRequest::$variant(msg) => msg, + _ => Default::default(), + }; + message::merge(wire_type, &mut msg, buf, ctx)?; + *self = StreamRequest::$variant(msg); + }}; + } + + match tag { + 1 => merge!(Close), + 2 => merge!(Execute), + 3 => merge!(Batch), + 4 => merge!(Sequence), + 5 => merge!(Describe), + 6 => merge!(StoreSql), + 7 => merge!(CloseSql), + 8 => merge!(GetAutocommit), + _ => skip_field(wire_type, tag, buf, ctx)?, + } + Ok(()) + } + + fn clear(&mut self) { + *self = StreamRequest::None; + } +} + +impl prost::Message for StreamResponse { + fn encode_raw(&self, buf: &mut B) + where + B: BufMut, + Self: Sized, + { + match self { + StreamResponse::Close(msg) => message::encode(1, msg, buf), + StreamResponse::Execute(msg) => message::encode(2, msg, buf), + StreamResponse::Batch(msg) => message::encode(3, msg, buf), + StreamResponse::Sequence(msg) => message::encode(4, msg, buf), + StreamResponse::Describe(msg) => message::encode(5, msg, buf), + StreamResponse::StoreSql(msg) => message::encode(6, msg, buf), + StreamResponse::CloseSql(msg) => message::encode(7, msg, buf), + StreamResponse::GetAutocommit(msg) => message::encode(8, msg, buf), + } + } + + fn encoded_len(&self) -> usize { + match self { + StreamResponse::Close(msg) => message::encoded_len(1, msg), + StreamResponse::Execute(msg) => message::encoded_len(2, msg), + StreamResponse::Batch(msg) => message::encoded_len(3, msg), + StreamResponse::Sequence(msg) => message::encoded_len(4, msg), + StreamResponse::Describe(msg) => message::encoded_len(5, msg), + StreamResponse::StoreSql(msg) => message::encoded_len(6, msg), + StreamResponse::CloseSql(msg) => message::encoded_len(7, msg), + StreamResponse::GetAutocommit(msg) => message::encoded_len(8, msg), + } + } + + fn merge_field( + &mut self, + _tag: u32, + _wire_type: WireType, + _buf: &mut B, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + panic!("StreamResponse can only be encoded, not decoded") + } + + fn clear(&mut self) { + panic!("StreamResponse can only be encoded, not decoded") + } +} diff --git a/sqld/src/hrana/http/request.rs b/sqld/src/hrana/http/request.rs index afd7c1c3..4b4d0432 100644 --- a/sqld/src/hrana/http/request.rs +++ b/sqld/src/hrana/http/request.rs @@ -7,7 +7,7 @@ use crate::connection::Connection; /// An error from executing a [`proto::StreamRequest`] #[derive(thiserror::Error, Debug)] -pub enum StreamResponseError { +enum StreamResponseError { #[error("The server already stores {count} SQL texts, it cannot store more")] SqlTooMany { count: usize }, #[error(transparent)] @@ -20,8 +20,9 @@ pub async fn handle( stream_guard: &mut stream::Guard<'_, D>, auth: Authenticated, request: proto::StreamRequest, + version: Version, ) -> Result { - let result = match try_handle(stream_guard, auth, request).await { + let result = match try_handle(stream_guard, auth, request, version).await { Ok(response) => proto::StreamResult::Ok { response }, Err(err) => { let resp_err = err.downcast::()?; @@ -39,8 +40,21 @@ async fn try_handle( stream_guard: &mut stream::Guard<'_, D>, auth: Authenticated, request: proto::StreamRequest, + version: Version, ) -> Result { + macro_rules! ensure_version { + ($min_version:expr, $what:expr) => { + if version < $min_version { + bail!(ProtocolError::NotSupported { + what: $what, + min_version: $min_version, + }) + } + }; + } + Ok(match request { + proto::StreamRequest::None => bail!(ProtocolError::NoneStreamRequest), proto::StreamRequest::Close(_req) => { stream_guard.close_db(); proto::StreamResponse::Close(proto::CloseStreamResp {}) @@ -48,8 +62,8 @@ async fn try_handle( proto::StreamRequest::Execute(req) => { let db = stream_guard.get_db()?; let sqls = stream_guard.sqls(); - let query = stmt::proto_stmt_to_query(&req.stmt, sqls, Version::Hrana2) - .map_err(catch_stmt_error)?; + let query = + stmt::proto_stmt_to_query(&req.stmt, sqls, version).map_err(catch_stmt_error)?; let result = stmt::execute_stmt(db, auth, query) .await .map_err(catch_stmt_error)?; @@ -58,7 +72,7 @@ async fn try_handle( proto::StreamRequest::Batch(req) => { let db = stream_guard.get_db()?; let sqls = stream_guard.sqls(); - let pgm = batch::proto_batch_to_program(&req.batch, sqls, Version::Hrana2)?; + let pgm = batch::proto_batch_to_program(&req.batch, sqls, version)?; let result = batch::execute_batch(db, auth, pgm) .await .map_err(catch_batch_error)?; @@ -67,8 +81,7 @@ async fn try_handle( proto::StreamRequest::Sequence(req) => { let db = stream_guard.get_db()?; let sqls = stream_guard.sqls(); - let sql = - stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; + let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, version)?; let pgm = batch::proto_sequence_to_program(sql).map_err(catch_stmt_error)?; batch::execute_sequence(db, auth, pgm) .await @@ -79,8 +92,7 @@ async fn try_handle( proto::StreamRequest::Describe(req) => { let db = stream_guard.get_db()?; let sqls = stream_guard.sqls(); - let sql = - stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; + let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, version)?; let result = stmt::describe_stmt(db, auth, sql.into()) .await .map_err(catch_stmt_error)?; @@ -102,6 +114,12 @@ async fn try_handle( sqls.remove(&req.sql_id); proto::StreamResponse::CloseSql(proto::CloseSqlStreamResp {}) } + proto::StreamRequest::GetAutocommit(_req) => { + ensure_version!(Version::Hrana3, "The `get_autocommit` request"); + let db = stream_guard.get_db()?; + let is_autocommit = db.is_autocommit().await?; + proto::StreamResponse::GetAutocommit(proto::GetAutocommitStreamResp { is_autocommit }) + } }) } diff --git a/sqld/src/hrana/http/stream.rs b/sqld/src/hrana/http/stream.rs index 43cf831f..326f8332 100644 --- a/sqld/src/hrana/http/stream.rs +++ b/sqld/src/hrana/http/stream.rs @@ -1,7 +1,3 @@ -//! Stream allows connections to be grouped together using a baton value. -//! A baton value is sent by sqld to the client to be used in subsequent -//! requests. - use anyhow::{anyhow, Context, Result}; use base64::prelude::{Engine as _, BASE64_STANDARD_NO_PAD}; use hmac::Mac as _; @@ -58,7 +54,7 @@ enum Handle { struct Stream { /// The database connection that corresponds to this stream. This is `None` after the `"close"` /// request was executed. - db: Option, + db: Option>, /// The cache of SQL texts stored on the server with `"store_sql"` requests. sqls: HashMap, /// Stream id of this stream. The id is generated randomly (it should be unguessable). @@ -108,8 +104,8 @@ impl ServerStreamState { /// otherwise we create a new stream. pub async fn acquire<'srv, D: Connection>( server: &'srv Server, + connection_maker: Arc>, baton: Option<&str>, - db_factory: Arc>, ) -> Result> { let stream = match baton { Some(baton) => { @@ -152,14 +148,14 @@ pub async fn acquire<'srv, D: Connection>( stream } None => { - let db = db_factory + let db = connection_maker .create() .await .context("Could not create a database connection")?; let mut state = server.stream_state.lock(); let stream = Box::new(Stream { - db: Some(db), + db: Some(Arc::new(db)), sqls: HashMap::new(), stream_id: gen_stream_id(&mut state), // initializing the sequence number randomly makes it much harder to exploit @@ -185,7 +181,12 @@ pub async fn acquire<'srv, D: Connection>( impl<'srv, D: Connection> Guard<'srv, D> { pub fn get_db(&self) -> Result<&D, ProtocolError> { let stream = self.stream.as_ref().unwrap(); - stream.db.as_ref().ok_or(ProtocolError::BatonStreamClosed) + stream.db.as_deref().ok_or(ProtocolError::BatonStreamClosed) + } + + pub fn get_db_owned(&self) -> Result, ProtocolError> { + let stream = self.stream.as_ref().unwrap(); + stream.db.clone().ok_or(ProtocolError::BatonStreamClosed) } /// Closes the database connection. The next call to [`Guard::release()`] will then remove the diff --git a/sqld/src/hrana/mod.rs b/sqld/src/hrana/mod.rs index daee7ec8..4c78d669 100644 --- a/sqld/src/hrana/mod.rs +++ b/sqld/src/hrana/mod.rs @@ -1,8 +1,10 @@ use std::fmt; pub mod batch; +mod cursor; pub mod http; pub mod proto; +mod protobuf; mod result_builder; pub mod stmt; pub mod ws; @@ -11,25 +13,27 @@ pub mod ws; pub enum Version { Hrana1, Hrana2, + Hrana3, } -impl fmt::Display for Version { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Version::Hrana1 => write!(f, "hrana1"), - Version::Hrana2 => write!(f, "hrana2"), - } - } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Encoding { + Json, + Protobuf, } /// An unrecoverable protocol error that should close the WebSocket or HTTP stream. A correct /// client should never trigger any of these errors. #[derive(thiserror::Error, Debug)] pub enum ProtocolError { - #[error("Cannot deserialize client message: {source}")] - Deserialize { source: serde_json::Error }, - #[error("Received a binary WebSocket message, which is not supported")] + #[error("Cannot deserialize client message from JSON: {source}")] + JsonDeserialize { source: serde_json::Error }, + #[error("Could not decode client message from Protobuf: {source}")] + ProtobufDecode { source: prost::DecodeError }, + #[error("Received a binary WebSocket message, which is not supported in this encoding")] BinaryWebSocketMessage, + #[error("Received a text WebSocket message, which is not supported in this encoding")] + TextWebSocketMessage, #[error("Received a request before hello message")] RequestBeforeHello, @@ -50,6 +54,13 @@ pub enum ProtocolError { #[error("Invalid reference to step in a batch condition")] BatchCondBadStep, + #[error("Stream {stream_id} already has an open cursor")] + CursorAlreadyOpen { stream_id: i32 }, + #[error("Cursor {cursor_id} not found")] + CursorNotFound { cursor_id: i32 }, + #[error("Cursor {cursor_id} already exists")] + CursorExists { cursor_id: i32 }, + #[error("Received an invalid baton")] BatonInvalid, #[error("Received a baton that has already been used")] @@ -65,4 +76,25 @@ pub enum ProtocolError { #[error("{0}")] ResponseTooLarge(String), + + #[error("BatchCond type not recognized")] + NoneBatchCond, + #[error("Value type not recognized")] + NoneValue, + #[error("ClientMsg type not recognized")] + NoneClientMsg, + #[error("Request type not recognized")] + NoneRequest, + #[error("StreamRequest type not recognized")] + NoneStreamRequest, +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Version::Hrana1 => write!(f, "hrana1"), + Version::Hrana2 => write!(f, "hrana2"), + Version::Hrana3 => write!(f, "hrana3"), + } + } } diff --git a/sqld/src/hrana/proto.rs b/sqld/src/hrana/proto.rs index 8d544a07..760aae9f 100644 --- a/sqld/src/hrana/proto.rs +++ b/sqld/src/hrana/proto.rs @@ -4,97 +4,189 @@ use bytes::Bytes; use serde::{Deserialize, Serialize}; use std::sync::Arc; -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct Error { + #[prost(string, tag = "1")] pub message: String, + #[prost(string, tag = "2")] pub code: String, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct Stmt { #[serde(default)] + #[prost(string, optional, tag = "1")] pub sql: Option, #[serde(default)] + #[prost(int32, optional, tag = "2")] pub sql_id: Option, #[serde(default)] + #[prost(message, repeated, tag = "3")] pub args: Vec, #[serde(default)] + #[prost(message, repeated, tag = "4")] pub named_args: Vec, #[serde(default)] + #[prost(bool, optional, tag = "5")] pub want_rows: Option, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct NamedArg { + #[prost(string, tag = "1")] pub name: String, + #[prost(message, required, tag = "2")] pub value: Value, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct StmtResult { + #[prost(message, repeated, tag = "1")] pub cols: Vec, - pub rows: Vec>, + #[prost(message, repeated, tag = "2")] + pub rows: Vec, + #[prost(uint64, tag = "3")] pub affected_row_count: u64, #[serde(with = "option_i64_as_str")] + #[prost(sint64, optional, tag = "4")] pub last_insert_rowid: Option, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct Col { + #[prost(string, optional, tag = "1")] pub name: Option, + #[prost(string, optional, tag = "2")] pub decltype: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, prost::Message)] +#[serde(transparent)] +pub struct Row { + #[prost(message, repeated, tag = "1")] + pub values: Vec, +} + +#[derive(Deserialize, prost::Message)] pub struct Batch { + #[prost(message, repeated, tag = "1")] pub steps: Vec, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct BatchStep { - pub stmt: Stmt, #[serde(default)] + #[prost(message, optional, tag = "1")] pub condition: Option, + #[prost(message, required, tag = "2")] + pub stmt: Stmt, } -#[derive(Serialize, Debug)] +#[derive(Serialize, Debug, Default)] pub struct BatchResult { pub step_results: Vec>, pub step_errors: Vec>, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Default)] #[serde(tag = "type", rename_all = "snake_case")] pub enum BatchCond { - Ok { step: i32 }, - Error { step: i32 }, - Not { cond: Box }, - And { conds: Vec }, - Or { conds: Vec }, + #[serde(skip_deserializing)] + #[default] + None, + Ok { + step: u32, + }, + Error { + step: u32, + }, + Not { + cond: Box, + }, + And(BatchCondList), + Or(BatchCondList), + IsAutocommit {}, +} + +#[derive(Deserialize, prost::Message)] +pub struct BatchCondList { + #[prost(message, repeated, tag = "1")] + pub conds: Vec, +} + +#[derive(Serialize, Debug, Default)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum CursorEntry { + #[serde(skip_deserializing)] + #[default] + None, + StepBegin(StepBeginEntry), + StepEnd(StepEndEntry), + StepError(StepErrorEntry), + Row { + row: Row, + }, + Error { + error: Error, + }, +} + +#[derive(Serialize, prost::Message)] +pub struct StepBeginEntry { + #[prost(uint32, tag = "1")] + pub step: u32, + #[prost(message, repeated, tag = "2")] + pub cols: Vec, +} + +#[derive(Serialize, prost::Message)] +pub struct StepEndEntry { + #[prost(uint64, tag = "1")] + pub affected_row_count: u64, + #[prost(sint64, optional, tag = "2")] + pub last_insert_rowid: Option, +} + +#[derive(Serialize, prost::Message)] +pub struct StepErrorEntry { + #[prost(uint32, tag = "1")] + pub step: u32, + #[prost(message, required, tag = "2")] + pub error: Error, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct DescribeResult { + #[prost(message, repeated, tag = "1")] pub params: Vec, + #[prost(message, repeated, tag = "2")] pub cols: Vec, + #[prost(bool, tag = "3")] pub is_explain: bool, + #[prost(bool, tag = "4")] pub is_readonly: bool, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct DescribeParam { + #[prost(string, optional, tag = "1")] pub name: Option, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct DescribeCol { + #[prost(string, tag = "1")] pub name: String, + #[prost(string, optional, tag = "2")] pub decltype: Option, } -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Default, Clone, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum Value { + #[serde(skip_deserializing)] + #[default] + None, Null, Integer { #[serde(with = "i64_as_str")] diff --git a/sqld/src/hrana/protobuf.rs b/sqld/src/hrana/protobuf.rs new file mode 100644 index 00000000..8a57417d --- /dev/null +++ b/sqld/src/hrana/protobuf.rs @@ -0,0 +1,335 @@ +use super::proto::{BatchCond, BatchCondList, BatchResult, CursorEntry, Value}; +use ::bytes::{Buf, BufMut, Bytes}; +use prost::encoding::{ + bytes, double, message, sint64, skip_field, string, uint32, DecodeContext, WireType, +}; +use prost::DecodeError; +use std::mem::replace; +use std::sync::Arc; + +impl prost::Message for BatchResult { + fn encode_raw(&self, buf: &mut B) + where + B: BufMut, + Self: Sized, + { + vec_as_map::encode(1, &self.step_results, buf); + vec_as_map::encode(2, &self.step_errors, buf); + } + + fn encoded_len(&self) -> usize { + vec_as_map::encoded_len(1, &self.step_results) + + vec_as_map::encoded_len(2, &self.step_errors) + } + + fn merge_field( + &mut self, + _tag: u32, + _wire_type: WireType, + _buf: &mut B, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + panic!("BatchResult can only be encoded, not decoded") + } + + fn clear(&mut self) { + self.step_results.clear(); + self.step_errors.clear(); + } +} + +impl prost::Message for BatchCond { + fn encode_raw(&self, _buf: &mut B) + where + B: BufMut, + Self: Sized, + { + panic!("BatchCond can only be decoded, not encoded") + } + + fn encoded_len(&self) -> usize { + panic!("BatchCond can only be decoded, not encoded") + } + + fn merge_field( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + match tag { + 1 => { + let mut step = 0; + uint32::merge(wire_type, &mut step, buf, ctx)?; + *self = BatchCond::Ok { step } + } + 2 => { + let mut step = 0; + uint32::merge(wire_type, &mut step, buf, ctx)?; + *self = BatchCond::Error { step } + } + 3 => { + let mut cond = match replace(self, BatchCond::None) { + BatchCond::Not { cond } => cond, + _ => Box::new(BatchCond::None), + }; + message::merge(wire_type, &mut *cond, buf, ctx)?; + *self = BatchCond::Not { cond }; + } + 4 => { + let mut cond_list = match replace(self, BatchCond::None) { + BatchCond::And(cond_list) => cond_list, + _ => BatchCondList::default(), + }; + message::merge(wire_type, &mut cond_list, buf, ctx)?; + *self = BatchCond::And(cond_list); + } + 5 => { + let mut cond_list = match replace(self, BatchCond::None) { + BatchCond::Or(cond_list) => cond_list, + _ => BatchCondList::default(), + }; + message::merge(wire_type, &mut cond_list, buf, ctx)?; + *self = BatchCond::Or(cond_list); + } + 6 => { + skip_field(wire_type, tag, buf, ctx)?; + *self = BatchCond::IsAutocommit {}; + } + _ => { + skip_field(wire_type, tag, buf, ctx)?; + } + } + Ok(()) + } + + fn clear(&mut self) { + *self = BatchCond::None; + } +} + +impl prost::Message for CursorEntry { + fn encode_raw(&self, buf: &mut B) + where + B: BufMut, + Self: Sized, + { + match self { + CursorEntry::None => {} + CursorEntry::StepBegin(entry) => message::encode(1, entry, buf), + CursorEntry::StepEnd(entry) => message::encode(2, entry, buf), + CursorEntry::StepError(entry) => message::encode(3, entry, buf), + CursorEntry::Row { row } => message::encode(4, row, buf), + CursorEntry::Error { error } => message::encode(5, error, buf), + } + } + + fn encoded_len(&self) -> usize { + match self { + CursorEntry::None => 0, + CursorEntry::StepBegin(entry) => message::encoded_len(1, entry), + CursorEntry::StepEnd(entry) => message::encoded_len(2, entry), + CursorEntry::StepError(entry) => message::encoded_len(3, entry), + CursorEntry::Row { row } => message::encoded_len(4, row), + CursorEntry::Error { error } => message::encoded_len(5, error), + } + } + + fn merge_field( + &mut self, + _tag: u32, + _wire_type: WireType, + _buf: &mut B, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + panic!("CursorEntry can only be encoded, not decoded") + } + + fn clear(&mut self) { + *self = CursorEntry::None; + } +} + +impl prost::Message for Value { + fn encode_raw(&self, buf: &mut B) + where + B: BufMut, + Self: Sized, + { + match self { + Value::None => {} + Value::Null => empty_message::encode(1, buf), + Value::Integer { value } => sint64::encode(2, value, buf), + Value::Float { value } => double::encode(3, value, buf), + Value::Text { value } => arc_str::encode(4, value, buf), + Value::Blob { value } => bytes::encode(5, value, buf), + } + } + + fn encoded_len(&self) -> usize { + match self { + Value::None => 0, + Value::Null => empty_message::encoded_len(1), + Value::Integer { value } => sint64::encoded_len(2, value), + Value::Float { value } => double::encoded_len(3, value), + Value::Text { value } => arc_str::encoded_len(4, value), + Value::Blob { value } => bytes::encoded_len(5, value), + } + } + + fn merge_field( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + match tag { + 1 => { + skip_field(wire_type, tag, buf, ctx)?; + *self = Value::Null + } + 2 => { + let mut value = 0; + sint64::merge(wire_type, &mut value, buf, ctx)?; + *self = Value::Integer { value }; + } + 3 => { + let mut value = 0.; + double::merge(wire_type, &mut value, buf, ctx)?; + *self = Value::Float { value }; + } + 4 => { + let mut value = String::new(); + string::merge(wire_type, &mut value, buf, ctx)?; + // TODO: this makes an unnecessary copy + let value: Arc = value.into(); + *self = Value::Text { value }; + } + 5 => { + let mut value = Bytes::new(); + bytes::merge(wire_type, &mut value, buf, ctx)?; + *self = Value::Blob { value }; + } + _ => { + skip_field(wire_type, tag, buf, ctx)?; + } + } + Ok(()) + } + + fn clear(&mut self) { + *self = Value::None; + } +} + +mod vec_as_map { + use bytes::BufMut; + use prost::encoding::{ + encode_key, encode_varint, encoded_len_varint, key_len, message, uint32, WireType, + }; + + pub fn encode(tag: u32, values: &[Option], buf: &mut B) + where + B: BufMut, + M: prost::Message, + { + for (index, msg) in values.iter().enumerate() { + if let Some(msg) = msg { + encode_map_entry(tag, index as u32, msg, buf); + } + } + } + + pub fn encoded_len(tag: u32, values: &[Option]) -> usize + where + M: prost::Message, + { + values + .iter() + .enumerate() + .map(|(index, msg)| match msg { + Some(msg) => encoded_map_entry_len(tag, index as u32, msg), + None => 0, + }) + .sum() + } + + fn encode_map_entry(tag: u32, key: u32, value: &M, buf: &mut B) + where + B: BufMut, + M: prost::Message, + { + encode_key(tag, WireType::LengthDelimited, buf); + + let entry_key_len = uint32::encoded_len(1, &key); + let entry_value_len = message::encoded_len(2, value); + + encode_varint((entry_key_len + entry_value_len) as u64, buf); + uint32::encode(1, &key, buf); + message::encode(2, value, buf); + } + + fn encoded_map_entry_len(tag: u32, key: u32, value: &M) -> usize + where + M: prost::Message, + { + let entry_key_len = uint32::encoded_len(1, &key); + let entry_value_len = message::encoded_len(2, value); + let entry_len = entry_key_len + entry_value_len; + key_len(tag) + encoded_len_varint(entry_len as u64) + entry_len + } +} + +mod empty_message { + use bytes::BufMut; + use prost::encoding::{encode_key, encode_varint, encoded_len_varint, key_len, WireType}; + + pub fn encode(tag: u32, buf: &mut B) + where + B: BufMut, + { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(0, buf); + } + + pub fn encoded_len(tag: u32) -> usize { + key_len(tag) + encoded_len_varint(0) + } +} + +mod arc_str { + use bytes::BufMut; + use prost::encoding::{encode_key, encode_varint, encoded_len_varint, key_len, WireType}; + use std::sync::Arc; + + pub fn encode(tag: u32, value: &Arc, buf: &mut B) + where + B: BufMut, + { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(value.len() as u64, buf); + buf.put_slice(value.as_bytes()); + } + + pub fn encoded_len(tag: u32, value: &Arc) -> usize { + key_len(tag) + encoded_len_varint(value.len() as u64) + value.len() + } +} diff --git a/sqld/src/hrana/result_builder.rs b/sqld/src/hrana/result_builder.rs index 34ee928e..c26b52f1 100644 --- a/sqld/src/hrana/result_builder.rs +++ b/sqld/src/hrana/result_builder.rs @@ -16,7 +16,7 @@ use super::proto; pub struct SingleStatementBuilder { has_step: bool, cols: Vec, - rows: Vec>, + rows: Vec, err: Option, affected_row_count: u64, last_insert_rowid: Option, @@ -25,11 +25,19 @@ pub struct SingleStatementBuilder { max_total_response_size: u64, } -struct SizeFormatter(u64); +struct SizeFormatter { + size: u64, +} + +impl SizeFormatter { + fn new() -> Self { + Self { size: 0 } + } +} impl io::Write for SizeFormatter { fn write(&mut self, buf: &[u8]) -> io::Result { - self.0 += buf.len() as u64; + self.size += buf.len() as u64; Ok(buf.len()) } @@ -40,17 +48,17 @@ impl io::Write for SizeFormatter { impl fmt::Write for SizeFormatter { fn write_str(&mut self, s: &str) -> fmt::Result { - self.0 += s.len() as u64; + self.size += s.len() as u64; Ok(()) } } -fn value_json_size(v: &ValueRef) -> u64 { - let mut f = SizeFormatter(0); +pub fn value_json_size(v: &ValueRef) -> u64 { + let mut f = SizeFormatter::new(); match v { ValueRef::Null => write!(&mut f, r#"{{"type":"null"}}"#).unwrap(), - ValueRef::Integer(i) => write!(&mut f, r#"{{"type":"integer", "value": "{i}"}}"#).unwrap(), - ValueRef::Real(x) => write!(&mut f, r#"{{"type":"integer","value": {x}"}}"#).unwrap(), + ValueRef::Integer(i) => write!(&mut f, r#"{{"type":"integer","value":"{i}"}}"#).unwrap(), + ValueRef::Real(x) => write!(&mut f, r#"{{"type":"float","value":{x}"}}"#).unwrap(), ValueRef::Text(s) => { // error will be caught later. if let Ok(s) = std::str::from_utf8(s) { @@ -59,8 +67,23 @@ fn value_json_size(v: &ValueRef) -> u64 { } ValueRef::Blob(b) => return b.len() as u64, } + f.size +} - f.0 +pub fn value_to_proto(v: ValueRef) -> Result { + Ok(match v { + ValueRef::Null => proto::Value::Null, + ValueRef::Integer(value) => proto::Value::Integer { value }, + ValueRef::Real(value) => proto::Value::Float { value }, + ValueRef::Text(s) => proto::Value::Text { + value: String::from_utf8(s.to_vec()) + .map_err(QueryResultBuilderError::from_any)? + .into(), + }, + ValueRef::Blob(d) => proto::Value::Blob { + value: Bytes::copy_from_slice(d), + }, + }) } impl Drop for SingleStatementBuilder { @@ -123,10 +146,10 @@ impl QueryResultBuilder for SingleStatementBuilder { fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { assert!(self.err.is_none()); - let mut f = SizeFormatter(0); + let mut f = SizeFormatter::new(); write!(&mut f, "{error}").unwrap(); TOTAL_RESPONSE_SIZE.fetch_sub(self.current_size as usize, Ordering::Relaxed); - self.current_size = f.0; + self.current_size = f.size; TOTAL_RESPONSE_SIZE.fetch_add(self.current_size as usize, Ordering::Relaxed); self.err = Some(error); @@ -163,7 +186,9 @@ impl QueryResultBuilder for SingleStatementBuilder { fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { assert!(self.err.is_none()); - self.rows.push(Vec::with_capacity(self.cols.len())); + self.rows.push(proto::Row { + values: Vec::with_capacity(self.cols.len()), + }); Ok(()) } @@ -177,24 +202,12 @@ impl QueryResultBuilder for SingleStatementBuilder { } self.inc_current_size(estimate_size)?; - - let val = match v { - ValueRef::Null => proto::Value::Null, - ValueRef::Integer(value) => proto::Value::Integer { value }, - ValueRef::Real(value) => proto::Value::Float { value }, - ValueRef::Text(s) => proto::Value::Text { - value: String::from_utf8(s.to_vec()) - .map_err(QueryResultBuilderError::from_any)? - .into(), - }, - ValueRef::Blob(d) => proto::Value::Blob { - value: Bytes::copy_from_slice(d), - }, - }; + let val = value_to_proto(v)?; self.rows .last_mut() .expect("row must be initialized") + .values .push(val); Ok(()) @@ -227,8 +240,8 @@ impl QueryResultBuilder for SingleStatementBuilder { } } -fn estimate_cols_json_size(c: &Column) -> u64 { - let mut f = SizeFormatter(0); +pub fn estimate_cols_json_size(c: &Column) -> u64 { + let mut f = SizeFormatter::new(); write!( &mut f, r#"{{"name":"{}","decltype":"{}"}}"#, @@ -236,7 +249,7 @@ fn estimate_cols_json_size(c: &Column) -> u64 { c.decl_ty.unwrap_or("null") ) .unwrap(); - f.0 + f.size } #[derive(Debug, Default)] diff --git a/sqld/src/hrana/stmt.rs b/sqld/src/hrana/stmt.rs index 49e50f1e..13892150 100644 --- a/sqld/src/hrana/stmt.rs +++ b/sqld/src/hrana/stmt.rs @@ -98,14 +98,20 @@ pub fn proto_stmt_to_query( } let params = if proto_stmt.named_args.is_empty() { - let values = proto_stmt.args.iter().map(proto_value_to_value).collect(); + let values = proto_stmt + .args + .iter() + .map(proto_value_to_value) + .collect::, _>>()?; Params::Positional(values) } else if proto_stmt.args.is_empty() { let values = proto_stmt .named_args .iter() - .map(|arg| (arg.name.clone(), proto_value_to_value(&arg.value))) - .collect(); + .map(|arg| { + proto_value_to_value(&arg.value).map(|arg_value| (arg.name.clone(), arg_value)) + }) + .collect::, _>>()?; Params::Named(values) } else { bail!(StmtError::ArgsBothPositionalAndNamed) @@ -143,14 +149,15 @@ pub fn proto_sql_to_sql<'s>( } } -fn proto_value_to_value(proto_value: &proto::Value) -> Value { - match proto_value { +fn proto_value_to_value(proto_value: &proto::Value) -> Result { + Ok(match proto_value { + proto::Value::None => return Err(ProtocolError::NoneValue), proto::Value::Null => Value::Null, proto::Value::Integer { value } => Value::Integer(*value), proto::Value::Float { value } => Value::Real(*value), proto::Value::Text { value } => Value::Text(value.as_ref().into()), proto::Value::Blob { value } => Value::Blob(value.as_ref().into()), - } + }) } fn proto_value_from_value(value: Value) -> proto::Value { @@ -232,7 +239,7 @@ pub fn stmt_error_from_sqld_error(sqld_error: SqldError) -> Result hrana::proto::Error { - hrana::proto::Error { + proto::Error { message: error.to_string(), code: error.code().into(), } @@ -287,12 +294,6 @@ fn sqlite_error_code(code: rusqlite::ffi::ErrorCode) -> &'static str { } } -impl From<&proto::Value> for Value { - fn from(proto_value: &proto::Value) -> Value { - proto_value_to_value(proto_value) - } -} - impl From for proto::Value { fn from(value: Value) -> proto::Value { proto_value_from_value(value) diff --git a/sqld/src/hrana/ws/conn.rs b/sqld/src/hrana/ws/conn.rs index b804b335..db673212 100644 --- a/sqld/src/hrana/ws/conn.rs +++ b/sqld/src/hrana/ws/conn.rs @@ -1,6 +1,3 @@ -//! This file contains functions to deal with the connection of the Hrana protocol -//! over web sockets - use std::borrow::Cow; use std::future::Future; use std::pin::Pin; @@ -19,7 +16,7 @@ use crate::connection::MakeConnection; use crate::database::Database; use crate::namespace::MakeNamespace; -use super::super::{ProtocolError, Version}; +use super::super::{Encoding, ProtocolError, Version}; use super::handshake::WebSocket; use super::{handshake, proto, session, Server, Upgrade}; @@ -31,6 +28,8 @@ struct Conn { ws_closed: bool, /// The version of the protocol that has been negotiated in the WebSocket handshake. version: Version, + /// The encoding of messages that has been negotiated in the WebSocket handshake. + encoding: Encoding, /// After a successful authentication, this contains the session-level state of the connection. session: Option::Connection>>, /// Join set for all tasks that were spawned to handle the connection. @@ -54,14 +53,19 @@ pub(super) async fn handle_tcp( socket: tokio::net::TcpStream, conn_id: u64, ) -> Result<()> { - let (ws, version, ns) = handshake::handshake_tcp( + let handshake::Output { + ws, + version, + encoding, + namespace, + } = handshake::handshake_tcp( socket, server.disable_default_namespace, server.disable_namespaces, ) .await .context("Could not perform the WebSocket handshake on TCP connection")?; - handle_ws(server, ws, version, conn_id, ns).await + handle_ws(server, ws, version, encoding, conn_id, namespace).await } pub(super) async fn handle_upgrade( @@ -69,20 +73,26 @@ pub(super) async fn handle_upgrade( upgrade: Upgrade, conn_id: u64, ) -> Result<()> { - let (ws, version, ns) = handshake::handshake_upgrade( + let handshake::Output { + ws, + version, + encoding, + namespace, + } = handshake::handshake_upgrade( upgrade, server.disable_default_namespace, server.disable_namespaces, ) .await .context("Could not perform the WebSocket handshake on HTTP connection")?; - handle_ws(server, ws, version, conn_id, ns).await + handle_ws(server, ws, version, encoding, conn_id, namespace).await } async fn handle_ws( server: Arc>, ws: WebSocket, version: Version, + encoding: Encoding, conn_id: u64, namespace: Bytes, ) -> Result<()> { @@ -96,6 +106,7 @@ async fn handle_ws( ws, ws_closed: false, version, + encoding, session: None, join_set: tokio::task::JoinSet::new(), responses: FuturesUnordered::new(), @@ -160,22 +171,23 @@ async fn handle_msg( ) -> Result { match client_msg { tungstenite::Message::Text(client_msg) => { - // client messages are received as text WebSocket messages that encode the `ClientMsg` - // in JSON - let client_msg: proto::ClientMsg = match serde_json::from_str(&client_msg) { - Ok(client_msg) => client_msg, - Err(err) => bail!(ProtocolError::Deserialize { source: err }), - }; + if conn.encoding != Encoding::Json { + bail!(ProtocolError::TextWebSocketMessage) + } - match client_msg { - proto::ClientMsg::Hello { jwt } => handle_hello_msg(conn, jwt).await, - proto::ClientMsg::Request { - request_id, - request, - } => handle_request_msg(conn, request_id, request).await, + let client_msg: proto::ClientMsg = serde_json::from_str(&client_msg) + .map_err(|err| ProtocolError::JsonDeserialize { source: err })?; + handle_client_msg(conn, client_msg).await + } + tungstenite::Message::Binary(client_msg) => { + if conn.encoding != Encoding::Protobuf { + bail!(ProtocolError::BinaryWebSocketMessage) } + + let client_msg = ::decode(client_msg.as_slice()) + .map_err(|err| ProtocolError::ProtobufDecode { source: err })?; + handle_client_msg(conn, client_msg).await } - tungstenite::Message::Binary(_) => bail!(ProtocolError::BinaryWebSocketMessage), tungstenite::Message::Ping(ping_data) => { let pong_msg = tungstenite::Message::Pong(ping_data); conn.ws @@ -190,6 +202,21 @@ async fn handle_msg( } } +async fn handle_client_msg( + conn: &mut Conn, + client_msg: proto::ClientMsg, +) -> Result { + tracing::trace!("Received client msg: {:?}", client_msg); + match client_msg { + proto::ClientMsg::None => bail!(ProtocolError::NoneClientMsg), + proto::ClientMsg::Hello(msg) => handle_hello_msg(conn, msg.jwt).await, + proto::ClientMsg::Request(msg) => match msg.request { + Some(request) => handle_request_msg(conn, msg.request_id, request).await, + None => bail!(ProtocolError::NoneRequest), + }, + } +} + async fn handle_hello_msg( conn: &mut Conn, jwt: Option, @@ -202,12 +229,16 @@ async fn handle_hello_msg( match hello_res { Ok(_) => { - send_msg(conn, &proto::ServerMsg::HelloOk {}).await?; + send_msg(conn, &proto::ServerMsg::HelloOk(proto::HelloOkMsg {})).await?; Ok(true) } Err(err) => match downcast_error(err) { Ok(error) => { - send_msg(conn, &proto::ServerMsg::HelloError { error }).await?; + send_msg( + conn, + &proto::ServerMsg::HelloError(proto::HelloErrorMsg { error }), + ) + .await?; Ok(false) } Err(err) => Err(err), @@ -225,6 +256,7 @@ async fn handle_request_msg( }; let response_rx = session::handle_request( + &conn.server, session, &mut conn.join_set, request, @@ -250,15 +282,19 @@ impl Future for ResponseFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { match ready!(Pin::new(&mut self.response_rx).poll(cx)) { - Ok(Ok(response)) => Poll::Ready(Ok(proto::ServerMsg::ResponseOk { - request_id: self.request_id, - response, - })), - Ok(Err(err)) => match downcast_error(err) { - Ok(error) => Poll::Ready(Ok(proto::ServerMsg::ResponseError { + Ok(Ok(response)) => { + Poll::Ready(Ok(proto::ServerMsg::ResponseOk(proto::ResponseOkMsg { request_id: self.request_id, - error, - })), + response: Some(response), + }))) + } + Ok(Err(err)) => match downcast_error(err) { + Ok(error) => Poll::Ready(Ok(proto::ServerMsg::ResponseError( + proto::ResponseErrorMsg { + request_id: self.request_id, + error, + }, + ))), Err(err) => Poll::Ready(Err(err)), }, Err(_recv_err) => { @@ -284,12 +320,21 @@ fn downcast_error(err: anyhow::Error) -> Result { } async fn send_msg(conn: &mut Conn, msg: &proto::ServerMsg) -> Result<()> { - let msg = serde_json::to_string(&msg).context("Could not serialize response message")?; - let msg = tungstenite::Message::Text(msg); + let msg = match conn.encoding { + Encoding::Json => { + let msg = + serde_json::to_string(&msg).context("Could not serialize response message")?; + tungstenite::Message::Text(msg) + } + Encoding::Protobuf => { + let msg = ::encode_to_vec(msg); + tungstenite::Message::Binary(msg) + } + }; conn.ws .send(msg) .await - .context("Could not send response to the WebSocket") + .context("Could not send message to the WebSocket") } async fn close(conn: &mut Conn, code: CloseCode, reason: String) { @@ -323,8 +368,10 @@ async fn close(conn: &mut Conn, code: CloseCode, reason: St fn protocol_error_to_close_code(err: &ProtocolError) -> CloseCode { match err { - ProtocolError::Deserialize { .. } => CloseCode::Invalid, + ProtocolError::JsonDeserialize { .. } => CloseCode::Invalid, + ProtocolError::ProtobufDecode { .. } => CloseCode::Invalid, ProtocolError::BinaryWebSocketMessage => CloseCode::Unsupported, + ProtocolError::TextWebSocketMessage => CloseCode::Unsupported, _ => CloseCode::Policy, } } diff --git a/sqld/src/hrana/ws/handshake.rs b/sqld/src/hrana/ws/handshake.rs index feead5c4..9e25a713 100644 --- a/sqld/src/hrana/ws/handshake.rs +++ b/sqld/src/hrana/ws/handshake.rs @@ -1,5 +1,3 @@ -//! This file handles web socket handshakes. - use anyhow::{anyhow, bail, Context as _, Result}; use bytes::Bytes; use futures::{SinkExt as _, StreamExt as _}; @@ -8,7 +6,7 @@ use tungstenite::http; use crate::http::db_factory::namespace_from_headers; -use super::super::Version; +use super::super::{Encoding, Version}; use super::Upgrade; #[derive(Debug)] @@ -17,12 +15,32 @@ pub enum WebSocket { Upgraded(tokio_tungstenite::WebSocketStream), } +#[derive(Debug, Copy, Clone)] +enum Subproto { + Hrana1, + Hrana2, + Hrana3, + Hrana3Protobuf, +} + +#[derive(Debug)] +pub struct Output { + pub ws: WebSocket, + pub version: Version, + pub encoding: Encoding, + pub namespace: Bytes, +} + pub async fn handshake_tcp( socket: tokio::net::TcpStream, disable_default_ns: bool, disable_namespaces: bool, -) -> Result<(WebSocket, Version, Bytes)> { - let mut version = None; +) -> Result { + socket + .set_nodelay(true) + .context("Could not disable Nagle's algorithm")?; + + let mut subproto = None; let mut namespace = None; let callback = |req: &http::Request<()>, resp: http::Response<()>| { let (mut resp_parts, _) = resp.into_parts(); @@ -36,9 +54,9 @@ pub async fn handshake_tcp( Err(e) => return Err(http::Response::from_parts(resp_parts, Some(e.to_string()))), }; - match negotiate_version(req.headers(), &mut resp_parts.headers) { - Ok(version_) => { - version = Some(version_); + match negotiate_subproto(req.headers(), &mut resp_parts.headers) { + Ok(subproto_) => { + subproto = Some(subproto_); Ok(http::Response::from_parts(resp_parts, ())) } Err(resp_body) => Err(http::Response::from_parts(resp_parts, Some(resp_body))), @@ -48,21 +66,29 @@ pub async fn handshake_tcp( let ws_config = Some(get_ws_config()); let stream = tokio_tungstenite::accept_hdr_async_with_config(socket, callback, ws_config).await?; - Ok((WebSocket::Tcp(stream), version.unwrap(), namespace.unwrap())) + + let (version, encoding) = subproto.unwrap().version_encoding(); + Ok(Output { + ws: WebSocket::Tcp(stream), + version, + encoding, + namespace: namespace.unwrap(), + }) } pub async fn handshake_upgrade( upgrade: Upgrade, disable_default_ns: bool, disable_namespaces: bool, -) -> Result<(WebSocket, Version, Bytes)> { +) -> Result { let mut req = upgrade.request; - let ns = namespace_from_headers(req.headers(), disable_default_ns, disable_namespaces)?; + let namespace = namespace_from_headers(req.headers(), disable_default_ns, disable_namespaces)?; let ws_config = Some(get_ws_config()); - let (mut resp, stream_fut_version_res) = match hyper_tungstenite::upgrade(&mut req, ws_config) { - Ok((mut resp, stream_fut)) => match negotiate_version(req.headers(), resp.headers_mut()) { - Ok(version) => (resp, Ok((stream_fut, version, ns))), + let (mut resp, stream_fut_subproto_res) = match hyper_tungstenite::upgrade(&mut req, ws_config) + { + Ok((mut resp, stream_fut)) => match negotiate_subproto(req.headers(), resp.headers_mut()) { + Ok(subproto) => (resp, Ok((stream_fut, subproto))), Err(msg) => { *resp.status_mut() = http::StatusCode::BAD_REQUEST; *resp.body_mut() = hyper::Body::from(msg.clone()); @@ -92,51 +118,78 @@ pub async fn handshake_upgrade( bail!("Could not send the HTTP upgrade response") } - let (stream_fut, version, ns) = stream_fut_version_res?; + let (stream_fut, subproto) = stream_fut_subproto_res?; let stream = stream_fut .await .context("Could not upgrade HTTP request to a WebSocket")?; - Ok((WebSocket::Upgraded(stream), version, ns)) + + let (version, encoding) = subproto.version_encoding(); + Ok(Output { + ws: WebSocket::Upgraded(stream), + version, + encoding, + namespace, + }) } -fn negotiate_version( +fn negotiate_subproto( req_headers: &http::HeaderMap, resp_headers: &mut http::HeaderMap, -) -> Result { +) -> Result { if let Some(protocol_hdr) = req_headers.get("sec-websocket-protocol") { - let supported_by_client = protocol_hdr + let client_subprotos = protocol_hdr .to_str() .unwrap_or("") .split(',') - .map(|p| p.trim()); - - let mut hrana1_supported = false; - let mut hrana2_supported = false; - for protocol_str in supported_by_client { - hrana1_supported |= protocol_str.eq_ignore_ascii_case("hrana1"); - hrana2_supported |= protocol_str.eq_ignore_ascii_case("hrana2"); - } - - let version = if hrana2_supported { - Version::Hrana2 - } else if hrana1_supported { - Version::Hrana1 - } else { - return Err("Only 'hrana1' and 'hrana2' subprotocols are supported".into()); + .map(|p| p.trim()) + .collect::>(); + + let server_subprotos = [ + Subproto::Hrana3Protobuf, + Subproto::Hrana3, + Subproto::Hrana2, + Subproto::Hrana1, + ]; + + let Some(subproto) = select_subproto(&client_subprotos, &server_subprotos) else { + let supported = server_subprotos + .iter() + .copied() + .map(|s| s.as_str()) + .collect::>() + .join(" "); + return Err(format!("Only these WebSocket subprotocols are supported: {}", supported)) }; + tracing::debug!( + "Client subprotocols {:?}, selected {:?}", + client_subprotos, + subproto + ); + resp_headers.append( "sec-websocket-protocol", - http::HeaderValue::from_str(&version.to_string()).unwrap(), + http::HeaderValue::from_str(subproto.as_str()).unwrap(), ); - Ok(version) + Ok(subproto) } else { // Sec-WebSocket-Protocol header not present, assume that the client wants hrana1 // According to RFC 6455, we must not set the Sec-WebSocket-Protocol response header - Ok(Version::Hrana1) + Ok(Subproto::Hrana1) } } +fn select_subproto(client_subprotos: &[&str], server_subprotos: &[Subproto]) -> Option { + for &server_subproto in server_subprotos.iter() { + for client_subproto in client_subprotos.iter() { + if client_subproto.eq_ignore_ascii_case(server_subproto.as_str()) { + return Some(server_subproto); + } + } + } + None +} + fn get_ws_config() -> tungstenite::protocol::WebSocketConfig { tungstenite::protocol::WebSocketConfig { max_send_queue: Some(1 << 20), @@ -159,3 +212,23 @@ impl WebSocket { } } } + +impl Subproto { + fn as_str(self) -> &'static str { + match self { + Self::Hrana1 => "hrana1", + Self::Hrana2 => "hrana2", + Self::Hrana3 => "hrana3", + Self::Hrana3Protobuf => "hrana3-protobuf", + } + } + + fn version_encoding(self) -> (Version, Encoding) { + match self { + Self::Hrana1 => (Version::Hrana1, Encoding::Json), + Self::Hrana2 => (Version::Hrana2, Encoding::Json), + Self::Hrana3 => (Version::Hrana3, Encoding::Json), + Self::Hrana3Protobuf => (Version::Hrana3, Encoding::Protobuf), + } + } +} diff --git a/sqld/src/hrana/ws/mod.rs b/sqld/src/hrana/ws/mod.rs index 007b4a14..6c625499 100644 --- a/sqld/src/hrana/ws/mod.rs +++ b/sqld/src/hrana/ws/mod.rs @@ -12,12 +12,14 @@ pub mod proto; mod conn; mod handshake; +mod protobuf; mod session; struct Server { namespaces: Arc>, auth: Arc, idle_kicker: Option, + max_response_size: u64, next_conn_id: AtomicU64, disable_default_namespace: bool, disable_namespaces: bool, @@ -35,9 +37,11 @@ pub struct Upgrade { pub response_tx: oneshot::Sender>, } +#[allow(clippy::too_many_arguments)] pub async fn serve( auth: Arc, idle_kicker: Option, + max_response_size: u64, mut accept_rx: mpsc::Receiver, mut upgrade_rx: mpsc::Receiver, namespaces: Arc>, @@ -47,6 +51,7 @@ pub async fn serve( let server = Arc::new(Server { auth, idle_kicker, + max_response_size, next_conn_id: AtomicU64::new(0), namespaces, disable_default_namespace, diff --git a/sqld/src/hrana/ws/proto.rs b/sqld/src/hrana/ws/proto.rs index 6bb88367..27b12b63 100644 --- a/sqld/src/hrana/ws/proto.rs +++ b/sqld/src/hrana/ws/proto.rs @@ -3,125 +3,267 @@ pub use super::super::proto::*; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Default)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ClientMsg { - Hello { jwt: Option }, - Request { request_id: i32, request: Request }, + #[serde(skip_deserializing)] + #[default] + None, + Hello(HelloMsg), + Request(RequestMsg), +} + +#[derive(Deserialize, prost::Message)] +pub struct HelloMsg { + #[prost(string, optional, tag = "1")] + pub jwt: Option, +} + +#[derive(Deserialize, prost::Message)] +pub struct RequestMsg { + #[prost(int32, tag = "1")] + pub request_id: i32, + #[prost(oneof = "Request", tags = "2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13")] + pub request: Option, } #[derive(Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ServerMsg { - HelloOk {}, - HelloError { error: Error }, - ResponseOk { request_id: i32, response: Response }, - ResponseError { request_id: i32, error: Error }, + HelloOk(HelloOkMsg), + HelloError(HelloErrorMsg), + ResponseOk(ResponseOkMsg), + ResponseError(ResponseErrorMsg), +} + +#[derive(Serialize, prost::Message)] +pub struct HelloOkMsg {} + +#[derive(Serialize, prost::Message)] +pub struct HelloErrorMsg { + #[prost(message, required, tag = "1")] + pub error: Error, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, prost::Message)] +pub struct ResponseOkMsg { + #[prost(int32, tag = "1")] + pub request_id: i32, + #[prost(oneof = "Response", tags = "2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13")] + pub response: Option, +} + +#[derive(Serialize, prost::Message)] +pub struct ResponseErrorMsg { + #[prost(int32, tag = "1")] + pub request_id: i32, + #[prost(message, required, tag = "2")] + pub error: Error, +} + +#[derive(Deserialize, prost::Oneof)] #[serde(tag = "type", rename_all = "snake_case")] pub enum Request { + #[prost(message, tag = "2")] OpenStream(OpenStreamReq), + #[prost(message, tag = "3")] CloseStream(CloseStreamReq), + #[prost(message, tag = "4")] Execute(ExecuteReq), + #[prost(message, tag = "5")] Batch(BatchReq), + #[prost(message, tag = "6")] + OpenCursor(OpenCursorReq), + #[prost(message, tag = "7")] + CloseCursor(CloseCursorReq), + #[prost(message, tag = "8")] + FetchCursor(FetchCursorReq), + #[prost(message, tag = "9")] Sequence(SequenceReq), + #[prost(message, tag = "10")] Describe(DescribeReq), + #[prost(message, tag = "11")] StoreSql(StoreSqlReq), + #[prost(message, tag = "12")] CloseSql(CloseSqlReq), + #[prost(message, tag = "13")] + GetAutocommit(GetAutocommitReq), } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Oneof)] #[serde(tag = "type", rename_all = "snake_case")] pub enum Response { + #[prost(message, tag = "2")] OpenStream(OpenStreamResp), + #[prost(message, tag = "3")] CloseStream(CloseStreamResp), + #[prost(message, tag = "4")] Execute(ExecuteResp), + #[prost(message, tag = "5")] Batch(BatchResp), + #[prost(message, tag = "6")] + OpenCursor(OpenCursorResp), + #[prost(message, tag = "7")] + CloseCursor(CloseCursorResp), + #[prost(message, tag = "8")] + FetchCursor(FetchCursorResp), + #[prost(message, tag = "9")] Sequence(SequenceResp), + #[prost(message, tag = "10")] Describe(DescribeResp), + #[prost(message, tag = "11")] StoreSql(StoreSqlResp), + #[prost(message, tag = "12")] CloseSql(CloseSqlResp), + #[prost(message, tag = "13")] + GetAutocommit(GetAutocommitResp), } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct OpenStreamReq { + #[prost(int32, tag = "1")] pub stream_id: i32, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct OpenStreamResp {} -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct CloseStreamReq { + #[prost(int32, tag = "1")] pub stream_id: i32, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct CloseStreamResp {} -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct ExecuteReq { + #[prost(int32, tag = "1")] pub stream_id: i32, + #[prost(message, required, tag = "2")] pub stmt: Stmt, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct ExecuteResp { + #[prost(message, required, tag = "1")] pub result: StmtResult, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct BatchReq { + #[prost(int32, tag = "1")] pub stream_id: i32, + #[prost(message, required, tag = "2")] pub batch: Batch, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct BatchResp { + #[prost(message, required, tag = "1")] pub result: BatchResult, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] +pub struct OpenCursorReq { + #[prost(int32, tag = "1")] + pub stream_id: i32, + #[prost(int32, tag = "2")] + pub cursor_id: i32, + #[prost(message, required, tag = "3")] + pub batch: Batch, +} + +#[derive(Serialize, prost::Message)] +pub struct OpenCursorResp {} + +#[derive(Deserialize, prost::Message)] +pub struct CloseCursorReq { + #[prost(int32, tag = "1")] + pub cursor_id: i32, +} + +#[derive(Serialize, prost::Message)] +pub struct CloseCursorResp {} + +#[derive(Deserialize, prost::Message)] +pub struct FetchCursorReq { + #[prost(int32, tag = "1")] + pub cursor_id: i32, + #[prost(uint32, tag = "2")] + pub max_count: u32, +} + +#[derive(Serialize, prost::Message)] +pub struct FetchCursorResp { + #[prost(message, repeated, tag = "1")] + pub entries: Vec, + #[prost(bool, tag = "2")] + pub done: bool, +} + +#[derive(Deserialize, prost::Message)] pub struct SequenceReq { + #[prost(int32, tag = "1")] pub stream_id: i32, #[serde(default)] + #[prost(string, optional, tag = "2")] pub sql: Option, #[serde(default)] + #[prost(int32, optional, tag = "3")] pub sql_id: Option, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct SequenceResp {} -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct DescribeReq { + #[prost(int32, tag = "1")] pub stream_id: i32, #[serde(default)] + #[prost(string, optional, tag = "2")] pub sql: Option, #[serde(default)] + #[prost(int32, optional, tag = "3")] pub sql_id: Option, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct DescribeResp { + #[prost(message, required, tag = "1")] pub result: DescribeResult, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct StoreSqlReq { + #[prost(int32, tag = "1")] pub sql_id: i32, + #[prost(string, required, tag = "2")] pub sql: String, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct StoreSqlResp {} -#[derive(Deserialize, Debug)] +#[derive(Deserialize, prost::Message)] pub struct CloseSqlReq { + #[prost(int32, tag = "1")] pub sql_id: i32, } -#[derive(Serialize, Debug)] +#[derive(Serialize, prost::Message)] pub struct CloseSqlResp {} + +#[derive(Deserialize, prost::Message)] +pub struct GetAutocommitReq { + #[prost(int32, tag = "1")] + pub stream_id: i32, +} + +#[derive(Serialize, prost::Message)] +pub struct GetAutocommitResp { + #[prost(bool, required, tag = "1")] + pub is_autocommit: bool, +} diff --git a/sqld/src/hrana/ws/protobuf.rs b/sqld/src/hrana/ws/protobuf.rs new file mode 100644 index 00000000..cf0b3a7c --- /dev/null +++ b/sqld/src/hrana/ws/protobuf.rs @@ -0,0 +1,100 @@ +use super::proto::{ClientMsg, HelloMsg, RequestMsg, ServerMsg}; +use ::bytes::{Buf, BufMut}; +use prost::encoding::{message, skip_field, DecodeContext, WireType}; +use prost::DecodeError; +use std::mem::replace; + +impl prost::Message for ClientMsg { + fn encode_raw(&self, _buf: &mut B) + where + B: BufMut, + Self: Sized, + { + panic!("ClientMsg can only be decoded, not encoded") + } + + fn encoded_len(&self) -> usize { + panic!("ClientMsg can only be decoded, not encoded") + } + + fn merge_field( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + match tag { + 1 => { + let mut msg = match replace(self, ClientMsg::None) { + ClientMsg::Hello(msg) => msg, + _ => HelloMsg::default(), + }; + message::merge(wire_type, &mut msg, buf, ctx)?; + *self = ClientMsg::Hello(msg); + } + 2 => { + let mut msg = match replace(self, ClientMsg::None) { + ClientMsg::Request(msg) => msg, + _ => RequestMsg::default(), + }; + message::merge(wire_type, &mut msg, buf, ctx)?; + *self = ClientMsg::Request(msg); + } + _ => { + skip_field(wire_type, tag, buf, ctx)?; + } + } + Ok(()) + } + + fn clear(&mut self) { + *self = ClientMsg::None; + } +} + +impl prost::Message for ServerMsg { + fn encode_raw(&self, buf: &mut B) + where + B: BufMut, + Self: Sized, + { + match self { + ServerMsg::HelloOk(msg) => message::encode(1, msg, buf), + ServerMsg::HelloError(msg) => message::encode(2, msg, buf), + ServerMsg::ResponseOk(msg) => message::encode(3, msg, buf), + ServerMsg::ResponseError(msg) => message::encode(4, msg, buf), + } + } + + fn encoded_len(&self) -> usize { + match self { + ServerMsg::HelloOk(msg) => message::encoded_len(1, msg), + ServerMsg::HelloError(msg) => message::encoded_len(2, msg), + ServerMsg::ResponseOk(msg) => message::encoded_len(3, msg), + ServerMsg::ResponseError(msg) => message::encoded_len(4, msg), + } + } + + fn merge_field( + &mut self, + _tag: u32, + _wire_type: WireType, + _buf: &mut B, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + panic!("ServerMsg can only be encoded, not decoded") + } + + fn clear(&mut self) { + panic!("ServerMsg can only be encoded, not decoded") + } +} diff --git a/sqld/src/hrana/ws/session.rs b/sqld/src/hrana/ws/session.rs index 461bd74b..f2e3ee72 100644 --- a/sqld/src/hrana/ws/session.rs +++ b/sqld/src/hrana/ws/session.rs @@ -5,7 +5,7 @@ use anyhow::{anyhow, bail, Context as _, Result}; use futures::future::BoxFuture; use tokio::sync::{mpsc, oneshot}; -use super::super::{batch, stmt, ProtocolError, Version}; +use super::super::{batch, cursor, stmt, ProtocolError, Version}; use super::{proto, Server}; use crate::auth::{AuthError, Authenticated}; use crate::connection::{Connection, MakeConnection}; @@ -18,10 +18,12 @@ pub struct Session { version: Version, streams: HashMap>, sqls: HashMap, + cursors: HashMap, } struct StreamHandle { job_tx: mpsc::Sender>, + cursor_id: Option, } /// An arbitrary job that is executed on a [`Stream`]. @@ -40,7 +42,9 @@ struct Stream { /// The database handle is `None` when the stream is created, and normally set to `Some` by the /// first job executed on the stream by the [`proto::OpenStreamReq`] request. However, if that /// request returns an error, the following requests may encounter a `None` here. - db: Option, + db: Option>, + /// Handle to an open cursor, if any. + cursor_hnd: Option>, } /// An error which can be converted to a Hrana [Error][proto::Error]. @@ -50,6 +54,8 @@ pub enum ResponseError { Auth { source: AuthError }, #[error("Stream {stream_id} has failed to open")] StreamNotOpen { stream_id: i32 }, + #[error("Cursor {cursor_id} has failed to open")] + CursorNotOpen { cursor_id: i32 }, #[error("The server already stores {count} SQL texts, it cannot store more")] SqlTooMany { count: usize }, #[error(transparent)] @@ -73,6 +79,7 @@ pub(super) fn handle_initial_hello( version, streams: HashMap::new(), sqls: HashMap::new(), + cursors: HashMap::new(), }) } @@ -95,11 +102,12 @@ pub(super) fn handle_repeated_hello( Ok(()) } -pub(super) async fn handle_request( - session: &mut Session, +pub(super) async fn handle_request( + server: &Server, + session: &mut Session<::Connection>, join_set: &mut tokio::task::JoinSet<()>, req: proto::Request, - connection_maker: Arc>, + connection_maker: Arc::Connection>>, ) -> Result>> { // TODO: this function has rotten: it is too long and contains too much duplicated code. It // should be refactored at the next opportunity, together with code in stmt.rs and batch.rs @@ -154,6 +162,17 @@ pub(super) async fn handle_request( }; } + macro_rules! get_stream_cursor_hnd { + ($stream:expr, $cursor_id:expr) => { + match $stream.cursor_hnd.as_mut() { + Some(cursor_hnd) => cursor_hnd, + None => bail!(ResponseError::CursorNotOpen { + cursor_id: $cursor_id, + }), + } + }; + } + match req { proto::Request::OpenStream(req) => { let stream_id = req.stream_id; @@ -161,17 +180,22 @@ pub(super) async fn handle_request( bail!(ProtocolError::StreamExists { stream_id }) } - let mut stream_hnd = stream_spawn(join_set, Stream { db: None }); + let mut stream_hnd = stream_spawn( + join_set, + Stream { + db: None, + cursor_hnd: None, + }, + ); stream_respond!(&mut stream_hnd, async move |stream| { let db = connection_maker .create() .await .context("Could not create a database connection")?; - stream.db = Some(db); + stream.db = Some(Arc::new(db)); Ok(proto::Response::OpenStream(proto::OpenStreamResp {})) }); - session.streams.insert(stream_id, stream_hnd); } proto::Request::CloseStream(req) => { @@ -180,6 +204,10 @@ pub(super) async fn handle_request( bail!(ProtocolError::StreamNotFound { stream_id }) }; + if let Some(cursor_id) = stream_hnd.cursor_id { + session.cursors.remove(&cursor_id); + } + stream_respond!(&mut stream_hnd, async move |_stream| { Ok(proto::Response::CloseStream(proto::CloseStreamResp {})) }); @@ -194,7 +222,7 @@ pub(super) async fn handle_request( stream_respond!(stream_hnd, async move |stream| { let db = get_stream_db!(stream, stream_id); - let result = stmt::execute_stmt(db, auth, query) + let result = stmt::execute_stmt(&**db, auth, query) .await .map_err(catch_stmt_error)?; Ok(proto::Response::Execute(proto::ExecuteResp { result })) @@ -210,7 +238,7 @@ pub(super) async fn handle_request( stream_respond!(stream_hnd, async move |stream| { let db = get_stream_db!(stream, stream_id); - let result = batch::execute_batch(db, auth, pgm) + let result = batch::execute_batch(&**db, auth, pgm) .await .map_err(catch_batch_error)?; Ok(proto::Response::Batch(proto::BatchResp { result })) @@ -232,7 +260,7 @@ pub(super) async fn handle_request( stream_respond!(stream_hnd, async move |stream| { let db = get_stream_db!(stream, stream_id); - batch::execute_sequence(db, auth, pgm) + batch::execute_sequence(&**db, auth, pgm) .await .map_err(catch_stmt_error) .map_err(catch_batch_error)?; @@ -255,7 +283,7 @@ pub(super) async fn handle_request( stream_respond!(stream_hnd, async move |stream| { let db = get_stream_db!(stream, stream_id); - let result = stmt::describe_stmt(db, auth, sql) + let result = stmt::describe_stmt(&**db, auth, sql) .await .map_err(catch_stmt_error)?; Ok(proto::Response::Describe(proto::DescribeResp { result })) @@ -280,6 +308,98 @@ pub(super) async fn handle_request( session.sqls.remove(&req.sql_id); respond!(proto::Response::CloseSql(proto::CloseSqlResp {})); } + proto::Request::OpenCursor(req) => { + ensure_version!(Version::Hrana3, "The `open_cursor` request"); + + let stream_id = req.stream_id; + let stream_hnd = get_stream_mut!(stream_id); + if stream_hnd.cursor_id.is_some() { + bail!(ProtocolError::CursorAlreadyOpen { stream_id }) + } + + let cursor_id = req.cursor_id; + if session.cursors.contains_key(&cursor_id) { + bail!(ProtocolError::CursorExists { cursor_id }) + } + + let pgm = batch::proto_batch_to_program(&req.batch, &session.sqls, session.version) + .map_err(catch_stmt_error)?; + let auth = session.authenticated; + + let mut cursor_hnd = cursor::CursorHandle::spawn(join_set); + stream_respond!(stream_hnd, async move |stream| { + let db = get_stream_db!(stream, stream_id); + cursor_hnd.open(db.clone(), auth, pgm); + stream.cursor_hnd = Some(cursor_hnd); + Ok(proto::Response::OpenCursor(proto::OpenCursorResp {})) + }); + session.cursors.insert(cursor_id, stream_id); + stream_hnd.cursor_id = Some(cursor_id); + } + proto::Request::CloseCursor(req) => { + ensure_version!(Version::Hrana3, "The `close_cursor` request"); + + let cursor_id = req.cursor_id; + let Some(stream_id) = session.cursors.remove(&cursor_id) else { + bail!(ProtocolError::CursorNotFound { cursor_id }) + }; + + let stream_hnd = get_stream_mut!(stream_id); + assert_eq!(stream_hnd.cursor_id, Some(cursor_id)); + stream_hnd.cursor_id = None; + + stream_respond!(stream_hnd, async move |stream| { + stream.cursor_hnd = None; + Ok(proto::Response::CloseCursor(proto::CloseCursorResp {})) + }); + } + proto::Request::FetchCursor(req) => { + ensure_version!(Version::Hrana3, "The `fetch_cursor` request"); + + let cursor_id = req.cursor_id; + let Some(&stream_id) = session.cursors.get(&cursor_id) else { + bail!(ProtocolError::CursorNotFound { cursor_id }) + }; + + let stream_hnd = get_stream_mut!(stream_id); + assert_eq!(stream_hnd.cursor_id, Some(cursor_id)); + + let max_count = req.max_count as usize; + let max_total_size = server.max_response_size / 8; + stream_respond!(stream_hnd, async move |stream| { + let cursor_hnd = get_stream_cursor_hnd!(stream, cursor_id); + + let mut entries = Vec::new(); + let mut total_size = 0; + let mut done = false; + while entries.len() < max_count && total_size < max_total_size { + let Some(sized_entry) = cursor_hnd.fetch().await? else { + done = true; + break + }; + entries.push(sized_entry.entry); + total_size += sized_entry.size; + } + + Ok(proto::Response::FetchCursor(proto::FetchCursorResp { + entries, + done, + })) + }); + } + proto::Request::GetAutocommit(req) => { + ensure_version!(Version::Hrana3, "The `get_autocommit` request"); + let stream_id = req.stream_id; + let stream_hnd = get_stream_mut!(stream_id); + + stream_respond!(stream_hnd, async move |stream| { + let db = get_stream_db!(stream, stream_id); + let is_autocommit = db.is_autocommit().await?; + Ok(proto::Response::GetAutocommit(proto::GetAutocommitResp { + is_autocommit, + })) + }); + } } Ok(resp_rx) } @@ -298,7 +418,10 @@ fn stream_spawn( let _: Result<_, _> = job.resp_tx.send(res); } }); - StreamHandle { job_tx } + StreamHandle { + job_tx, + cursor_id: None, + } } async fn stream_respond( @@ -336,6 +459,7 @@ impl ResponseError { Self::Auth { source } => source.code(), Self::SqlTooMany { .. } => "SQL_STORE_TOO_MANY", Self::StreamNotOpen { .. } => "STREAM_NOT_OPEN", + Self::CursorNotOpen { .. } => "CURSOR_NOT_OPEN", Self::Stmt(err) => err.code(), Self::Batch(err) => err.code(), } diff --git a/sqld/src/http/hrana_over_http_1.rs b/sqld/src/http/hrana_over_http_1.rs index 9d5b2daa..e981d793 100644 --- a/sqld/src/http/hrana_over_http_1.rs +++ b/sqld/src/http/hrana_over_http_1.rs @@ -104,7 +104,7 @@ where let res: Result<_> = async move { let req_body = hyper::body::to_bytes(req.into_body()).await?; let req_body = serde_json::from_slice(&req_body) - .map_err(|e| hrana::ProtocolError::Deserialize { source: e })?; + .map_err(|e| hrana::ProtocolError::JsonDeserialize { source: e })?; let db = db_factory .create() diff --git a/sqld/src/http/mod.rs b/sqld/src/http/mod.rs index 8f5591af..a7e97888 100644 --- a/sqld/src/http/mod.rs +++ b/sqld/src/http/mod.rs @@ -112,9 +112,9 @@ fn parse_queries(queries: Vec) -> crate::Result> { Ok(out) } -async fn handle_query( +async fn handle_query( auth: Authenticated, - MakeConnectionExtractor(connection_maker): MakeConnectionExtractor, + MakeConnectionExtractor(connection_maker): MakeConnectionExtractor, Json(query): Json, ) -> Result { let batch = parse_queries(query.statements)?; @@ -177,21 +177,6 @@ async fn handle_version() -> Response { Response::new(Body::from(version)) } -async fn handle_hrana_v2( - MakeConnectionExtractor(connection_maker): MakeConnectionExtractor< - ::Connection, - >, - AxumState(state): AxumState>, - auth: Authenticated, - req: Request, -) -> Result, Error> { - let server = state.hrana_http_srv; - - let res = server.handle_pipeline(auth, req, connection_maker).await?; - - Ok(res) -} - async fn handle_fallback() -> impl IntoResponse { (StatusCode::NOT_FOUND).into_response() } @@ -265,6 +250,25 @@ where tracing::debug!("got request: {} {}", req.method(), req.uri()); } + macro_rules! handle_hrana { + ($endpoint:expr, $version:expr, $encoding:expr,) => {{ + async fn handle_hrana( + AxumState(state): AxumState>, + MakeConnectionExtractor(connection_maker): MakeConnectionExtractor< + ::Connection, + >, + auth: Authenticated, + req: Request, + ) -> Result, Error> { + Ok(state + .hrana_http_srv + .handle_request(connection_maker, auth, req, $endpoint, $version, $encoding) + .await?) + } + handle_hrana + }}; + } + let app = Router::new() .route("/", post(handle_query)) .route("/", get(handle_upgrade)) @@ -276,7 +280,48 @@ where .route("/v1/execute", post(hrana_over_http_1::handle_execute)) .route("/v1/batch", post(hrana_over_http_1::handle_batch)) .route("/v2", get(crate::hrana::http::handle_index)) - .route("/v2/pipeline", post(handle_hrana_v2)) + .route( + "/v2/pipeline", + post(handle_hrana!( + hrana::http::Endpoint::Pipeline, + hrana::Version::Hrana2, + hrana::Encoding::Json, + )), + ) + .route("/v3", get(crate::hrana::http::handle_index)) + .route( + "/v3/pipeline", + post(handle_hrana!( + hrana::http::Endpoint::Pipeline, + hrana::Version::Hrana3, + hrana::Encoding::Json, + )), + ) + .route( + "/v3/cursor", + post(handle_hrana!( + hrana::http::Endpoint::Cursor, + hrana::Version::Hrana3, + hrana::Encoding::Json, + )), + ) + .route("/v3-protobuf", get(crate::hrana::http::handle_index)) + .route( + "/v3-protobuf/pipeline", + post(handle_hrana!( + hrana::http::Endpoint::Pipeline, + hrana::Version::Hrana3, + hrana::Encoding::Protobuf, + )), + ) + .route( + "/v3-protobuf/cursor", + post(handle_hrana!( + hrana::http::Endpoint::Cursor, + hrana::Version::Hrana3, + hrana::Encoding::Protobuf, + )), + ) .with_state(state); let layered_app = app diff --git a/sqld/src/lib.rs b/sqld/src/lib.rs index 749793a1..59bbac09 100644 --- a/sqld/src/lib.rs +++ b/sqld/src/lib.rs @@ -185,11 +185,13 @@ where let idle_kicker = idle_shutdown_layer.clone().map(|isl| isl.into_kicker()); let disable_default_namespace = config.disable_default_namespace; let disable_namespaces = config.disable_namespaces; + let max_response_size = config.max_response_size; join_set.spawn(async move { hrana::ws::serve( auth, idle_kicker, + max_response_size, hrana_accept_rx, hrana_upgrade_rx, namespaces, diff --git a/sqld/src/rpc/proxy.rs b/sqld/src/rpc/proxy.rs index 19a730c0..8751a170 100644 --- a/sqld/src/rpc/proxy.rs +++ b/sqld/src/rpc/proxy.rs @@ -171,6 +171,7 @@ pub mod rpc { .map(TryInto::try_into) .collect::>()?, }, + Some(cond::Cond::IsAutocommit(_)) => Self::IsAutocommit, None => anyhow::bail!("invalid condition"), }; @@ -248,6 +249,9 @@ pub mod rpc { connection::program::Cond::And { conds } => cond::Cond::And(AndCond { conds: conds.into_iter().map(|c| c.into()).collect(), }), + connection::program::Cond::IsAutocommit => { + cond::Cond::IsAutocommit(IsAutocommitCond {}) + } }; Self { cond: Some(cond) }