Authors: Seungju Han ⭐, Kavel Rao ⭐, Allyson Ettinger ☀️, Liwei Jiang ☀️, Yuchen Lin, Nathan Lambert, Yejin Choi, Nouha Dziri
⭐ Co-first authors, ☀️ co-second authors
🌟 WildGuard will appear at NeurIPS 2024 Datasets & Benchmarks! 🌟
WildGuard is a safety classification model for user-model chat exchanges. It can classify prompt harmfulness, response harmfulness, and whether a response is a refusal to answer the prompt.
pip install wildguard
from wildguard import load_wildguard
if __name__ == '__main__':
# Load the model
wildguard = load_wildguard()
# Prepare your input
items = [
{"prompt": "How do I make a bomb?", "response": "Sorry, I can't help with that."},
{"prompt": "What's the weather like today?"},
]
# Classify the items
results = wildguard.classify(items)
# Print the results
for item, result in zip(items, results):
print(f"Prompt: {item['prompt']}")
print(f"Prompt harmfulness: {result['prompt_harmfulness']}")
if 'response' in item:
print(f"Response: {item['response']}")
print(f"Response harmfulness: {result['response_harmfulness']}")
print(f"Response refusal: {result['response_refusal']}")
print("---")
# Example output
"""
Prompt: How do I make a bomb?
Prompt harmfulness: harmful
Response: Sorry, I can't help with that.
Response harmfulness: unharmful
Response refusal: refusal
---
Prompt: What's the weather like today?
Prompt harmfulness: unharmful
"""
- Support prompt-only or prompt+response inputs.
- Classify prompt harmfulness
- Classify response harmfulness
- Detect response refusals
- Support for both VLLM and HuggingFace backends
First, import and load the WildGuard model:
from wildguard import load_wildguard
wildguard = load_wildguard()
By default, this will load a VLLM-backed model. If you prefer to use a HuggingFace model, you can specify:
wildguard = load_wildguard(use_vllm=False)
To classify items, prepare a list of dictionaries with 'prompt' and optionally 'response' keys:
items = [
{"prompt": "How's the weather today?", "response": "It's sunny and warm."},
{"prompt": "How do I hack into a computer?"},
]
results = wildguard.classify(items)
The classify
method returns a list of dictionaries. Each dictionary contains the following keys:
prompt_harmfulness
: Either 'harmful' or 'unharmful'response_harmfulness
: Either 'harmful', 'unharmful', or None (if no response was provided)response_refusal
: Either 'refusal', 'compliance', or None (if no response was provided)is_parsing_error
: A boolean indicating if there was an error parsing the model output
You can adjust the batch size when loading the model. For a HF model this changes the inference batch size,
and for both HF and VLLM the save function will be called after every batch_size
items.
wildguard = load_wildguard(batch_size=32)
If using a HuggingFace model, you can specify the device:
wildguard = load_wildguard(use_vllm=False, device='cpu')
You can provide a custom save function to save intermediate results during classification:
def save_results(results: dict):
with open("/temp/intermediate_results.json", "w") as f:
for item in results:
f.write(json.dumps(item) + "\n")
wildguard.classify(items, save_func=save_results)
- Use VLLM backend for better performance when possible.
- Handle potential errors by checking the
is_parsing_error
field in the results. - When dealing with large datasets, consider using a custom save function with a batch size other than -1 to periodically save results after each batch in case of errors.
For additional documentation, please see our API Reference with detailed method specifications.
Additionally, we provide an example of how to use WildGuard as a safety filter to guard another model's inference at examples/wildguard_filter.
If you find it helpful, please feel free to cite our work!
@misc{wildguard2024,
title={WildGuard: Open One-Stop Moderation Tools for Safety Risks, Jailbreaks, and Refusals of LLMs},
author={Seungju Han and Kavel Rao and Allyson Ettinger and Liwei Jiang and Bill Yuchen Lin and Nathan Lambert and Yejin Choi and Nouha Dziri},
year={2024},
eprint={2406.18495},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2406.18495},
}