diff --git a/examples/system_instruction.rs b/examples/system_instruction.rs new file mode 100644 index 0000000..8bde470 --- /dev/null +++ b/examples/system_instruction.rs @@ -0,0 +1,50 @@ +use std::sync::Arc; + +use gemini_rs::prelude::*; + +use gcp_auth::AuthenticationManager; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt().init(); + let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let api_endpoint = std::env::var("API_ENDPOINT")?; + let project_id = std::env::var("PROJECT_ID")?; + let location_id = std::env::var("LOCATION_ID")?; + + let gemini = GeminiClient::new( + authentication_manager, + api_endpoint, + project_id, + location_id, + ); + + let system_instruction = "Answer as if you were Winston Churchill"; + let prompt = "What is the airspeed of an unladen swallow?"; + + let request = GenerateContentRequest { + contents: vec![Content { + role: "user".to_string(), + parts: Some(vec![Part::Text(prompt.to_string())]), + }], + system_instruction: Some(Content { + role: "system".to_string(), + parts: Some(vec![Part::Text(system_instruction.to_string())]), + }), + ..Default::default() + }; + + let result = gemini + .generate_content(&request, "gemini-1.0-pro-002") + .await?; + + if let GenerateContentResponse::Ok { + candidates, + usage_metadata: _, + } = result + { + println!("Response: {:?}", candidates[0].get_text().unwrap()); + } + + Ok(()) +} diff --git a/src/client.rs b/src/client.rs index 4728e0d..abe7dff 100644 --- a/src/client.rs +++ b/src/client.rs @@ -127,6 +127,7 @@ impl GeminiClient { .collect(), generation_config: None, tools: None, + system_instruction: None, }; let response = self.generate_content(&request, model).await?; @@ -154,6 +155,7 @@ impl GeminiClient { }], generation_config: generation_config.cloned(), tools: None, + system_instruction: None, }; let response = self.generate_content(&request, "gemini-pro").await?; diff --git a/src/types/error.rs b/src/types/error.rs index 25c40ec..85baef5 100644 --- a/src/types/error.rs +++ b/src/types/error.rs @@ -22,10 +22,22 @@ pub enum ErrorType { #[serde(rename = "type.googleapis.com/google.rpc.Help")] Help { links: Vec }, + + #[serde(rename = "type.googleapis.com/google.rpc.BadRequest")] + BadRequest { + #[serde(rename = "fieldViolations")] + field_violations: Vec, + }, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ErrorInfoMetadata { - service: String, - consumer: String, + pub service: String, + pub consumer: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FieldViolation { + pub field: String, + pub description: String, } diff --git a/src/types/generate_content.rs b/src/types/generate_content.rs index 527065d..67aea7c 100644 --- a/src/types/generate_content.rs +++ b/src/types/generate_content.rs @@ -4,11 +4,12 @@ use serde::{Deserialize, Serialize}; use super::{Content, Error, Part}; -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Default, Serialize, Deserialize)] pub struct GenerateContentRequest { pub contents: Vec, pub generation_config: Option, pub tools: Option>, + pub system_instruction: Option, } impl GenerateContentRequest { @@ -20,6 +21,7 @@ impl GenerateContentRequest { }], generation_config, tools: None, + system_instruction: None, } } }