From c4e9f78ec9ddb035644c6c78e2957f9c62bc977c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cipriani=20Bandarra?= Date: Sun, 9 Jun 2024 10:07:00 +0100 Subject: [PATCH] Make content role optional --- examples/count-tokens.rs | 2 +- examples/system_instruction.rs | 6 +++--- src/client.rs | 4 ++-- src/types/common.rs | 2 +- src/types/generate_content.rs | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/count-tokens.rs b/examples/count-tokens.rs index d9d8d17..7541a46 100644 --- a/examples/count-tokens.rs +++ b/examples/count-tokens.rs @@ -17,7 +17,7 @@ async fn main() -> Result<(), Box> { let prompt = "What is the airspeed of an unladen swallow?"; let request = CountTokensRequest { contents: Content { - role: "user".to_string(), + role: Some("user".to_string()), parts: Some(vec![Part::Text(prompt.to_string())]), }, }; diff --git a/examples/system_instruction.rs b/examples/system_instruction.rs index ec603f3..4def033 100644 --- a/examples/system_instruction.rs +++ b/examples/system_instruction.rs @@ -15,16 +15,16 @@ async fn main() -> Result<(), Box> { location_id, ); - let system_instruction = "Answer as if you were Winston Churchill"; + let system_instruction = "Answer as if you were Yoda"; let prompt = "What is the airspeed of an unladen swallow?"; let request = GenerateContentRequest { contents: vec![Content { - role: "user".to_string(), + role: Some("user".to_string()), parts: Some(vec![Part::Text(prompt.to_string())]), }], system_instruction: Some(Content { - role: "system".to_string(), + role: None, parts: Some(vec![Part::Text(system_instruction.to_string())]), }), ..Default::default() diff --git a/src/client.rs b/src/client.rs index 58c05ef..ebd6f0a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -152,7 +152,7 @@ impl GeminiClient { contents: messages .iter() .map(|m| Content { - role: m.role.to_string(), + role: Some(m.role.to_string()), parts: Some(vec![Part::Text(m.text.clone())]), }) .collect(), @@ -181,7 +181,7 @@ impl GeminiClient { ) -> Result { let request = GenerateContentRequest { contents: vec![Content { - role: "user".to_string(), + role: Some("user".to_string()), parts: Some(vec![Part::Text(prompt.to_string())]), }], generation_config: generation_config.cloned(), diff --git a/src/types/common.rs b/src/types/common.rs index 1d6a43b..7a4d0dc 100644 --- a/src/types/common.rs +++ b/src/types/common.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Content { - pub role: String, + pub role: Option, pub parts: Option>, } diff --git a/src/types/generate_content.rs b/src/types/generate_content.rs index a90a7ea..3b20214 100644 --- a/src/types/generate_content.rs +++ b/src/types/generate_content.rs @@ -17,7 +17,7 @@ impl GenerateContentRequest { pub fn from_prompt(prompt: &str, generation_config: Option) -> Self { GenerateContentRequest { contents: vec![Content { - role: "user".to_string(), + role: Some("user".to_string()), parts: Some(vec![Part::Text(prompt.to_string())]), }], generation_config,