Skip to content

Commit

Permalink
feat: add custom embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Nov 10, 2023
1 parent 5dc422b commit eae0f5b
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 3 deletions.
74 changes: 74 additions & 0 deletions app/src/main/java/org/unitmesh/llmpoc/embedding/CustomEmbedding.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.unitmesh.llmpoc.embedding

import cc.unitmesh.cf.core.utils.IdUtil
import cc.unitmesh.nlp.embedding.Embedding
import cc.unitmesh.nlp.similarity.CosineSimilarity
import cc.unitmesh.rag.store.EmbeddingMatch
import cc.unitmesh.rag.store.EmbeddingStore
import cc.unitmesh.rag.store.Entry
import java.util.*

class CustomEmbedding<Embedded> : EmbeddingStore<Embedded> {
private val entries: MutableList<Entry<Embedded>> = ArrayList()
override fun add(embedding: Embedding): String {
val id: String = IdUtil.uuid()
add(id, embedding)
return id
}

override fun add(id: String, embedding: Embedding) {
add(id, embedding, null)
}

override fun add(embedding: Embedding, embedded: Embedded): String {
val id: String = IdUtil.uuid()
add(id, embedding, embedded)
return id
}

private fun add(id: String, embedding: Embedding, embedded: Embedded?) {
entries.add(Entry(id, embedding, embedded))
}

override fun addAll(embeddings: List<Embedding>): List<String> {
val ids: MutableList<String> = ArrayList()
for (embedding in embeddings) {
ids.add(add(embedding))
}
return ids
}

override fun addAll(embeddings: List<Embedding>, embedded: List<Embedded>): List<String> {
require(embeddings.size == embedded.size) { "The list of embeddings and embedded must have the same size" }
val ids: MutableList<String> = ArrayList()
for (i in embeddings.indices) {
ids.add(add(embeddings[i], embedded[i]))
}
return ids
}

override fun findRelevant(
referenceEmbedding: Embedding,
maxResults: Int,
minScore: Double,
): List<EmbeddingMatch<Embedded>> {
val comparator = Comparator.comparingDouble(EmbeddingMatch<Embedded>::score)
val matches = PriorityQueue(comparator)

for (entry in entries) {
val score = CosineSimilarity.between(entry.embedding, referenceEmbedding)
if (score >= minScore) {
matches.add(EmbeddingMatch(score, entry.id, entry.embedding, entry.embedded!!))

if (matches.size > maxResults) {
matches.poll()
}
}
}

val result = ArrayList(matches)
result.sortWith(comparator)
result.reverse()
return result
}
}
11 changes: 8 additions & 3 deletions app/src/main/java/org/unitmesh/llmpoc/ui/home/HomeFragment.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@ import androidx.fragment.app.Fragment
import androidx.lifecycle.ViewModelProvider
import androidx.recyclerview.widget.LinearLayoutManager
import androidx.recyclerview.widget.RecyclerView
import cc.unitmesh.nlp.embedding.Embedding
import cc.unitmesh.nlp.similarity.CosineSimilarity
import cc.unitmesh.nlp.similarity.RelevanceScore
import cc.unitmesh.rag.document.Document
import cc.unitmesh.rag.store.EmbeddingMatch
import cc.unitmesh.rag.store.InMemoryEmbeddingStore
import org.unitmesh.llmpoc.R
import org.unitmesh.llmpoc.databinding.FragmentHomeBinding
import org.unitmesh.llmpoc.embedding.CustomEmbedding
import org.unitmesh.llmpoc.embedding.STSemantic
import java.util.*

class HomeFragment : Fragment() {

private lateinit var embeddingStore: InMemoryEmbeddingStore<Document>
private lateinit var embeddingStore: CustomEmbedding<Document>
private lateinit var stSemantic: STSemantic
private var _binding: FragmentHomeBinding? = null

Expand Down Expand Up @@ -89,8 +95,7 @@ class HomeFragment : Fragment() {
}

fun compute() {
// EmbeddingStore 没有 reset 方法,所以每次都需要重新创建
embeddingStore = InMemoryEmbeddingStore()
embeddingStore = CustomEmbedding()

// 获取所有的 TextView
val textViews = mutableListOf<EditText>()
Expand Down

0 comments on commit eae0f5b

Please sign in to comment.