Adds safety setting

This commit is contained in:
2024-08-15 14:44:59 +01:00
parent 60c20958a4
commit cb954ea5db
3 changed files with 111 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
use std::vec;
use gemini_rs::prelude::*;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let authentication_manager = gcp_auth::provider().await?;
let api_endpoint = std::env::var("API_ENDPOINT")?;
let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?;
let gemini = GeminiClient::new(
authentication_manager,
api_endpoint,
project_id,
location_id,
);
let prompt = "What is the airspeed of an unladen swallow?";
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
safety_settings: Some(vec![SafetySetting {
category: HarmCategory::HateSpeech,
threshold: HarmBlockThreshold::BlockNone,
method: None,
}]),
..Default::default()
};
println!("{}", serde_json::to_string_pretty(&request).unwrap());
let result = gemini
.generate_content(&request, "gemini-1.0-pro-002")
.await?;
println!("Response: {:?}", result);
Ok(())
}

View File

@@ -159,6 +159,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
generation_config: None,
tools: None,
system_instruction: None,
safety_settings: None,
};
let response = self.generate_content(&request, model).await?;
@@ -187,6 +188,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
generation_config: generation_config.cloned(),
tools: None,
system_instruction: None,
safety_settings: None,
};
let response = self.generate_content(&request, "gemini-pro").await?;

View File

@@ -6,10 +6,16 @@ use super::{Content, Part, VertexApiError};
use crate::error::Result;
#[derive(Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation_config: Option<GenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tools>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_settings: Option<Vec<SafetySetting>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<Content>,
}
@@ -23,14 +29,17 @@ impl GenerateContentRequest {
generation_config,
tools: None,
system_instruction: None,
safety_settings: None,
}
}
}
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct Tools {
#[serde(skip_serializing_if = "Option::is_none")]
pub function_declarations: Option<Vec<FunctionDeclaration>>,
#[serde(rename = "googleSearchRetrieval")]
#[serde(skip_serializing_if = "Option::is_none")]
pub google_search_retrieval: Option<GoogleSearchRetrieval>,
}
@@ -43,21 +52,79 @@ pub struct GoogleSearchRetrieval {
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub candidate_count: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_mime_type: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: HarmCategory,
pub threshold: HarmBlockThreshold,
#[serde(skip_serializing_if = "Option::is_none")]
pub method: Option<HarmBlockMethod>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HarmCategory {
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
Unspecified,
#[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
HateSpeech,
#[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
DangerousContent,
#[serde(rename = "HARM_CATEGORY_HARASSMENT")]
Harassment,
#[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
SexuallyExplicit,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HarmBlockThreshold {
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
Unspecified,
#[serde(rename = "BLOCK_LOW_AND_ABOVE")]
BlockLowAndAbove,
#[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
BlockMediumAndAbove,
#[serde(rename = "BLOCK_ONLY_HIGH")]
BlockOnlyHigh,
#[serde(rename = "BLOCK_NONE")]
BlockNone,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HarmBlockMethod {
#[serde(rename = "HARM_BLOCK_METHOD_UNSPECIFIED")]
Unspecified, // HARM_BLOCK_METHOD_UNSPECIFIED
#[serde(rename = "SEVERITY")]
Severity, // SEVERITY
#[serde(rename = "PROBABILITY")]
Probability, // PROBABILITY
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub citation_metadata: Option<CitationMetadata>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_ratings: Option<Vec<SafetyRating>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}