diff --git a/components/Chat/Chat.tsx b/components/Chat/Chat.tsx index 9f199d97..133a8219 100644 --- a/components/Chat/Chat.tsx +++ b/components/Chat/Chat.tsx @@ -1,69 +1,40 @@ import React, { useEffect, useState } from "react"; import ChatResponse from "@/components/Chat/Response/Response"; -import { StreamingMessage } from "@/types/components/chat"; import { Work } from "@nulib/dcapi-types"; import { prepareQuestion } from "@/lib/chat-helpers"; +import useChatSocket from "@/hooks/useChatSocket"; +import useQueryParams from "@/hooks/useQueryParams"; + +const Chat = () => { + const { searchTerm: question } = useQueryParams(); + const { authToken, isConnected, message, sendMessage } = useChatSocket(); -const Chat = ({ - authToken, - chatSocket, - question, -}: { - authToken: string; - chatSocket?: WebSocket; - question?: string; -}) => { - const [isReadyStateOpen, setIsReadyStateOpen] = useState(false); const [isStreamingComplete, setIsStreamingComplete] = useState(false); const [sourceDocuments, setSourceDocuments] = useState([]); const [streamedAnswer, setStreamedAnswer] = useState(""); - const handleReadyStateChange = () => { - setIsReadyStateOpen(chatSocket?.readyState === 1); - }; - - // Handle web socket stream updates - const handleMessageUpdate = (event: MessageEvent) => { - const data: StreamingMessage = JSON.parse(event.data); - // console.log("handleMessageUpdate", data); - - if (data.source_documents) { - setSourceDocuments(data.source_documents); - } else if (data.token) { - setStreamedAnswer((prev) => { - return prev + data.token; - }); - } else if (data.answer) { - setStreamedAnswer(data.answer); - setIsStreamingComplete(true); - } - }; - useEffect(() => { - if (question && isReadyStateOpen && chatSocket) { + if (question && isConnected && authToken) { const preparedQuestion = prepareQuestion(question, authToken); - chatSocket?.send(JSON.stringify(preparedQuestion)); + sendMessage(preparedQuestion); } - }, [chatSocket, isReadyStateOpen, prepareQuestion]); + }, [authToken, isConnected, question, sendMessage]); useEffect(() => { - if (chatSocket) { - chatSocket.addEventListener("message", handleMessageUpdate); - chatSocket.addEventListener("open", handleReadyStateChange); - chatSocket.addEventListener("close", handleReadyStateChange); - chatSocket.addEventListener("error", handleReadyStateChange); - } + if (!message) return; - return () => { - if (chatSocket) { - chatSocket.removeEventListener("message", handleMessageUpdate); - chatSocket.removeEventListener("open", handleReadyStateChange); - chatSocket.removeEventListener("close", handleReadyStateChange); - chatSocket.removeEventListener("error", handleReadyStateChange); - } - }; - }, [chatSocket, chatSocket?.url]); + if (message.source_documents) { + setSourceDocuments(message.source_documents); + } else if (message.token) { + setStreamedAnswer((prev) => { + return prev + message.token; + }); + } else if (message.answer) { + setStreamedAnswer(message.answer); + setIsStreamingComplete(true); + } + }, [message]); if (!question) return null; diff --git a/components/Chat/Wrapper.tsx b/components/Chat/Wrapper.tsx deleted file mode 100644 index f579a8e5..00000000 --- a/components/Chat/Wrapper.tsx +++ /dev/null @@ -1,16 +0,0 @@ -import Chat from "@/components/Chat/Chat"; -import useChatSocket from "@/hooks/useChatSocket"; -import useQueryParams from "@/hooks/useQueryParams"; - -const ChatWrapper = () => { - const { searchTerm: question } = useQueryParams(); - const { authToken, chatSocket } = useChatSocket(); - - if (!authToken || !chatSocket || !question) return null; - - return ( - - ); -}; - -export default ChatWrapper; diff --git a/hooks/useChatSocket.ts b/hooks/useChatSocket.ts index 25184f6d..91457db0 100644 --- a/hooks/useChatSocket.ts +++ b/hooks/useChatSocket.ts @@ -1,11 +1,17 @@ -import { useEffect, useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; import { DCAPI_CHAT_ENDPOINT } from "@/lib/constants/endpoints"; +import { StreamingMessage } from "@/types/components/chat"; import axios from "axios"; const useChatSocket = () => { - const [chatSocket, setChatSocket] = useState(null); const [authToken, setAuthToken] = useState(null); + const [url, setUrl] = useState(null); + + const socketRef = useRef(null); + const [message, setMessage] = useState(); + + const [isConnected, setIsConnected] = useState(false); useEffect(() => { axios({ @@ -15,24 +21,57 @@ const useChatSocket = () => { }) .then((response) => { const { auth: authToken, endpoint } = response.data; - if (!authToken || !endpoint) return; - const socket = new WebSocket(endpoint); - setAuthToken(authToken); - setChatSocket(socket); - - return () => { - if (socket) socket.close(); - }; + setUrl(endpoint); }) .catch((error) => { console.error(error); }); }, []); - return { authToken, chatSocket }; + useEffect(() => { + if (!url) return; + + socketRef.current = new WebSocket(url); + + socketRef.current.onopen = () => { + console.log("WebSocket connected"); + setIsConnected(true); + }; + + socketRef.current.onclose = () => { + console.log("WebSocket disconnected"); + setIsConnected(false); + }; + + socketRef.current.onerror = (error) => { + console.error("WebSocket error", error); + }; + + socketRef.current.onmessage = (event: MessageEvent) => { + const data = JSON.parse(event.data); + setMessage(data); + }; + + return () => { + socketRef.current?.close(); + }; + }, [url]); + + const sendMessage = useCallback((data: object) => { + if (socketRef.current && socketRef.current.readyState === WebSocket.OPEN) { + socketRef.current.send(JSON.stringify(data)); + } + }, []); + + return { + authToken, + isConnected, + message, + sendMessage, + }; }; export default useChatSocket; diff --git a/pages/search.tsx b/pages/search.tsx index cb511908..3fabc93d 100644 --- a/pages/search.tsx +++ b/pages/search.tsx @@ -7,7 +7,7 @@ import React, { useEffect, useState } from "react"; import { ApiSearchRequestBody } from "@/types/api/request"; import { ApiSearchResponse } from "@/types/api/response"; -import ChatWrapper from "@/components/Chat/Wrapper"; +import Chat from "@/components/Chat/Chat"; import Container from "@/components/Shared/Container"; import { DC_API_SEARCH_URL } from "@/lib/constants/endpoints"; import Facets from "@/components/Facets/Facets"; @@ -189,7 +189,7 @@ const SearchPage: NextPage = () => { /> )} - {showChatResponse && } + {showChatResponse && }