diff --git a/src/Connection.php b/src/Connection.php index 14ce02f..3935ebe 100644 --- a/src/Connection.php +++ b/src/Connection.php @@ -15,7 +15,7 @@ namespace Drift\DBAL; -use Doctrine\DBAL\DBALException; +use Doctrine\DBAL\Exception as DBALException; use Doctrine\DBAL\Exception\InvalidArgumentException; use Doctrine\DBAL\Exception\TableExistsException; use Doctrine\DBAL\Exception\TableNotFoundException; @@ -33,20 +33,9 @@ */ class Connection { - /** - * @var Driver - */ - private $driver; - - /** - * @var Credentials - */ - private $credentials; - - /** - * @var AbstractPlatform - */ - private $platform; + private Driver $driver; + private Credentials $credentials; + private AbstractPlatform $platform; /** * Connection constructor. @@ -112,6 +101,16 @@ public function connect() ->connect($this->credentials); } + /** + * Close. + */ + public function close() + { + $this + ->driver + ->close(); + } + /** * Creates QueryBuilder. * diff --git a/src/Driver/Driver.php b/src/Driver/Driver.php index a3328fd..9a30568 100644 --- a/src/Driver/Driver.php +++ b/src/Driver/Driver.php @@ -45,5 +45,21 @@ public function query( array $parameters ): PromiseInterface; - public function insert(QueryBuilder $queryBuilder, string $table, array $values): PromiseInterface; + /** + * @param QueryBuilder $queryBuilder + * @param string $table + * @param array $values + * + * @return PromiseInterface + */ + public function insert( + QueryBuilder $queryBuilder, + string $table, + array $values + ): PromiseInterface; + + /** + * @return void + */ + public function close(): void; } diff --git a/src/Driver/Mysql/MysqlDriver.php b/src/Driver/Mysql/MysqlDriver.php index 1ef61ae..1db6495 100644 --- a/src/Driver/Mysql/MysqlDriver.php +++ b/src/Driver/Mysql/MysqlDriver.php @@ -86,4 +86,12 @@ public function query( throw $this->exceptionConverter->convert(new DoctrineException($exception->getMessage(), null, $exception->getCode()), new Query($sql, $parameters, [])); }); } + + /** + * @return void + */ + public function close(): void + { + $this->connection->close(); + } } diff --git a/src/Driver/PostgreSQL/PostgreSQLDriver.php b/src/Driver/PostgreSQL/PostgreSQLDriver.php index 56399fb..2ddf5bd 100644 --- a/src/Driver/PostgreSQL/PostgreSQLDriver.php +++ b/src/Driver/PostgreSQL/PostgreSQLDriver.php @@ -17,6 +17,7 @@ use Doctrine\DBAL\Driver\API\ExceptionConverter as ExceptionConverterInterface; use Doctrine\DBAL\Driver\API\PostgreSQL\ExceptionConverter; +use Doctrine\DBAL\Exception; use Doctrine\DBAL\Query; use Doctrine\DBAL\Query\QueryBuilder; use Drift\DBAL\Credentials; @@ -24,20 +25,23 @@ use Drift\DBAL\Driver\Exception as DoctrineException; use Drift\DBAL\Result; use PgAsync\Client; +use PgAsync\Connection; use PgAsync\ErrorException; use React\EventLoop\LoopInterface; use React\Promise\Deferred; use React\Promise\PromiseInterface; +use function React\Promise\reject; /** * Class PostgreSQLDriver. */ class PostgreSQLDriver extends AbstractDriver { - private Client $client; + private Connection $connection; private LoopInterface $loop; private EmptyDoctrinePostgreSQLDriver $doctrineDriver; private ExceptionConverterInterface $exceptionConverter; + private bool $isClosed = false; /** * @param LoopInterface $loop @@ -54,13 +58,15 @@ public function __construct(LoopInterface $loop) */ public function connect(Credentials $credentials, array $options = []) { - $this->client = new Client([ - 'host' => $credentials->getHost(), - 'port' => $credentials->getPort(), - 'user' => $credentials->getUser(), - 'password' => $credentials->getPassword(), - 'database' => $credentials->getDbName(), - ], $this->loop); + $this->connection = + (new Client([ + 'host' => $credentials->getHost(), + 'port' => $credentials->getPort(), + 'user' => $credentials->getUser(), + 'password' => $credentials->getPassword(), + 'database' => $credentials->getDbName(), + ], $this->loop)) + ->getIdleConnection(); } /** @@ -70,6 +76,10 @@ public function query( string $sql, array $parameters ): PromiseInterface { + if ($this->isClosed) { + return reject(new Exception('Connection closed')); + } + /** * We should fix the parametrization. */ @@ -82,7 +92,7 @@ public function query( $deferred = new Deferred(); $this - ->client + ->connection ->executeStatement($sql, $parameters) ->subscribe(function ($row) use (&$results) { $results[] = $row; @@ -123,6 +133,10 @@ public function query( */ public function insert(QueryBuilder $queryBuilder, string $table, array $values): PromiseInterface { + if ($this->isClosed) { + return reject(new Exception('Connection closed')); + } + $queryBuilder = $this->createInsertQuery($queryBuilder, $table, $values); $query = 'SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = ?'; @@ -146,9 +160,20 @@ public function insert(QueryBuilder $queryBuilder, string $table, array $values) ->query($queryBuilder->getSQL().$returningPart, $queryBuilder->getParameters()) ->then(function (Result $result) use ($fields) { return 0 === count($fields) - ? new Result() + ? new Result(0, null, null) : new Result([], \intval($result->fetchFirstRow()[$fields[0]]), 1); }); }); } + + /** + * @return void + */ + public function close(): void + { + $this->isClosed = true; + $this + ->connection + ->disconnect(); + } } diff --git a/src/Driver/SQLite/SQLiteDriver.php b/src/Driver/SQLite/SQLiteDriver.php index 96a7ae0..c313e81 100644 --- a/src/Driver/SQLite/SQLiteDriver.php +++ b/src/Driver/SQLite/SQLiteDriver.php @@ -82,4 +82,12 @@ public function query( throw $this->exceptionConverter->convert(new DoctrineException($exception->getMessage()), new Query($sql, $parameters, [])); }); } + + /** + * @return void + */ + public function close(): void + { + $this->database->close(); + } } diff --git a/tests/ConnectionTest.php b/tests/ConnectionTest.php index 8bd9162..feae045 100644 --- a/tests/ConnectionTest.php +++ b/tests/ConnectionTest.php @@ -16,6 +16,7 @@ namespace Drift\DBAL\Tests; use function Clue\React\Block\await; +use Doctrine\DBAL\Exception as DBALException; use Doctrine\DBAL\Exception\InvalidArgumentException; use Doctrine\DBAL\Exception\TableExistsException; use Doctrine\DBAL\Exception\TableNotFoundException; @@ -606,4 +607,38 @@ public function testAffectedRows() await($promise, $loop, self::MAX_TIMEOUT); } + + /** + * Test close connection. + */ + public function testCloseConnection() + { + $loop = $this->createLoop(); + $connection = $this->getConnection($loop); + $promise = $this + ->resetInfrastructure($connection, true) + ->then(function (Connection $connection) { + return $connection->insert('test', [ + 'field1' => 'val1', + 'field2' => 'val2', + ]); + }) + ->then(function (Result $result) use ($connection) { + $this->assertEquals(1, $result->getAffectedRows()); + $connection->close(); + + return $connection->insert('test', [ + 'field1' => 'val1', + 'field2' => 'val2', + ]); + }) + ->then(function () { + $this->fail('An exception should have been thrown'); + }) + ->otherwise(function (DBALException $exception) { + // Good catch + }); + + await($promise, $loop, self::MAX_TIMEOUT); + } }