Adds code for image generation
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,3 +7,4 @@ Cargo.lock
|
||||
.cargo/config.toml
|
||||
.DS_Store
|
||||
/.idea
|
||||
output.jpg
|
||||
|
||||
@@ -13,12 +13,14 @@ reqwest = { version = "0.12", features = ["json", "gzip"] }
|
||||
reqwest-eventsource = "0.6"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = { version = "1"}
|
||||
serde_with = { version = "3.9", features = ["base64"]}
|
||||
tracing = "0.1"
|
||||
tokio = { version = "1" }
|
||||
|
||||
[dev-dependencies]
|
||||
console = "0.15.8"
|
||||
dialoguer = "0.11.0"
|
||||
image = "0.25.2"
|
||||
indicatif = "0.17.8"
|
||||
tokio = { version = "1.37.0", features = ["full"] }
|
||||
tracing-subscriber = "0.3.18"
|
||||
|
||||
45
examples/generate_image.rs
Normal file
45
examples/generate_image.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use std::{error::Error, io::Cursor};
|
||||
|
||||
use gemini_rs::prelude::{
|
||||
GeminiClient, PredictImageRequest, PredictImageRequestParameters, PredictImageRequestPrompt,
|
||||
};
|
||||
use image::{ImageFormat, ImageReader};
|
||||
|
||||
#[tokio::main]
|
||||
pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
tracing_subscriber::fmt().init();
|
||||
let authentication_manager = gcp_auth::provider().await?;
|
||||
let api_endpoint = std::env::var("API_ENDPOINT")?;
|
||||
let project_id = std::env::var("PROJECT_ID")?;
|
||||
let location_id = std::env::var("LOCATION_ID")?;
|
||||
|
||||
let gemini = GeminiClient::new(
|
||||
authentication_manager,
|
||||
api_endpoint,
|
||||
project_id,
|
||||
location_id,
|
||||
);
|
||||
|
||||
let prompt = "
|
||||
Create an image of a tuxedo cat riding a rocket to the moon.";
|
||||
let request = PredictImageRequest {
|
||||
instances: vec![PredictImageRequestPrompt {
|
||||
prompt: prompt.to_string(),
|
||||
}],
|
||||
parameters: PredictImageRequestParameters {
|
||||
sample_count: 1,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
let mut result = gemini
|
||||
.predict_image(&request, "imagen-3.0-fast-generate-001")
|
||||
.await?;
|
||||
|
||||
let result = result.predictions.pop().unwrap();
|
||||
|
||||
let format = ImageFormat::from_mime_type(result.mime_type).unwrap();
|
||||
let img =
|
||||
ImageReader::with_format(Cursor::new(result.bytes_base64_encoded), format).decode()?;
|
||||
img.save("output.jpg")?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -11,6 +11,7 @@ use crate::prelude::{
|
||||
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
||||
TextEmbeddingResponse,
|
||||
};
|
||||
use crate::types::{PredictImageRequest, PredictImageResponse};
|
||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||
|
||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||
@@ -252,4 +253,27 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
tracing::debug!("count_tokens response: {:?}", txt_json);
|
||||
Ok(serde_json::from_str(&txt_json)?)
|
||||
}
|
||||
|
||||
pub async fn predict_image(
|
||||
&self,
|
||||
request: &PredictImageRequest,
|
||||
model: &str,
|
||||
) -> Result<PredictImageResponse> {
|
||||
let endpoint_url = format!(
|
||||
"https://{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
|
||||
self.api_endpoint, self.project_id, self.location_id, model,
|
||||
);
|
||||
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||
let resp = self
|
||||
.client
|
||||
.post(&endpoint_url)
|
||||
.bearer_auth(access_token)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let txt_json = resp.text().await?;
|
||||
Ok(serde_json::from_str(&txt_json)?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ mod common;
|
||||
mod count_tokens;
|
||||
mod error;
|
||||
mod generate_content;
|
||||
mod predict_image;
|
||||
mod text_embeddings;
|
||||
|
||||
pub use common::*;
|
||||
pub use count_tokens::*;
|
||||
pub use error::*;
|
||||
pub use generate_content::*;
|
||||
pub use predict_image::*;
|
||||
pub use text_embeddings::*;
|
||||
|
||||
142
src/types/predict_image.rs
Normal file
142
src/types/predict_image.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::base64::Base64;
|
||||
use serde_with::serde_as;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageRequest {
|
||||
pub instances: Vec<PredictImageRequestPrompt>,
|
||||
pub parameters: PredictImageRequestParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageRequestPrompt {
|
||||
/// The text prompt for the image.
|
||||
/// The following models support different values for this parameter:
|
||||
/// - `imagen-3.0-generate-001`: up to 480 tokens.
|
||||
/// - `imagen-3.0-fast-generate-001`: up to 480 tokens.
|
||||
/// - `imagegeneration@006`: up to 128 tokens.
|
||||
/// - `imagegeneration@005`: up to 128 tokens.
|
||||
/// - `imagegeneration@002`: up to 64 tokens.
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageRequestParameters {
|
||||
/// The number of images to generate. The default value is 4.
|
||||
/// The following models support different values for this parameter:
|
||||
/// - `imagen-3.0-generate-001`: 1 to 4.
|
||||
/// - `imagen-3.0-fast-generate-001`: 1 to 4.
|
||||
/// - `imagegeneration@006`: 1 to 4.
|
||||
/// - `imagegeneration@005`: 1 to 4.
|
||||
/// - `imagegeneration@002`: 1 to 8.
|
||||
pub sample_count: i32,
|
||||
|
||||
/// The random seed for image generation. This is not available when addWatermark is set to
|
||||
/// true.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<u32>,
|
||||
|
||||
/// A description of what to discourage in the generated images.
|
||||
/// The following models support this parameter:
|
||||
/// - `imagen-3.0-generate-001`: up to 480 tokens.
|
||||
/// - `imagen-3.0-fast-generate-001`: up to 480 tokens.
|
||||
/// - `imagegeneration@006`: up to 128 tokens.
|
||||
/// - `imagegeneration@005`: up to 128 tokens.
|
||||
/// - `imagegeneration@002`: up to 64 tokens.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub negative_prompt: Option<String>,
|
||||
|
||||
/// The aspect ratio for the image. The default value is "1:1".
|
||||
/// The following models support different values for this parameter:
|
||||
/// - `imagen-3.0-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||
/// - `imagen-3.0-fast-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||
/// - `imagegeneration@006`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||
/// - `imagegeneration@005`: "1:1" or "9:16".
|
||||
/// - `imagegeneration@002`: "1:1".
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub aspect_ratio: Option<String>,
|
||||
|
||||
/// Describes the output image format in an `PredictImageRequestParametersOutputOptions
|
||||
/// object.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_options: Option<PredictImageRequestParametersOutputOptions>,
|
||||
|
||||
/// Describes the style for the generated images. The following values are supported:
|
||||
/// - "photograph"
|
||||
/// - "digital_art"
|
||||
/// - "landscape"
|
||||
/// - "sketch"
|
||||
/// - "watercolor"
|
||||
/// - "cyberpunk"
|
||||
/// - "pop_art"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sample_image_style: Option<String>,
|
||||
|
||||
/// Allow generation of people by the model. The following values are supported:
|
||||
/// - `"dont_allow"`: Disallow the inclusion of people or faces in images.
|
||||
/// - `"allow_adult"`: Allow generation of adults only.
|
||||
/// - `"allow_all"`: Allow generation of people of all ages.
|
||||
///
|
||||
/// The default value is `"allow_adult"`.
|
||||
///
|
||||
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
|
||||
/// `imagegeneration@006` only.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub person_generation: Option<String>,
|
||||
|
||||
/// Adds a filter level to safety filtering. The following values are supported:
|
||||
/// - `"block_most"`: Strongest filtering level, most strict blocking.
|
||||
/// - `"block_some"`: Block some problematic prompts and responses.
|
||||
/// - `"block_few"`: Reduces the number of requests blocked due to safety filters. May
|
||||
/// increase objectionable content generated by Imagen.
|
||||
/// - `"block_fewest"`: Block very few problematic prompts and responses. Access to this
|
||||
/// feature is restricted.
|
||||
///
|
||||
/// The default value is `"block_some"`.
|
||||
///
|
||||
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
|
||||
/// `imagegeneration@006` only.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub safety_setting: Option<String>,
|
||||
|
||||
/// Add an invisible watermark to the generated images. The default value is `false` for the
|
||||
/// `imagegeneration@002` and `imagegeneration@005` models, and `true` for the
|
||||
/// `imagen-3.0-fast-generate-001`, `imagegeneration@006`, and imagegeneration@006 models.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub add_watermark: Option<bool>,
|
||||
|
||||
/// Cloud Storage URI to store the generated images.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub storage_uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageRequestParametersOutputOptions {
|
||||
/// The image format that the output should be saved as. The following values are supported:
|
||||
///
|
||||
/// - "image/png": Save as a PNG image
|
||||
/// - "image/jpeg": Save as a JPEG image
|
||||
///
|
||||
/// The default value is "image/png".v
|
||||
pub mime_type: Option<String>,
|
||||
|
||||
/// The level of compression if the output type is "image/jpeg".
|
||||
/// Accepted values are 0 through 100. The default value is 75.
|
||||
pub compression_quality: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageResponse {
|
||||
pub predictions: Vec<PredictImageResponsePrediction>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageResponsePrediction {
|
||||
#[serde_as(as = "Base64")]
|
||||
pub bytes_base64_encoded: Vec<u8>,
|
||||
pub mime_type: String,
|
||||
}
|
||||
Reference in New Issue
Block a user