diff --git a/src/error.rs b/src/error.rs index ab34dcb..36c0c0e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -40,3 +40,48 @@ impl From for OllamaError { Self::LinesCodecError(value) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display_response_parse_error() { + let err: Result = serde_json::from_str("not json"); + let ollama_err = OllamaError::from(err.unwrap_err()); + let display = format!("{}", ollama_err); + assert!(display.starts_with("Response parse error:")); + assert!(!display.ends_with('\n')); + } + + #[test] + fn display_lines_codec_error() { + let err = LinesCodecError::MaxLineLengthExceeded; + let ollama_err = OllamaError::from(err); + let display = format!("{}", ollama_err); + assert!(display.starts_with("Lines codec error:")); + assert!(!display.ends_with('\n')); + } + + #[test] + fn from_serde_json_error() { + let err: Result = serde_json::from_str("not json"); + let ollama_err = OllamaError::from(err.unwrap_err()); + assert!(matches!(ollama_err, OllamaError::ResponseParseError(_))); + } + + #[test] + fn from_lines_codec_error() { + let err = LinesCodecError::MaxLineLengthExceeded; + let ollama_err = OllamaError::from(err); + assert!(matches!(ollama_err, OllamaError::LinesCodecError(_))); + } + + #[test] + fn error_trait_is_implemented() { + let err: Result = serde_json::from_str("not json"); + let ollama_err = OllamaError::from(err.unwrap_err()); + // Verify it implements std::error::Error by using it as a trait object + let _: &dyn Error = &ollama_err; + } +} diff --git a/src/types/chat.rs b/src/types/chat.rs index db1f16e..957423e 100644 --- a/src/types/chat.rs +++ b/src/types/chat.rs @@ -172,3 +172,172 @@ pub struct ToolCallFunction { pub arguments: Value, pub index: usize, } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn role_serializes_lowercase() { + assert_eq!(serde_json::to_value(&Role::User).unwrap(), json!("user")); + assert_eq!( + serde_json::to_value(&Role::System).unwrap(), + json!("system") + ); + assert_eq!( + serde_json::to_value(&Role::Assistant).unwrap(), + json!("assistant") + ); + assert_eq!(serde_json::to_value(&Role::Tool).unwrap(), json!("tool")); + } + + #[test] + fn role_deserializes_lowercase() { + let role: Role = serde_json::from_value(json!("user")).unwrap(); + assert!(matches!(role, Role::User)); + } + + #[test] + fn message_system_constructor() { + let msg = Message::system("you are helpful"); + assert_eq!(msg.content, "you are helpful"); + assert!(matches!(msg.role, Role::System)); + assert!(msg.tool_calls.is_empty()); + } + + #[test] + fn message_user_constructor() { + let msg = Message::user("hello"); + assert_eq!(msg.content, "hello"); + assert!(matches!(msg.role, Role::User)); + } + + #[test] + fn message_tool_response_constructor() { + let value = json!({"temperature": 22.0}); + let msg = Message::tool_response(&value).unwrap(); + assert!(matches!(msg.role, Role::Tool)); + assert_eq!(msg.content, serde_json::to_string(&value).unwrap()); + } + + #[test] + fn message_skips_empty_tool_calls() { + let msg = Message::user("hello"); + let json = serde_json::to_value(&msg).unwrap(); + assert!(!json.as_object().unwrap().contains_key("tool_calls")); + } + + #[test] + fn message_deserializes_without_tool_calls() { + let json = json!({"content": "hi", "role": "user"}); + let msg: Message = serde_json::from_value(json).unwrap(); + assert_eq!(msg.content, "hi"); + assert!(msg.tool_calls.is_empty()); + } + + #[test] + fn chat_request_always_serializes_messages() { + let request = ChatRequest::builder("llama3").build(); + let json = serde_json::to_value(&request).unwrap(); + let obj = json.as_object().unwrap(); + assert!(obj.contains_key("messages")); + assert_eq!(obj["messages"], json!([])); + } + + #[test] + fn chat_request_skips_optional_fields() { + let request = ChatRequest::builder("llama3").build(); + let json = serde_json::to_value(&request).unwrap(); + let obj = json.as_object().unwrap(); + assert!(!obj.contains_key("stream")); + assert!(!obj.contains_key("options")); + assert!(!obj.contains_key("tools")); + assert!(!obj.contains_key("format")); + assert!(!obj.contains_key("think")); + } + + #[test] + fn chat_request_builder_with_messages() { + let messages = vec![Message::system("be helpful"), Message::user("hello")]; + let request = ChatRequest::builder("llama3") + .messages(messages) + .stream(false) + .build(); + + assert_eq!(request.model, "llama3"); + assert_eq!(request.messages.len(), 2); + assert_eq!(request.stream, Some(false)); + } + + #[test] + fn tool_type_serializes_as_type_field() { + let tool = Tool { + tool_type: ToolType::Function, + function: Function { + name: "get_weather".to_string(), + description: "Get weather".to_string(), + parameters: json!({"type": "object"}), + }, + }; + let json = serde_json::to_value(&tool).unwrap(); + let obj = json.as_object().unwrap(); + assert!(obj.contains_key("type")); + assert!(!obj.contains_key("tool_type")); + assert_eq!(obj["type"], json!("function")); + } + + #[test] + fn tool_type_deserializes_from_type_field() { + let json = json!({ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": {"type": "object"} + } + }); + let tool: Tool = serde_json::from_value(json).unwrap(); + assert!(matches!(tool.tool_type, ToolType::Function)); + assert_eq!(tool.function.name, "get_weather"); + } + + #[test] + fn chat_response_deserialize() { + let json = json!({ + "model": "llama3", + "created_at": "2024-01-01T00:00:00Z", + "message": {"content": "Hello!", "role": "assistant"}, + "done": false + }); + let response: ChatResponse = serde_json::from_value(json).unwrap(); + assert_eq!(response.model, "llama3"); + assert_eq!(response.message.content, "Hello!"); + assert!(matches!(response.message.role, Role::Assistant)); + assert!(!response.done); + } + + #[test] + fn chat_response_with_tool_calls() { + let json = json!({ + "model": "llama3", + "created_at": "2024-01-01T00:00:00Z", + "message": { + "content": "", + "role": "assistant", + "tool_calls": [{ + "function": { + "name": "get_weather", + "arguments": {"city": "Paris"}, + "index": 0 + } + }] + }, + "done": true + }); + let response: ChatResponse = serde_json::from_value(json).unwrap(); + assert_eq!(response.message.tool_calls.len(), 1); + assert_eq!(response.message.tool_calls[0].function.name, "get_weather"); + assert_eq!(response.message.tool_calls[0].function.index, 0); + } +} diff --git a/src/types/common.rs b/src/types/common.rs index 39c7598..1cf52cc 100644 --- a/src/types/common.rs +++ b/src/types/common.rs @@ -123,3 +123,171 @@ pub enum Stop { Single(String), Multiple(Vec), } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn think_bool_true() { + let think: Think = serde_json::from_value(json!(true)).unwrap(); + assert!(matches!(think, Think::Bool(true))); + } + + #[test] + fn think_bool_false() { + let think: Think = serde_json::from_value(json!(false)).unwrap(); + assert!(matches!(think, Think::Bool(false))); + } + + #[test] + fn think_level_high() { + let think: Think = serde_json::from_value(json!("high")).unwrap(); + assert!(matches!(think, Think::Level(ThinkLevel::High))); + } + + #[test] + fn think_level_medium() { + let think: Think = serde_json::from_value(json!("medium")).unwrap(); + assert!(matches!(think, Think::Level(ThinkLevel::Medium))); + } + + #[test] + fn think_level_low() { + let think: Think = serde_json::from_value(json!("low")).unwrap(); + assert!(matches!(think, Think::Level(ThinkLevel::Low))); + } + + #[test] + fn think_bool_round_trip() { + let think = Think::Bool(true); + let json = serde_json::to_value(&think).unwrap(); + assert_eq!(json, json!(true)); + } + + #[test] + fn think_level_round_trip() { + let think = Think::Level(ThinkLevel::High); + let json = serde_json::to_value(&think).unwrap(); + assert_eq!(json, json!("high")); + } + + #[test] + fn stop_single() { + let stop: Stop = serde_json::from_value(json!("end")).unwrap(); + assert!(matches!(stop, Stop::Single(s) if s == "end")); + } + + #[test] + fn stop_multiple() { + let stop: Stop = serde_json::from_value(json!(["end", "stop"])).unwrap(); + match stop { + Stop::Multiple(v) => assert_eq!(v, vec!["end", "stop"]), + _ => panic!("expected Stop::Multiple"), + } + } + + #[test] + fn stop_single_round_trip() { + let stop = Stop::Single("end".to_string()); + let json = serde_json::to_value(&stop).unwrap(); + assert_eq!(json, json!("end")); + } + + #[test] + fn stop_multiple_round_trip() { + let stop = Stop::Multiple(vec!["end".to_string(), "stop".to_string()]); + let json = serde_json::to_value(&stop).unwrap(); + assert_eq!(json, json!(["end", "stop"])); + } + + #[test] + fn options_default_serializes_empty() { + let options = Options::default(); + let json = serde_json::to_value(&options).unwrap(); + assert_eq!(json, json!({})); + } + + #[test] + fn options_skips_none_fields() { + let options = Options::builder().seed(42).temperature(0.5).build(); + let json = serde_json::to_value(&options).unwrap(); + assert_eq!(json, json!({"seed": 42, "temperature": 0.5})); + assert!(!json.as_object().unwrap().contains_key("top_k")); + } + + #[test] + fn options_builder_all_fields() { + let options = Options::builder() + .seed(42) + .temperature(0.7) + .top_k(40) + .top_p(0.9) + .min_p(0.05) + .stop(Stop::Single("end".to_string())) + .num_ctx(4096) + .num_predict(128) + .build(); + + assert_eq!(options.seed, Some(42)); + assert_eq!(options.temperature, Some(0.7)); + assert_eq!(options.top_k, Some(40)); + assert_eq!(options.top_p, Some(0.9)); + assert_eq!(options.min_p, Some(0.05)); + assert!(options.stop.is_some()); + assert_eq!(options.num_ctx, Some(4096)); + assert_eq!(options.num_predict, Some(128)); + } + + #[test] + fn options_round_trip() { + let json = json!({ + "seed": 42, + "temperature": 0.5, + "top_k": 40, + "num_ctx": 4096 + }); + let options: Options = serde_json::from_value(json.clone()).unwrap(); + assert_eq!(options.seed, Some(42)); + assert_eq!(options.temperature, Some(0.5)); + assert_eq!(options.top_k, Some(40)); + assert_eq!(options.num_ctx, Some(4096)); + assert_eq!(options.top_p, None); + + let reserialized = serde_json::to_value(&options).unwrap(); + assert_eq!(reserialized, json); + } + + #[test] + fn model_details_round_trip() { + let json = json!({ + "format": "gguf", + "family": "llama", + "families": ["llama", "clip"], + "parameter_size": "8B", + "quantization_level": "Q4_0" + }); + let details: ModelDetails = serde_json::from_value(json).unwrap(); + assert_eq!(details.format, "gguf"); + assert_eq!(details.family, "llama"); + assert_eq!( + details.families, + Some(vec!["llama".to_string(), "clip".to_string()]) + ); + assert_eq!(details.parameter_size, "8B"); + assert_eq!(details.quantization_level, "Q4_0"); + } + + #[test] + fn model_details_without_families() { + let json = json!({ + "format": "gguf", + "family": "llama", + "parameter_size": "8B", + "quantization_level": "Q4_0" + }); + let details: ModelDetails = serde_json::from_value(json).unwrap(); + assert_eq!(details.families, None); + } +} diff --git a/src/types/generate.rs b/src/types/generate.rs index a943f5a..2602953 100644 --- a/src/types/generate.rs +++ b/src/types/generate.rs @@ -154,3 +154,122 @@ pub struct GenerateResponse { /// Time spent generating tokens in nanoseconds pub eval_duration: Option, } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn builder_minimal() { + let request = GenerateRequest::builder("llama3").build(); + assert_eq!(request.model, "llama3"); + assert!(request.prompt.is_none()); + assert!(request.system.is_none()); + assert!(request.images.is_empty()); + } + + #[test] + fn builder_with_all_fields() { + let request = GenerateRequest::builder("llama3") + .prompt("hello") + .system_prompt("you are helpful") + .suffix("end".to_string()) + .stream(false) + .images(vec!["base64data".to_string()]) + .format(json!("json")) + .think(Think::Bool(true)) + .options(Options::builder().seed(42).build()) + .build(); + + assert_eq!(request.model, "llama3"); + assert_eq!(request.prompt, Some("hello".to_string())); + assert_eq!(request.system, Some("you are helpful".to_string())); + assert_eq!(request.suffix, Some("end".to_string())); + assert_eq!(request.stream, Some(false)); + assert_eq!(request.images, vec!["base64data".to_string()]); + assert!(request.format.is_some()); + assert!(request.think.is_some()); + assert!(request.options.is_some()); + } + + #[test] + fn request_skips_none_fields() { + let request = GenerateRequest::builder("llama3").prompt("hello").build(); + let json = serde_json::to_value(&request).unwrap(); + let obj = json.as_object().unwrap(); + + assert!(obj.contains_key("model")); + assert!(obj.contains_key("prompt")); + assert!(!obj.contains_key("suffix")); + assert!(!obj.contains_key("system")); + assert!(!obj.contains_key("stream")); + assert!(!obj.contains_key("images")); + assert!(!obj.contains_key("format")); + assert!(!obj.contains_key("think")); + assert!(!obj.contains_key("options")); + } + + #[test] + fn request_includes_images_when_nonempty() { + let request = GenerateRequest::builder("llama3") + .images(vec!["abc".to_string()]) + .build(); + let json = serde_json::to_value(&request).unwrap(); + assert!(json.as_object().unwrap().contains_key("images")); + } + + #[test] + fn response_deserialize_streaming_chunk() { + let json = json!({ + "model": "llama3", + "created_at": "2024-01-01T00:00:00Z", + "response": "Hello", + "done": false + }); + let response: GenerateResponse = serde_json::from_value(json).unwrap(); + assert_eq!(response.model, "llama3"); + assert_eq!(response.response, "Hello"); + assert!(!response.done); + assert!(response.done_reason.is_none()); + assert!(response.total_duration.is_none()); + } + + #[test] + fn response_deserialize_final_chunk() { + let json = json!({ + "model": "llama3", + "created_at": "2024-01-01T00:00:00Z", + "response": "", + "done": true, + "done_reason": "stop", + "total_duration": 5000000000u64, + "load_duration": 1000000000u64, + "prompt_eval_count": 10, + "prompt_eval_duration": 500000000u64, + "eval_count": 50, + "eval_duration": 3500000000u64 + }); + let response: GenerateResponse = serde_json::from_value(json).unwrap(); + assert!(response.done); + assert_eq!(response.done_reason, Some("stop".to_string())); + assert_eq!(response.total_duration, Some(5_000_000_000)); + assert_eq!(response.eval_count, Some(50)); + } + + #[test] + fn response_deserialize_with_thinking() { + let json = json!({ + "model": "llama3", + "created_at": "2024-01-01T00:00:00Z", + "response": "The answer is 42.", + "thinking": "Let me think about this...", + "done": true + }); + let response: GenerateResponse = serde_json::from_value(json).unwrap(); + assert_eq!( + response.thinking, + Some("Let me think about this...".to_string()) + ); + } +} diff --git a/src/types/ps.rs b/src/types/ps.rs index b5f735c..fe629c7 100644 --- a/src/types/ps.rs +++ b/src/types/ps.rs @@ -18,3 +18,42 @@ pub struct RunningModel { pub size_vram: u64, pub context_length: u32, } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn deserialize_ps_response() { + let json = json!({ + "models": [{ + "name": "llama3:latest", + "model": "llama3:latest", + "size": 4_700_000_000u64, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "llama", + "parameter_size": "8B", + "quantization_level": "Q4_0" + }, + "expires_at": "2024-01-01T01:00:00Z", + "size_vram": 4_700_000_000u64, + "context_length": 8192 + }] + }); + let response: PsResponse = serde_json::from_value(json).unwrap(); + assert_eq!(response.models.len(), 1); + assert_eq!(response.models[0].name, "llama3:latest"); + assert_eq!(response.models[0].size_vram, 4_700_000_000); + assert_eq!(response.models[0].context_length, 8192); + } + + #[test] + fn deserialize_empty_models() { + let json = json!({"models": []}); + let response: PsResponse = serde_json::from_value(json).unwrap(); + assert!(response.models.is_empty()); + } +} diff --git a/src/types/pull.rs b/src/types/pull.rs index e90861a..df732fc 100644 --- a/src/types/pull.rs +++ b/src/types/pull.rs @@ -45,3 +45,51 @@ impl PullRequestBuilder { pub struct PullResponse { pub status: String, } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn builder_minimal() { + let request = PullRequest::builder("llama3").build(); + assert_eq!(request.model, "llama3"); + assert!(request.insecure.is_none()); + assert!(request.stream.is_none()); + } + + #[test] + fn builder_with_options() { + let request = PullRequest::builder("llama3") + .stream(true) + .insecure(false) + .build(); + assert_eq!(request.stream, Some(true)); + assert_eq!(request.insecure, Some(false)); + } + + #[test] + fn request_skips_none_fields() { + let request = PullRequest::builder("llama3").build(); + let json = serde_json::to_value(&request).unwrap(); + let obj = json.as_object().unwrap(); + assert!(obj.contains_key("model")); + assert!(!obj.contains_key("insecure")); + assert!(!obj.contains_key("stream")); + } + + #[test] + fn request_includes_set_fields() { + let request = PullRequest::builder("llama3").stream(true).build(); + let json = serde_json::to_value(&request).unwrap(); + assert_eq!(json, json!({"model": "llama3", "stream": true})); + } + + #[test] + fn response_deserialize() { + let json = json!({"status": "pulling manifest"}); + let response: PullResponse = serde_json::from_value(json).unwrap(); + assert_eq!(response.status, "pulling manifest"); + } +} diff --git a/src/types/tags.rs b/src/types/tags.rs index 90e209d..4c861de 100644 --- a/src/types/tags.rs +++ b/src/types/tags.rs @@ -16,3 +16,40 @@ pub struct Model { pub digest: String, pub details: ModelDetails, } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn deserialize_tags_response() { + let json = json!({ + "models": [{ + "name": "llama3:latest", + "model": "llama3:latest", + "modified_at": "2024-01-01T00:00:00Z", + "size": 4_700_000_000u64, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "llama", + "parameter_size": "8B", + "quantization_level": "Q4_0" + } + }] + }); + let response: TagsResponse = serde_json::from_value(json).unwrap(); + assert_eq!(response.models.len(), 1); + assert_eq!(response.models[0].name, "llama3:latest"); + assert_eq!(response.models[0].size, 4_700_000_000); + assert_eq!(response.models[0].details.family, "llama"); + } + + #[test] + fn deserialize_empty_models() { + let json = json!({"models": []}); + let response: TagsResponse = serde_json::from_value(json).unwrap(); + assert!(response.models.is_empty()); + } +} diff --git a/src/types/version.rs b/src/types/version.rs index 65730e3..ec92427 100644 --- a/src/types/version.rs +++ b/src/types/version.rs @@ -4,3 +4,25 @@ use serde::{Deserialize, Serialize}; pub struct VersionResponse { pub version: String, } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn deserialize() { + let json = json!({"version": "0.6.2"}); + let response: VersionResponse = serde_json::from_value(json).unwrap(); + assert_eq!(response.version, "0.6.2"); + } + + #[test] + fn round_trip() { + let response = VersionResponse { + version: "0.6.2".to_string(), + }; + let json = serde_json::to_value(&response).unwrap(); + assert_eq!(json, json!({"version": "0.6.2"})); + } +}