Skip to content

Commit

Permalink
ref: changes system prompt configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
micheleriva committed Oct 16, 2024
1 parent acf8695 commit 46ea465
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 23 deletions.
29 changes: 26 additions & 3 deletions src/answerSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ export type AnswerParams<UserContext = unknown> = {
onRelatedQueries?: (relatedQueries: string[]) => void
onNewInteractionStarted?: (interactionId: string) => void
onStateChange?: (state: Interaction[]) => void
}
},
systemPrompts?: string[]
}

export type Interaction<T = AnyDocument> = {
Expand Down Expand Up @@ -62,6 +63,8 @@ export class AnswerSession {
private conversationID: string
private lastInteractionParams?: AskParams
public state: Interaction[] = []
private systemPrompts?: string[]


constructor(params: AnswerParams) {
// @ts-expect-error - sorry again TypeScript :-)
Expand Down Expand Up @@ -192,7 +195,7 @@ export class AnswerSession {
requestBody.append('interactionId', interactionId)
requestBody.append('alias', this.oramaClient.getAlias() ?? '')

const systemPromptConfiguration = this.oramaClient.getSystemPromptConfiguration()
const systemPromptConfiguration = this.getSystemPromptConfiguration()
if (systemPromptConfiguration) {
requestBody.append('systemPrompts', JSON.stringify(systemPromptConfiguration))
}
Expand Down Expand Up @@ -244,7 +247,8 @@ export class AnswerSession {
if (done) break
buffer += decoder.decode(value, { stream: true })

let endOfMessageIndex
// biome-ignore lint/suspicious/noImplicitAnyLet: <explanation>
let endOfMessageIndex

// biome-ignore lint/suspicious/noAssignInExpressions: this saves a variable allocation on each iteration
while ((endOfMessageIndex = buffer.indexOf('\n\n')) !== -1) {
Expand Down Expand Up @@ -346,4 +350,23 @@ export class AnswerSession {
}
}
}

/**
* Methods associated with custom system prompts
*/
public setSystemPromptConfiguration(config: { systemPrompts: string[] }) {
if (Array.isArray(config.systemPrompts)) {
if (!config.systemPrompts.every((prompt) => typeof prompt === 'string')) {
throw new Error('Invalid system prompt configuration')
}

this.systemPrompts = config.systemPrompts
}

return this
}

public getSystemPromptConfiguration(): string[] | undefined {
return this.systemPrompts
}
}
24 changes: 4 additions & 20 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export type AnswerSessionParams = {
onNewInteractionStarted?: (interactionId: string) => void
onStateChange?: (state: Interaction[]) => void
}
systemPrompts?: string[]
}

export { AnswerSession, Message }
Expand All @@ -69,7 +70,6 @@ export class OramaClient {
private readonly collector?: Collector
private readonly cache?: Cache<Results<AnyDocument>>
private readonly profile: Profile
private systemPrompts?: string[]
private searchDebounceTimer?: any // NodeJS.Timer
private searchRequestCounter = 0
private blockSearchTillAuth = false
Expand Down Expand Up @@ -273,15 +273,16 @@ export class OramaClient {
initialMessages: params?.initialMessages || [],
oramaClient: this,
events: params?.events,
userContext: params?.userContext
userContext: params?.userContext,
systemPrompts: params?.systemPrompts ?? []
})
}

public startHeartBeat(config: HeartBeatConfig): void {
this.heartbeat?.stop()
this.heartbeat = new HeartBeat({
...config,
endpoint: this.endpoint + `/health?api-key=${this.api_key}`
endpoint: `${this.endpoint}/health?api-key=${this.api_key}`
})
this.heartbeat.start()
}
Expand Down Expand Up @@ -431,21 +432,4 @@ export class OramaClient {
public reset(): void {
this.profile.reset()
}

/**
* Methods associated with custom system prompts
*/
public setSystemPromptConfiguration(config: { systemPrompts: string[] }): void {
if (Array.isArray(config.systemPrompts)) {
if (!config.systemPrompts.every((prompt) => typeof prompt === 'string')) {
throw new Error('Invalid system prompt configuration')
}

this.systemPrompts = config.systemPrompts
}
}

public getSystemPromptConfiguration(): string[] | undefined {
return this.systemPrompts
}
}
29 changes: 29 additions & 0 deletions tests/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,33 @@ await t.test('regenerate last answer', async t => {

assert.equal(state.length, 2)
assert.equal(state[state.length - 1].query, 'labrador')
})

t.test('can use custom system prompts', async t => {
if (!process.env.ORAMA_E2E_ENDPOINT || !process.env.ORAMA_E2E_API_KEY) {
if (!process.env.ORAMA_E2E_ENDPOINT || !process.env.ORAMA_E2E_API_KEY) {
t.skip('ORAMA_E2E_ENDPOINT and ORAMA_E2E_API_KEY are not set. E2e tests will be skipped.')
return
}
}

const client = new OramaClient({
endpoint: process.env.ORAMA_E2E_ENDPOINT!,
api_key: process.env.ORAMA_E2E_API_KEY!
})

const session = client
.createAnswerSession({
systemPrompts: ['sp_italian-prompt-chc4o0']
})

session.setSystemPromptConfiguration({
systemPrompts: ['sp_italian-prompt-chc4o0']
})

const res = await session.ask({
term: 'what is Orama?'
})

assert.ok(res)
})

0 comments on commit 46ea465

Please sign in to comment.