Allow using system instructions

This commit is contained in:
2024-04-19 18:23:17 +01:00
parent 3d298fabbc
commit 5fde27b70d
4 changed files with 69 additions and 3 deletions

View File

@@ -0,0 +1,50 @@
use std::sync::Arc;
use gemini_rs::prelude::*;
use gcp_auth::AuthenticationManager;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt().init();
let authentication_manager = Arc::new(AuthenticationManager::new().await?);
let api_endpoint = std::env::var("API_ENDPOINT")?;
let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?;
let gemini = GeminiClient::new(
authentication_manager,
api_endpoint,
project_id,
location_id,
);
let system_instruction = "Answer as if you were Winston Churchill";
let prompt = "What is the airspeed of an unladen swallow?";
let request = GenerateContentRequest {
contents: vec![Content {
role: "user".to_string(),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
system_instruction: Some(Content {
role: "system".to_string(),
parts: Some(vec![Part::Text(system_instruction.to_string())]),
}),
..Default::default()
};
let result = gemini
.generate_content(&request, "gemini-1.0-pro-002")
.await?;
if let GenerateContentResponse::Ok {
candidates,
usage_metadata: _,
} = result
{
println!("Response: {:?}", candidates[0].get_text().unwrap());
}
Ok(())
}

View File

@@ -127,6 +127,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
.collect(),
generation_config: None,
tools: None,
system_instruction: None,
};
let response = self.generate_content(&request, model).await?;
@@ -154,6 +155,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
}],
generation_config: generation_config.cloned(),
tools: None,
system_instruction: None,
};
let response = self.generate_content(&request, "gemini-pro").await?;

View File

@@ -22,10 +22,22 @@ pub enum ErrorType {
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
Help { links: Vec<Link> },
#[serde(rename = "type.googleapis.com/google.rpc.BadRequest")]
BadRequest {
#[serde(rename = "fieldViolations")]
field_violations: Vec<FieldViolation>,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ErrorInfoMetadata {
service: String,
consumer: String,
pub service: String,
pub consumer: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FieldViolation {
pub field: String,
pub description: String,
}

View File

@@ -4,11 +4,12 @@ use serde::{Deserialize, Serialize};
use super::{Content, Error, Part};
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct GenerateContentRequest {
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub tools: Option<Vec<Tools>>,
pub system_instruction: Option<Content>,
}
impl GenerateContentRequest {
@@ -20,6 +21,7 @@ impl GenerateContentRequest {
}],
generation_config,
tools: None,
system_instruction: None,
}
}
}