Skip to content

Commit

Permalink
feat: batch prediction samples (#3884)
Browse files Browse the repository at this point in the history
  • Loading branch information
gryczj authored Sep 27, 2024
1 parent 4b67f18 commit 231f010
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 0 deletions.
113 changes: 113 additions & 0 deletions ai-platform/snippets/batch-code-predict.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

async function main(projectId, inputUri, outputUri, jobDisplayName) {
// [START generativeaionvertexai_batch_code_predict]
// Imports the aiplatform library
const aiplatformLib = require('@google-cloud/aiplatform');
const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1;

/**
* TODO(developer): Uncomment/update these variables before running the sample.
*/
// projectId = 'YOUR_PROJECT_ID';

// Optional: URI of the input dataset.
// Could be a BigQuery table or a Google Cloud Storage file.
// E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
// inputUri =
// 'gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl';

// Optional: URI where the output will be stored.
// Could be a BigQuery table or a Google Cloud Storage file.
// E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
// outputUri = 'gs://batch-bucket-testing/batch_code_predict_output';

// The name of batch prediction job
// jobDisplayName = `Batch code prediction job: ${new Date().getMilliseconds()}`;

// The name of pre-trained model
const codeModel = 'code-bison';
const location = 'us-central1';

// Construct your modelParameters
const parameters = {
maxOutputTokens: '200',
temperature: '0.2',
};
const parametersValue = aiplatformLib.helpers.toValue(parameters);
// Configure the parent resource
const parent = `projects/${projectId}/locations/${location}`;
const modelName = `projects/${projectId}/locations/${location}/publishers/google/models/${codeModel}`;

// Specifies the location of the api endpoint
const clientOptions = {
apiEndpoint: `${location}-aiplatform.googleapis.com`,
};

// Instantiates a client
const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions);

// Perform batch code prediction using a pre-trained code generation model.
// Example of using Google Cloud Storage bucket as the input and output data source
async function callBatchCodePredicton() {
const gcsSource = new aiplatform.GcsSource({
uris: [inputUri],
});

const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({
gcsSource,
instancesFormat: 'jsonl',
});

const gcsDestination = new aiplatform.GcsDestination({
outputUriPrefix: outputUri,
});

const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({
gcsDestination,
predictionsFormat: 'jsonl',
});

const batchPredictionJob = new aiplatform.BatchPredictionJob({
displayName: jobDisplayName,
model: modelName,
inputConfig,
outputConfig,
modelParameters: parametersValue,
});

const request = {
parent,
batchPredictionJob,
};

// Create batch prediction job request
const [response] = await jobServiceClient.createBatchPredictionJob(request);

console.log('Raw response: ', JSON.stringify(response, null, 2));
}

await callBatchCodePredicton();
// [END generativeaionvertexai_batch_code_predict]
}

main(...process.argv.slice(2)).catch(err => {
console.error(err.message);
process.exitCode = 1;
});
115 changes: 115 additions & 0 deletions ai-platform/snippets/batch-text-predict.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

async function main(projectId, inputUri, outputUri, jobDisplayName) {
// [START generativeaionvertexai_batch_text_predict]
// Imports the aiplatform library
const aiplatformLib = require('@google-cloud/aiplatform');
const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1;

/**
* TODO(developer): Uncomment/update these variables before running the sample.
*/
// projectId = 'YOUR_PROJECT_ID';

// Optional: URI of the input dataset.
// Could be a BigQuery table or a Google Cloud Storage file.
// E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
// inputUri =
// 'gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl';

// Optional: URI where the output will be stored.
// Could be a BigQuery table or a Google Cloud Storage file.
// E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
// outputUri = 'gs://batch-bucket-testing/batch_text_predict_output';

// The name of batch prediction job
// jobDisplayName = `Batch text prediction job: ${new Date().getMilliseconds()}`;

// The name of pre-trained model
const textModel = 'text-bison';
const location = 'us-central1';

// Construct your modelParameters
const parameters = {
maxOutputTokens: '200',
temperature: '0.2',
topP: '0.95',
topK: '40',
};
const parametersValue = aiplatformLib.helpers.toValue(parameters);
// Configure the parent resource
const parent = `projects/${projectId}/locations/${location}`;
const modelName = `projects/${projectId}/locations/${location}/publishers/google/models/${textModel}`;

// Specifies the location of the api endpoint
const clientOptions = {
apiEndpoint: `${location}-aiplatform.googleapis.com`,
};

// Instantiates a client
const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions);

// Perform batch text prediction using a pre-trained text generation model.
// Example of using Google Cloud Storage bucket as the input and output data source
async function callBatchTextPredicton() {
const gcsSource = new aiplatform.GcsSource({
uris: [inputUri],
});

const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({
gcsSource,
instancesFormat: 'jsonl',
});

const gcsDestination = new aiplatform.GcsDestination({
outputUriPrefix: outputUri,
});

const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({
gcsDestination,
predictionsFormat: 'jsonl',
});

const batchPredictionJob = new aiplatform.BatchPredictionJob({
displayName: jobDisplayName,
model: modelName,
inputConfig,
outputConfig,
modelParameters: parametersValue,
});

const request = {
parent,
batchPredictionJob,
};

// Create batch prediction job request
const [response] = await jobServiceClient.createBatchPredictionJob(request);

console.log('Raw response: ', JSON.stringify(response, null, 2));
}

await callBatchTextPredicton();
// [END generativeaionvertexai_batch_text_predict]
}

main(...process.argv.slice(2)).catch(err => {
console.error(err.message);
process.exitCode = 1;
});
70 changes: 70 additions & 0 deletions ai-platform/snippets/test/batch-code-predict.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

const {assert} = require('chai');
const {after, describe, it} = require('mocha');
const uuid = require('uuid').v4;
const cp = require('child_process');
const {JobServiceClient} = require('@google-cloud/aiplatform');

const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});

describe('Batch code predict', async () => {
const displayName = `batch-code-predict-job-${uuid()}`;
const location = 'us-central1';
const inputUri =
'gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl';
const outputUri = 'gs://ucaip-samples-test-output/';
const jobServiceClient = new JobServiceClient({
apiEndpoint: `${location}-aiplatform.googleapis.com`,
});
const projectId = process.env.CAIP_PROJECT_ID;
let batchPredictionJobId;

after(async () => {
const name = jobServiceClient.batchPredictionJobPath(
projectId,
location,
batchPredictionJobId
);

const cancelRequest = {
name,
};

jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => {
const deleteRequest = {
name,
};

return jobServiceClient.deleteBatchPredictionJob(deleteRequest);
});
});

it('should create job with code prediction', async () => {
const response = execSync(
`node ./batch-code-predict.js ${projectId} ${inputUri} ${outputUri} ${displayName}`
);

assert.match(response, new RegExp(displayName));

batchPredictionJobId = response
.split('/locations/us-central1/batchPredictionJobs/')[1]
.split('\n')[0];
});
});
69 changes: 69 additions & 0 deletions ai-platform/snippets/test/batch-text-predict.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

const {assert} = require('chai');
const {after, describe, it} = require('mocha');
const uuid = require('uuid').v4;
const cp = require('child_process');
const {JobServiceClient} = require('@google-cloud/aiplatform');

const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});

describe('Batch text predict', async () => {
const displayName = `batch-text-predict-job-${uuid()}`;
const location = 'us-central1';
const inputUri =
'gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl';
const outputUri = 'gs://ucaip-samples-test-output/';
const jobServiceClient = new JobServiceClient({
apiEndpoint: `${location}-aiplatform.googleapis.com`,
});
const projectId = process.env.CAIP_PROJECT_ID;
let batchPredictionJobId;

after(async () => {
const name = jobServiceClient.batchPredictionJobPath(
projectId,
location,
batchPredictionJobId
);

const cancelRequest = {
name,
};

jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => {
const deleteRequest = {
name,
};

return jobServiceClient.deleteBatchPredictionJob(deleteRequest);
});
});

it('should create job with text prediction', async () => {
const response = execSync(
`node ./batch-text-predict.js ${projectId} ${inputUri} ${outputUri} ${displayName}`
);

assert.match(response, new RegExp(displayName));
batchPredictionJobId = response
.split('/locations/us-central1/batchPredictionJobs/')[1]
.split('\n')[0];
});
});

0 comments on commit 231f010

Please sign in to comment.