diff --git a/examples/safety-setting.rs b/examples/safety-setting.rs new file mode 100644 index 0000000..f321265 --- /dev/null +++ b/examples/safety-setting.rs @@ -0,0 +1,42 @@ +use std::vec; + +use gemini_rs::prelude::*; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let authentication_manager = gcp_auth::provider().await?; + let api_endpoint = std::env::var("API_ENDPOINT")?; + let project_id = std::env::var("PROJECT_ID")?; + let location_id = std::env::var("LOCATION_ID")?; + + let gemini = GeminiClient::new( + authentication_manager, + api_endpoint, + project_id, + location_id, + ); + + let prompt = "What is the airspeed of an unladen swallow?"; + + let request = GenerateContentRequest { + contents: vec![Content { + role: Some("user".to_string()), + parts: Some(vec![Part::Text(prompt.to_string())]), + }], + safety_settings: Some(vec![SafetySetting { + category: HarmCategory::HateSpeech, + threshold: HarmBlockThreshold::BlockNone, + method: None, + }]), + ..Default::default() + }; + + println!("{}", serde_json::to_string_pretty(&request).unwrap()); + + let result = gemini + .generate_content(&request, "gemini-1.0-pro-002") + .await?; + println!("Response: {:?}", result); + + Ok(()) +} diff --git a/src/client.rs b/src/client.rs index d69d61c..89a57d5 100644 --- a/src/client.rs +++ b/src/client.rs @@ -159,6 +159,7 @@ impl GeminiClient { generation_config: None, tools: None, system_instruction: None, + safety_settings: None, }; let response = self.generate_content(&request, model).await?; @@ -187,6 +188,7 @@ impl GeminiClient { generation_config: generation_config.cloned(), tools: None, system_instruction: None, + safety_settings: None, }; let response = self.generate_content(&request, "gemini-pro").await?; diff --git a/src/types/generate_content.rs b/src/types/generate_content.rs index baac2d9..8dd1194 100644 --- a/src/types/generate_content.rs +++ b/src/types/generate_content.rs @@ -6,10 +6,16 @@ use super::{Content, Part, VertexApiError}; use crate::error::Result; #[derive(Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct GenerateContentRequest { pub contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] pub generation_config: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub safety_settings: Option>, + #[serde(skip_serializing_if = "Option::is_none")] pub system_instruction: Option, } @@ -23,14 +29,17 @@ impl GenerateContentRequest { generation_config, tools: None, system_instruction: None, + safety_settings: None, } } } #[derive(Clone, Default, Serialize, Deserialize)] pub struct Tools { + #[serde(skip_serializing_if = "Option::is_none")] pub function_declarations: Option>, #[serde(rename = "googleSearchRetrieval")] + #[serde(skip_serializing_if = "Option::is_none")] pub google_search_retrieval: Option, } @@ -43,21 +52,79 @@ pub struct GoogleSearchRetrieval { #[derive(Clone, Debug, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] pub max_output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub stop_sequences: Option>, + #[serde(skip_serializing_if = "Option::is_none")] pub candidate_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub response_mime_type: Option, } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SafetySetting { + pub category: HarmCategory, + pub threshold: HarmBlockThreshold, + #[serde(skip_serializing_if = "Option::is_none")] + pub method: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum HarmCategory { + #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")] + Unspecified, + #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")] + HateSpeech, + #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")] + DangerousContent, + #[serde(rename = "HARM_CATEGORY_HARASSMENT")] + Harassment, + #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")] + SexuallyExplicit, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum HarmBlockThreshold { + #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] + Unspecified, + #[serde(rename = "BLOCK_LOW_AND_ABOVE")] + BlockLowAndAbove, + #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")] + BlockMediumAndAbove, + #[serde(rename = "BLOCK_ONLY_HIGH")] + BlockOnlyHigh, + #[serde(rename = "BLOCK_NONE")] + BlockNone, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum HarmBlockMethod { + #[serde(rename = "HARM_BLOCK_METHOD_UNSPECIFIED")] + Unspecified, // HARM_BLOCK_METHOD_UNSPECIFIED + #[serde(rename = "SEVERITY")] + Severity, // SEVERITY + #[serde(rename = "PROBABILITY")] + Probability, // PROBABILITY +} + #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Candidate { + #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub citation_metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub safety_ratings: Option>, + #[serde(skip_serializing_if = "Option::is_none")] pub finish_reason: Option, }