-
Notifications
You must be signed in to change notification settings - Fork 1
/
clustering.py
284 lines (253 loc) · 10.9 KB
/
clustering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# clustering.py
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import streamlit as st
def recursive_clustering(embeddings, idea_texts, chat_model, min_cluster_size=5, node_id=None, level=0, parent_labels=None, current_node_id=0):
"""
Perform recursive clustering on embeddings and generate labels using LLM.
Returns:
nodes_data: List of node dictionaries.
edges_data: List of edge dictionaries.
current_node_id: The next available node ID (integer).
"""
nodes_data = []
edges_data = []
# Define colors for different levels
level_colors = ["#FF7F50", "#87CEFA", "#32CD32", "#BA55D3",
"#FFD700", "#FF69B4", "#CD5C5C", "#4B0082"]
n_samples = len(embeddings)
if node_id is None:
# Root node
cluster_label, cluster_description = generate_cluster_label(idea_texts, chat_model, parent_labels)
cluster_node_id = f"cluster_{current_node_id}"
node_data = {
'id': cluster_node_id,
'title': cluster_description.replace('\n', '<br>'),
'label': cluster_label,
'shape': 'ellipse',
'color': level_colors[level % len(level_colors)],
'font': {'size': 16, 'multi': True},
'full_text': cluster_label, # Store full text for markdown
'description': cluster_description # Store description
}
nodes_data.append(node_data)
current_node_id += 1
node_id = cluster_node_id # Set node_id to this root node ID
else:
# Generate label and description for this cluster using LLM
cluster_label, cluster_description = generate_cluster_label(idea_texts, chat_model, parent_labels)
cluster_node_id = f"cluster_{current_node_id}"
node_data = {
'id': cluster_node_id,
'title': cluster_description.replace('\n', '<br>'),
'label': cluster_label,
'shape': 'ellipse',
'color': level_colors[level % len(level_colors)],
'font': {'size': 16, 'multi': True},
'full_text': cluster_label, # Store full text for markdown
'description': cluster_description # Store description
}
nodes_data.append(node_data)
current_node_id += 1
# Add edge from parent node to this cluster node
edge_data = {
'source': node_id,
'to': cluster_node_id # Use 'to' instead of 'target'
}
edges_data.append(edge_data)
node_id = cluster_node_id # Update node_id to current cluster's node ID
if n_samples <= min_cluster_size or n_samples <= 2:
# Base case: create nodes for ideas
for i, idea in enumerate(idea_texts):
idea_node_id = f"idea_{current_node_id}"
node_data = {
'id': idea_node_id,
'title': idea.replace('\n', '<br>'),
'label': idea,
'shape': 'box',
'color': level_colors[level % len(level_colors)],
'font': {'size': 16, 'multi': True},
'full_text': idea # Store full text for markdown
# No description for idea nodes
}
nodes_data.append(node_data)
edge_data = {
'source': node_id,
'to': idea_node_id # Use 'to' instead of 'target'
}
edges_data.append(edge_data)
current_node_id += 1
return nodes_data, edges_data, current_node_id
else:
# Determine the optimal number of clusters using silhouette score
optimal_k = determine_optimal_clusters(
embeddings, min_k=2, max_k=min(10, n_samples - 1))
if optimal_k < 2:
# Cannot cluster further, proceed to base case
for i, idea in enumerate(idea_texts):
idea_node_id = f"idea_{current_node_id}"
node_data = {
'id': idea_node_id,
'title': idea.replace('\n', '<br>'),
'label': idea,
'shape': 'box',
'color': level_colors[level % len(level_colors)],
'font': {'size': 16, 'multi': True},
'full_text': idea # Store full text for markdown
# No description for idea nodes
}
nodes_data.append(node_data)
edge_data = {
'source': node_id,
'to': idea_node_id # Use 'to' instead of 'target'
}
edges_data.append(edge_data)
current_node_id += 1
return nodes_data, edges_data, current_node_id
# Perform KMeans clustering
try:
kmeans = KMeans(n_clusters=optimal_k, random_state=42)
labels = kmeans.fit_predict(embeddings)
except Exception as e:
# If clustering fails, proceed to base case
for i, idea in enumerate(idea_texts):
idea_node_id = f"idea_{current_node_id}"
node_data = {
'id': idea_node_id,
'title': idea.replace('\n', '<br>'),
'label': idea,
'shape': 'box',
'color': level_colors[level % len(level_colors)],
'font': {'size': 16, 'multi': True},
'full_text': idea # Store full text for markdown
# No description for idea nodes
}
nodes_data.append(node_data)
edge_data = {
'source': node_id,
'to': idea_node_id # Use 'to' instead of 'target'
}
edges_data.append(edge_data)
current_node_id += 1
return nodes_data, edges_data, current_node_id
# Update parent labels
current_parent_labels = parent_labels.copy() if parent_labels else []
current_parent_labels.append(cluster_label)
# Group embeddings and idea_texts by cluster
unique_labels = np.unique(labels)
for cluster_num in unique_labels:
cluster_indices = [
idx for idx, label in enumerate(labels) if label == cluster_num]
cluster_embeddings = [embeddings[idx] for idx in cluster_indices]
cluster_idea_texts = [idea_texts[idx]
for idx in cluster_indices]
# Recursively cluster
child_nodes_data, child_edges_data, current_node_id = recursive_clustering(
cluster_embeddings,
cluster_idea_texts,
chat_model,
min_cluster_size=min_cluster_size,
node_id=node_id, # Pass current cluster node ID as parent
level=level + 1,
parent_labels=current_parent_labels,
current_node_id=current_node_id # Pass the updated node ID
)
# Add child nodes and edges to the main lists
nodes_data.extend(child_nodes_data)
edges_data.extend(child_edges_data)
return nodes_data, edges_data, current_node_id
def determine_optimal_clusters(embeddings, min_k=2, max_k=10):
"""
Determine the optimal number of clusters using silhouette score.
Args:
embeddings: List of embeddings.
min_k: Minimum number of clusters to try.
max_k: Maximum number of clusters to try.
Returns:
The optimal number of clusters.
"""
n_samples = len(embeddings)
if n_samples <= 2:
return 1 # Cannot cluster further
min_k = max(2, min_k)
max_k = min(max_k, n_samples - 1)
if min_k > max_k:
return 1 # Cannot cluster further
best_k = 1
best_score = -1
for k in range(min_k, max_k + 1):
try:
kmeans = KMeans(n_clusters=k, random_state=42)
labels = kmeans.fit_predict(embeddings)
if len(np.unique(labels)) < 2:
continue # Need at least 2 clusters for silhouette score
score = silhouette_score(embeddings, labels)
if score > best_score:
best_k = k
best_score = score
except Exception as e:
# Catch exceptions such as when n_samples < n_clusters
continue
return best_k if best_score > -1 else 1 # Return 1 if no better k found
def generate_cluster_label(idea_texts_list, chat_model, parent_labels=None):
"""
Generate a label and description for a cluster using the LLM.
Args:
idea_texts_list: List of idea texts in the cluster.
chat_model: The LLM model to use.
parent_labels: List of parent cluster labels.
Returns:
A tuple of (label string, description string) generated by the LLM.
"""
# Prepare the prompt
ideas_concatenated = '\n'.join(idea_texts_list[:50]) # Limit to first 50 ideas to control prompt length
parent_labels_text = ', '.join(parent_labels) if parent_labels else ''
prompt_template = """
You are an assistant that labels clusters of ideas. Given the following ideas:
{IDEAS}
Parent cluster labels: {PARENT_LABELS}
Considering the parent labels to avoid redundancy, provide:
1. A concise label (a few words) that summarizes the main theme of these ideas and is distinct from the parent labels.
2. A brief description (2-3 sentences) that captures the essence of this cluster.
Return your answer in JSON format **without any additional text**, and ensure that the output is ONLY the JSON and nothing else.
Example output:
{{
"label": "Cluster Label",
"description": "Cluster Description"
}}
"""
# Use the LLM to generate the label and description
chain = LLMChain(
prompt=PromptTemplate.from_template(prompt_template),
llm=chat_model
)
response = chain.run(IDEAS=ideas_concatenated, PARENT_LABELS=parent_labels_text)
# Clean up the response and parse the JSON
import json
import re
try:
# Extract JSON object from the response using regex
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
json_str = json_match.group(0)
response_json = json.loads(json_str)
label = response_json.get('label', 'No Label')
description = response_json.get('description', 'No Description')
else:
# If no JSON found, set default values
label = "No Label"
description = "No Description"
st.error("LLM did not return JSON format as expected.")
st.write("LLM Response:")
st.write(response)
except json.JSONDecodeError as e:
label = "No Label"
description = "No Description"
st.error(f"Error parsing LLM response: {e}")
st.write("LLM Response:")
st.write(response)
return label.strip(), description.strip()