Adds safety setting
This commit is contained in:
42
examples/safety-setting.rs
Normal file
42
examples/safety-setting.rs
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
use std::vec;
|
||||||
|
|
||||||
|
use gemini_rs::prelude::*;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let authentication_manager = gcp_auth::provider().await?;
|
||||||
|
let api_endpoint = std::env::var("API_ENDPOINT")?;
|
||||||
|
let project_id = std::env::var("PROJECT_ID")?;
|
||||||
|
let location_id = std::env::var("LOCATION_ID")?;
|
||||||
|
|
||||||
|
let gemini = GeminiClient::new(
|
||||||
|
authentication_manager,
|
||||||
|
api_endpoint,
|
||||||
|
project_id,
|
||||||
|
location_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
let prompt = "What is the airspeed of an unladen swallow?";
|
||||||
|
|
||||||
|
let request = GenerateContentRequest {
|
||||||
|
contents: vec![Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
|
}],
|
||||||
|
safety_settings: Some(vec![SafetySetting {
|
||||||
|
category: HarmCategory::HateSpeech,
|
||||||
|
threshold: HarmBlockThreshold::BlockNone,
|
||||||
|
method: None,
|
||||||
|
}]),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("{}", serde_json::to_string_pretty(&request).unwrap());
|
||||||
|
|
||||||
|
let result = gemini
|
||||||
|
.generate_content(&request, "gemini-1.0-pro-002")
|
||||||
|
.await?;
|
||||||
|
println!("Response: {:?}", result);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -159,6 +159,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
generation_config: None,
|
generation_config: None,
|
||||||
tools: None,
|
tools: None,
|
||||||
system_instruction: None,
|
system_instruction: None,
|
||||||
|
safety_settings: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self.generate_content(&request, model).await?;
|
let response = self.generate_content(&request, model).await?;
|
||||||
@@ -187,6 +188,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
generation_config: generation_config.cloned(),
|
generation_config: generation_config.cloned(),
|
||||||
tools: None,
|
tools: None,
|
||||||
system_instruction: None,
|
system_instruction: None,
|
||||||
|
safety_settings: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self.generate_content(&request, "gemini-pro").await?;
|
let response = self.generate_content(&request, "gemini-pro").await?;
|
||||||
|
|||||||
@@ -6,10 +6,16 @@ use super::{Content, Part, VertexApiError};
|
|||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
|
|
||||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct GenerateContentRequest {
|
pub struct GenerateContentRequest {
|
||||||
pub contents: Vec<Content>,
|
pub contents: Vec<Content>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub generation_config: Option<GenerationConfig>,
|
pub generation_config: Option<GenerationConfig>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tools: Option<Vec<Tools>>,
|
pub tools: Option<Vec<Tools>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub safety_settings: Option<Vec<SafetySetting>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub system_instruction: Option<Content>,
|
pub system_instruction: Option<Content>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,14 +29,17 @@ impl GenerateContentRequest {
|
|||||||
generation_config,
|
generation_config,
|
||||||
tools: None,
|
tools: None,
|
||||||
system_instruction: None,
|
system_instruction: None,
|
||||||
|
safety_settings: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||||
pub struct Tools {
|
pub struct Tools {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||||
#[serde(rename = "googleSearchRetrieval")]
|
#[serde(rename = "googleSearchRetrieval")]
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub google_search_retrieval: Option<GoogleSearchRetrieval>,
|
pub google_search_retrieval: Option<GoogleSearchRetrieval>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,21 +52,79 @@ pub struct GoogleSearchRetrieval {
|
|||||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct GenerationConfig {
|
pub struct GenerationConfig {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub max_output_tokens: Option<i32>,
|
pub max_output_tokens: Option<i32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub top_k: Option<i32>,
|
pub top_k: Option<i32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stop_sequences: Option<Vec<String>>,
|
pub stop_sequences: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub candidate_count: Option<u32>,
|
pub candidate_count: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub response_mime_type: Option<String>,
|
pub response_mime_type: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct SafetySetting {
|
||||||
|
pub category: HarmCategory,
|
||||||
|
pub threshold: HarmBlockThreshold,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub method: Option<HarmBlockMethod>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub enum HarmCategory {
|
||||||
|
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
|
||||||
|
Unspecified,
|
||||||
|
#[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
|
||||||
|
HateSpeech,
|
||||||
|
#[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
|
||||||
|
DangerousContent,
|
||||||
|
#[serde(rename = "HARM_CATEGORY_HARASSMENT")]
|
||||||
|
Harassment,
|
||||||
|
#[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
|
||||||
|
SexuallyExplicit,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub enum HarmBlockThreshold {
|
||||||
|
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
|
||||||
|
Unspecified,
|
||||||
|
#[serde(rename = "BLOCK_LOW_AND_ABOVE")]
|
||||||
|
BlockLowAndAbove,
|
||||||
|
#[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
|
||||||
|
BlockMediumAndAbove,
|
||||||
|
#[serde(rename = "BLOCK_ONLY_HIGH")]
|
||||||
|
BlockOnlyHigh,
|
||||||
|
#[serde(rename = "BLOCK_NONE")]
|
||||||
|
BlockNone,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub enum HarmBlockMethod {
|
||||||
|
#[serde(rename = "HARM_BLOCK_METHOD_UNSPECIFIED")]
|
||||||
|
Unspecified, // HARM_BLOCK_METHOD_UNSPECIFIED
|
||||||
|
#[serde(rename = "SEVERITY")]
|
||||||
|
Severity, // SEVERITY
|
||||||
|
#[serde(rename = "PROBABILITY")]
|
||||||
|
Probability, // PROBABILITY
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct Candidate {
|
pub struct Candidate {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub content: Option<Content>,
|
pub content: Option<Content>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub citation_metadata: Option<CitationMetadata>,
|
pub citation_metadata: Option<CitationMetadata>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub safety_ratings: Option<Vec<SafetyRating>>,
|
pub safety_ratings: Option<Vec<SafetyRating>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user