diff --git a/examples/structured_output.rs b/examples/structured_output.rs new file mode 100644 index 0000000..c18b02b --- /dev/null +++ b/examples/structured_output.rs @@ -0,0 +1,41 @@ +use std::{env, error::Error, io::Write}; + +use futures_util::StreamExt; +use ollama_rs::{OllamaClient, types::generate::GenerateRequest}; +use serde_json::json; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let _ = dotenvy::dotenv(); + let server_address = env::var("OLLAMA_SERVER")?; + let ollama_client = OllamaClient::new(server_address); + let output_schema = json!({ + "type": "object", + "properties": { + "thought": { "type": "string", "description": "The thought that led to the response"}, + "response": { "type": "string", "description": "The response to the user"} + } + }); + let request = GenerateRequest::builder("dolphin3:8b") + .system_prompt("You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.") + .stream(false) + .format(output_schema) + .prompt("Why is the sky blue?") + .build(); + + let mut stream = ollama_client.generate(request); + while let Some(response) = stream.next().await { + match response { + Ok(token) => { + print!("{}", token.response); + std::io::stdout().flush()?; + if token.done { + break; + } + } + Err(e) => println!("Error: {}", e), + } + } + + Ok(()) +} diff --git a/src/types/chat.rs b/src/types/chat.rs index 24e90ae..247cf75 100644 --- a/src/types/chat.rs +++ b/src/types/chat.rs @@ -63,6 +63,9 @@ pub struct ChatRequest { #[serde(skip_serializing_if = "Vec::is_empty")] pub tools: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, } impl ChatRequest { @@ -92,6 +95,7 @@ impl ChatRequestBuilder { stream: None, options: None, tools: vec![], + format: None, }, } } @@ -116,6 +120,11 @@ impl ChatRequestBuilder { self } + pub fn format(mut self, json_schema: Value) -> Self { + self.chat_request.format = Some(json_schema); + self + } + pub fn build(self) -> ChatRequest { self.chat_request }